mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-04 14:18:41 +01:00
fix(computed signal): fix various bugs in the computed signal
This commit is contained in:
@@ -2,11 +2,14 @@
|
|||||||
This module provides a class for creating a pseudo signal that is computed from other signals.
|
This module provides a class for creating a pseudo signal that is computed from other signals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy as sp
|
||||||
from bec_lib import bec_logger
|
from bec_lib import bec_logger
|
||||||
from ophyd import SignalRO
|
from ophyd import Signal, SignalRO
|
||||||
from ophyd.ophydobj import Kind
|
from ophyd.ophydobj import Kind
|
||||||
|
|
||||||
logger = bec_logger.logger
|
logger = bec_logger.logger
|
||||||
@@ -59,7 +62,7 @@ class ComputedSignal(SignalRO):
|
|||||||
attr_name=attr_name,
|
attr_name=attr_name,
|
||||||
)
|
)
|
||||||
self._device_manager = device_manager
|
self._device_manager = device_manager
|
||||||
self._input_signals = []
|
self._input_signals: list[Signal] = []
|
||||||
self._signal_subs = []
|
self._signal_subs = []
|
||||||
self._compute_method = None
|
self._compute_method = None
|
||||||
self._compute_method_str = None
|
self._compute_method_str = None
|
||||||
@@ -94,17 +97,27 @@ class ComputedSignal(SignalRO):
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Updating compute method for {self.name}.")
|
logger.info(f"Updating compute method for {self.name}.")
|
||||||
method = method.strip()
|
method = method.strip()
|
||||||
if not method.startswith("def"):
|
|
||||||
raise ValueError("The compute method should be a string representation of a function")
|
|
||||||
|
|
||||||
# get the function name
|
# Parse and validate the function using AST
|
||||||
function_name = method.split("(")[0].split(" ")[1]
|
try:
|
||||||
method = method.replace(function_name, "user_compute_method")
|
tree = ast.parse(method)
|
||||||
|
if not tree.body or not isinstance(tree.body[0], ast.FunctionDef):
|
||||||
|
raise ValueError("The compute method should be a valid function definition")
|
||||||
|
|
||||||
|
# Rename the function in the AST
|
||||||
|
func_def = tree.body[0]
|
||||||
|
func_def.name = "user_compute_method"
|
||||||
|
|
||||||
|
# Convert AST back to code and compile
|
||||||
|
code = compile(tree, "<string>", "exec")
|
||||||
|
except SyntaxError as exc:
|
||||||
|
raise ValueError(f"Invalid function syntax: {exc}") from exc
|
||||||
|
|
||||||
self._compute_method_str = method
|
self._compute_method_str = method
|
||||||
# pylint: disable=exec-used
|
# pylint: disable=exec-used
|
||||||
out = {}
|
namespace = {"np": np, "sp": sp}
|
||||||
exec(method, out)
|
exec(code, namespace)
|
||||||
self._compute_method = out["user_compute_method"]
|
self._compute_method = namespace["user_compute_method"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_signals(self):
|
def input_signals(self):
|
||||||
@@ -115,7 +128,7 @@ class ComputedSignal(SignalRO):
|
|||||||
*input_vars: The input signals to be used for the computation
|
*input_vars: The input signals to be used for the computation
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> signal.input_signals = ["samx_readback", "samx_readback"]
|
>>> signal.input_signals = ["samx.readback", "samx.readback"]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self._input_signals
|
return self._input_signals
|
||||||
@@ -128,10 +141,7 @@ class ComputedSignal(SignalRO):
|
|||||||
signals = []
|
signals = []
|
||||||
for signal in input_vars:
|
for signal in input_vars:
|
||||||
if isinstance(signal, str):
|
if isinstance(signal, str):
|
||||||
target = signal.replace("_", ".")
|
obj = rgetattr(self._device_manager.devices, signal)
|
||||||
parts = target.split(".")
|
|
||||||
target = ".".join([parts[0], "obj"] + parts[1:])
|
|
||||||
obj = rgetattr(self._device_manager.devices, target)
|
|
||||||
sub_id = obj.subscribe(self._signal_callback)
|
sub_id = obj.subscribe(self._signal_callback)
|
||||||
self._signal_subs.append((obj, sub_id))
|
self._signal_subs.append((obj, sub_id))
|
||||||
signals.append(obj)
|
signals.append(obj)
|
||||||
|
|||||||
@@ -12,21 +12,33 @@ def device_manager_with_devices():
|
|||||||
dm.add_device("a")
|
dm.add_device("a")
|
||||||
dm.add_device("b")
|
dm.add_device("b")
|
||||||
device_mock = mock.MagicMock()
|
device_mock = mock.MagicMock()
|
||||||
device_mock.obj.readback.get.return_value = 20
|
|
||||||
dm.devices["a"] = device_mock
|
dm.devices["a"] = device_mock
|
||||||
dm.devices["b"] = device_mock
|
dm.devices["b"] = device_mock
|
||||||
|
|
||||||
return dm
|
return dm
|
||||||
|
|
||||||
|
|
||||||
def test_computed_signal(device_manager_with_devices):
|
@pytest.mark.parametrize(
|
||||||
|
"compute_method_str",
|
||||||
|
[
|
||||||
|
"def test(a, b): return a.get() + b.get()",
|
||||||
|
"def test(a, b): return a.get() + b.get()",
|
||||||
|
" def my_compute_method(a,b):\n return a.get() + b.get()\n",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_computed_signal(device_manager_with_devices, compute_method_str):
|
||||||
signal = ComputedSignal(name="test", device_manager=device_manager_with_devices)
|
signal = ComputedSignal(name="test", device_manager=device_manager_with_devices)
|
||||||
assert signal.get() is None
|
assert signal.get() is None
|
||||||
|
|
||||||
signal.compute_method = "def test(a, b): return a.get() + b.get()"
|
# Configure the mocks before setting input signals
|
||||||
signal.input_signals = ["a_readback", "b_readback"]
|
device_manager_with_devices.devices["a"].readback.get.return_value = 20
|
||||||
|
device_manager_with_devices.devices["b"].readback.get.return_value = 20
|
||||||
|
|
||||||
|
signal.compute_method = compute_method_str
|
||||||
|
signal.input_signals = ["a.readback", "b.readback"]
|
||||||
|
|
||||||
assert signal.get() == 40
|
assert signal.get() == 40
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
assert callable(signal._compute_method)
|
assert callable(signal._compute_method)
|
||||||
assert signal._compute_method_str == "def user_compute_method(a, b): return a.get() + b.get()"
|
assert signal._compute_method_str == compute_method_str.strip()
|
||||||
|
|||||||
Reference in New Issue
Block a user