fix(computed signal): fix various bugs in the computed signal

This commit is contained in:
2025-11-11 18:28:28 +01:00
committed by David Perl
parent 176c95d0f1
commit d7fb4f55e8
2 changed files with 42 additions and 20 deletions

View File

@@ -2,11 +2,14 @@
This module provides a class for creating a pseudo signal that is computed from other signals. This module provides a class for creating a pseudo signal that is computed from other signals.
""" """
import ast
from functools import reduce from functools import reduce
from typing import Callable from typing import Callable
import numpy as np
import scipy as sp
from bec_lib import bec_logger from bec_lib import bec_logger
from ophyd import SignalRO from ophyd import Signal, SignalRO
from ophyd.ophydobj import Kind from ophyd.ophydobj import Kind
logger = bec_logger.logger logger = bec_logger.logger
@@ -59,7 +62,7 @@ class ComputedSignal(SignalRO):
attr_name=attr_name, attr_name=attr_name,
) )
self._device_manager = device_manager self._device_manager = device_manager
self._input_signals = [] self._input_signals: list[Signal] = []
self._signal_subs = [] self._signal_subs = []
self._compute_method = None self._compute_method = None
self._compute_method_str = None self._compute_method_str = None
@@ -94,17 +97,27 @@ class ComputedSignal(SignalRO):
""" """
logger.info(f"Updating compute method for {self.name}.") logger.info(f"Updating compute method for {self.name}.")
method = method.strip() method = method.strip()
if not method.startswith("def"):
raise ValueError("The compute method should be a string representation of a function")
# get the function name # Parse and validate the function using AST
function_name = method.split("(")[0].split(" ")[1] try:
method = method.replace(function_name, "user_compute_method") tree = ast.parse(method)
if not tree.body or not isinstance(tree.body[0], ast.FunctionDef):
raise ValueError("The compute method should be a valid function definition")
# Rename the function in the AST
func_def = tree.body[0]
func_def.name = "user_compute_method"
# Convert AST back to code and compile
code = compile(tree, "<string>", "exec")
except SyntaxError as exc:
raise ValueError(f"Invalid function syntax: {exc}") from exc
self._compute_method_str = method self._compute_method_str = method
# pylint: disable=exec-used # pylint: disable=exec-used
out = {} namespace = {"np": np, "sp": sp}
exec(method, out) exec(code, namespace)
self._compute_method = out["user_compute_method"] self._compute_method = namespace["user_compute_method"]
@property @property
def input_signals(self): def input_signals(self):
@@ -115,7 +128,7 @@ class ComputedSignal(SignalRO):
*input_vars: The input signals to be used for the computation *input_vars: The input signals to be used for the computation
Example: Example:
>>> signal.input_signals = ["samx_readback", "samx_readback"] >>> signal.input_signals = ["samx.readback", "samx.readback"]
""" """
return self._input_signals return self._input_signals
@@ -128,10 +141,7 @@ class ComputedSignal(SignalRO):
signals = [] signals = []
for signal in input_vars: for signal in input_vars:
if isinstance(signal, str): if isinstance(signal, str):
target = signal.replace("_", ".") obj = rgetattr(self._device_manager.devices, signal)
parts = target.split(".")
target = ".".join([parts[0], "obj"] + parts[1:])
obj = rgetattr(self._device_manager.devices, target)
sub_id = obj.subscribe(self._signal_callback) sub_id = obj.subscribe(self._signal_callback)
self._signal_subs.append((obj, sub_id)) self._signal_subs.append((obj, sub_id))
signals.append(obj) signals.append(obj)

View File

@@ -12,21 +12,33 @@ def device_manager_with_devices():
dm.add_device("a") dm.add_device("a")
dm.add_device("b") dm.add_device("b")
device_mock = mock.MagicMock() device_mock = mock.MagicMock()
device_mock.obj.readback.get.return_value = 20
dm.devices["a"] = device_mock dm.devices["a"] = device_mock
dm.devices["b"] = device_mock dm.devices["b"] = device_mock
return dm return dm
def test_computed_signal(device_manager_with_devices): @pytest.mark.parametrize(
"compute_method_str",
[
"def test(a, b): return a.get() + b.get()",
"def test(a, b): return a.get() + b.get()",
" def my_compute_method(a,b):\n return a.get() + b.get()\n",
],
)
def test_computed_signal(device_manager_with_devices, compute_method_str):
signal = ComputedSignal(name="test", device_manager=device_manager_with_devices) signal = ComputedSignal(name="test", device_manager=device_manager_with_devices)
assert signal.get() is None assert signal.get() is None
signal.compute_method = "def test(a, b): return a.get() + b.get()" # Configure the mocks before setting input signals
signal.input_signals = ["a_readback", "b_readback"] device_manager_with_devices.devices["a"].readback.get.return_value = 20
device_manager_with_devices.devices["b"].readback.get.return_value = 20
signal.compute_method = compute_method_str
signal.input_signals = ["a.readback", "b.readback"]
assert signal.get() == 40 assert signal.get() == 40
# pylint: disable=protected-access # pylint: disable=protected-access
assert callable(signal._compute_method) assert callable(signal._compute_method)
assert signal._compute_method_str == "def user_compute_method(a, b): return a.get() + b.get()" assert signal._compute_method_str == compute_method_str.strip()