fix: smarter strip for computed signal

This commit is contained in:
2025-11-13 13:58:19 +01:00
committed by Klaus Wakonig
parent 0306a0bd6f
commit 1241bcb014
2 changed files with 16 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ This module provides a class for creating a pseudo signal that is computed from
""" """
import ast import ast
import re
from functools import reduce from functools import reduce
from typing import Callable from typing import Callable
@@ -14,6 +15,15 @@ from ophyd.ophydobj import Kind
logger = bec_logger.logger logger = bec_logger.logger
_FIRST_DEF = re.compile(r"(^|\\n)\s*def", re.MULTILINE)
def _smart_strip(method: str) -> str:
if (first_def := _FIRST_DEF.search(method)) is not None:
return method[first_def.span()[1] - 3 :]
else:
raise ValueError(f"No 'def' keyword found in function definition: {method}")
def rgetattr(obj, attr, *args): def rgetattr(obj, attr, *args):
"""See https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects""" """See https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects"""
@@ -96,8 +106,7 @@ class ComputedSignal(SignalRO):
""" """
logger.info(f"Updating compute method for {self.name}.") logger.info(f"Updating compute method for {self.name}.")
method = method.strip() method = _smart_strip(method)
# Parse and validate the function using AST # Parse and validate the function using AST
try: try:
tree = ast.parse(method) tree = ast.parse(method)

View File

@@ -3,7 +3,7 @@ from unittest import mock
import pytest import pytest
from bec_server.device_server.tests.utils import DMMock from bec_server.device_server.tests.utils import DMMock
from ophyd_devices.utils.dynamic_pseudo import ComputedSignal from ophyd_devices.utils.dynamic_pseudo import ComputedSignal, _smart_strip
@pytest.fixture @pytest.fixture
@@ -24,6 +24,9 @@ def device_manager_with_devices():
"def test(a, b): return a.get() + b.get()", "def test(a, b): return a.get() + b.get()",
"def test(a, b): return a.get() + b.get()", "def test(a, b): return a.get() + b.get()",
" def my_compute_method(a,b):\n return a.get() + b.get()\n", " def my_compute_method(a,b):\n return a.get() + b.get()\n",
"#comment goes here\n def my_compute_method(a,b):\n return a.get() + b.get()\n",
"#comment goes here\n\n\n def my_compute_method(a,b):\n return a.get() + b.get()\n",
"#comment goes here\n\n\n def my_compute_method(a,b):\n#comment inside\n return a.get() + b.get()\n",
], ],
) )
def test_computed_signal(device_manager_with_devices, compute_method_str): def test_computed_signal(device_manager_with_devices, compute_method_str):
@@ -41,4 +44,4 @@ def test_computed_signal(device_manager_with_devices, compute_method_str):
# pylint: disable=protected-access # pylint: disable=protected-access
assert callable(signal._compute_method) assert callable(signal._compute_method)
assert signal._compute_method_str == compute_method_str.strip() assert signal._compute_method_str == _smart_strip(compute_method_str)