From d7fb4f55e8eaeabc9a40a30189db4d7920884eea Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Tue, 11 Nov 2025 18:28:28 +0100 Subject: [PATCH] fix(computed signal): fix various bugs in the computed signal --- ophyd_devices/utils/dynamic_pseudo.py | 40 +++++++++++++++++---------- tests/test_dynamic_pseudo.py | 22 +++++++++++---- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/ophyd_devices/utils/dynamic_pseudo.py b/ophyd_devices/utils/dynamic_pseudo.py index 453035c..f8ce036 100644 --- a/ophyd_devices/utils/dynamic_pseudo.py +++ b/ophyd_devices/utils/dynamic_pseudo.py @@ -2,11 +2,14 @@ This module provides a class for creating a pseudo signal that is computed from other signals. """ +import ast from functools import reduce from typing import Callable +import numpy as np +import scipy as sp from bec_lib import bec_logger -from ophyd import SignalRO +from ophyd import Signal, SignalRO from ophyd.ophydobj import Kind logger = bec_logger.logger @@ -59,7 +62,7 @@ class ComputedSignal(SignalRO): attr_name=attr_name, ) self._device_manager = device_manager - self._input_signals = [] + self._input_signals: list[Signal] = [] self._signal_subs = [] self._compute_method = None self._compute_method_str = None @@ -94,17 +97,27 @@ class ComputedSignal(SignalRO): """ logger.info(f"Updating compute method for {self.name}.") method = method.strip() - if not method.startswith("def"): - raise ValueError("The compute method should be a string representation of a function") - # get the function name - function_name = method.split("(")[0].split(" ")[1] - method = method.replace(function_name, "user_compute_method") + # Parse and validate the function using AST + try: + tree = ast.parse(method) + if not tree.body or not isinstance(tree.body[0], ast.FunctionDef): + raise ValueError("The compute method should be a valid function definition") + + # Rename the function in the AST + func_def = tree.body[0] + func_def.name = "user_compute_method" + + # Convert AST back to code and compile + code = compile(tree, "", "exec") + except SyntaxError as exc: + raise ValueError(f"Invalid function syntax: {exc}") from exc + self._compute_method_str = method # pylint: disable=exec-used - out = {} - exec(method, out) - self._compute_method = out["user_compute_method"] + namespace = {"np": np, "sp": sp} + exec(code, namespace) + self._compute_method = namespace["user_compute_method"] @property def input_signals(self): @@ -115,7 +128,7 @@ class ComputedSignal(SignalRO): *input_vars: The input signals to be used for the computation Example: - >>> signal.input_signals = ["samx_readback", "samx_readback"] + >>> signal.input_signals = ["samx.readback", "samx.readback"] """ return self._input_signals @@ -128,10 +141,7 @@ class ComputedSignal(SignalRO): signals = [] for signal in input_vars: if isinstance(signal, str): - target = signal.replace("_", ".") - parts = target.split(".") - target = ".".join([parts[0], "obj"] + parts[1:]) - obj = rgetattr(self._device_manager.devices, target) + obj = rgetattr(self._device_manager.devices, signal) sub_id = obj.subscribe(self._signal_callback) self._signal_subs.append((obj, sub_id)) signals.append(obj) diff --git a/tests/test_dynamic_pseudo.py b/tests/test_dynamic_pseudo.py index 762630d..9b68109 100644 --- a/tests/test_dynamic_pseudo.py +++ b/tests/test_dynamic_pseudo.py @@ -12,21 +12,33 @@ def device_manager_with_devices(): dm.add_device("a") dm.add_device("b") device_mock = mock.MagicMock() - device_mock.obj.readback.get.return_value = 20 dm.devices["a"] = device_mock dm.devices["b"] = device_mock return dm -def test_computed_signal(device_manager_with_devices): +@pytest.mark.parametrize( + "compute_method_str", + [ + "def test(a, b): return a.get() + b.get()", + "def test(a, b): return a.get() + b.get()", + " def my_compute_method(a,b):\n return a.get() + b.get()\n", + ], +) +def test_computed_signal(device_manager_with_devices, compute_method_str): signal = ComputedSignal(name="test", device_manager=device_manager_with_devices) assert signal.get() is None - signal.compute_method = "def test(a, b): return a.get() + b.get()" - signal.input_signals = ["a_readback", "b_readback"] + # Configure the mocks before setting input signals + device_manager_with_devices.devices["a"].readback.get.return_value = 20 + device_manager_with_devices.devices["b"].readback.get.return_value = 20 + + signal.compute_method = compute_method_str + signal.input_signals = ["a.readback", "b.readback"] + assert signal.get() == 40 # pylint: disable=protected-access assert callable(signal._compute_method) - assert signal._compute_method_str == "def user_compute_method(a, b): return a.get() + b.get()" + assert signal._compute_method_str == compute_method_str.strip()