From b918f1851ca8be9294b01f0191070f6bd86ba431 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 28 Nov 2025 11:04:47 +0100 Subject: [PATCH] fix(status): Add wrappers for ophyd status objects to improve error handling --- ophyd_devices/utils/psi_device_base_utils.py | 156 ++++++++++++++++++- tests/test_utils.py | 77 ++++++++- 2 files changed, 224 insertions(+), 9 deletions(-) diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 6e0042d..3c5d06a 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -12,7 +12,11 @@ from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from from ophyd import Device, Signal -from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus +from ophyd.status import DeviceStatus as _DeviceStatus +from ophyd.status import MoveStatus as _MoveStatus +from ophyd.status import Status as _Status +from ophyd.status import StatusBase as _StatusBase +from ophyd.status import SubscriptionStatus as _SubscriptionStatus if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import ScanStatusMessage @@ -46,6 +50,142 @@ OP_MAP = { } +class StatusBase(_StatusBase): + """Base class for all status objects.""" + + def __init__( + self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None + ): + self.obj = obj + super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success) + + def __and__(self, other): + """Returns a new 'composite' status object, AndStatus""" + return AndStatus(self, other) + + +class AndStatus(StatusBase): + """ + A Status that has composes two other Status objects using logical and. + If any of the two Status objects fails, the combined status will fail + with the exception of the first Status to fail. + + Parameters + ---------- + left: StatusBase + The left-hand Status object + right: StatusBase + The right-hand Status object + """ + + def __init__(self, left, right, **kwargs): + self.left = left + self.right = right + super().__init__(**kwargs) + self._trace_attributes["left"] = self.left._trace_attributes + self._trace_attributes["right"] = self.right._trace_attributes + + def inner(status): + with self._lock: + if self._externally_initiated_completion: + return + + # Return if status is already done.. + if self.done: + return + + with status._lock: + if status.done and not status.success: + self.set_exception(status.exception()) # st._exception + return + if self.left.done and self.right.done and self.left.success and self.right.success: + self.set_finished() + + self.left.add_callback(inner) + self.right.add_callback(inner) + + def __repr__(self): + return "({self.left!r} & {self.right!r})".format(self=self) + + def __str__(self): + return "{0}(done={1.done}, " "success={1.success})" "".format(self.__class__.__name__, self) + + def __contains__(self, status) -> bool: + for child in [self.left, self.right]: + if child == status: + return True + if isinstance(child, AndStatus): + if status in child: + return True + + return False + + +class Status(_Status): + """Thin wrapper around StatusBase to add __and__ operator.""" + + def __and__(self, other): + """Returns a new 'composite' status object, AndStatus""" + return AndStatus(self, other) + + +class DeviceStatus(_DeviceStatus): + """Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False""" + + def __and__(self, other): + """Returns a new 'composite' status object, AndStatus""" + return AndStatus(self, other) + + +class MoveStatus(_MoveStatus): + """Thin wrapper around MoveStatus to ensure __and__ operator and stop on failure.""" + + def __and__(self, other): + """Returns a new 'composite' status object, AndStatus""" + return AndStatus(self, other) + + +class SubscriptionStatus(StatusBase): + """Subscription status implementation based on wrapped StatusBase implementation.""" + + def __init__( + self, + obj: Device | Signal, + callback: Callable, + event_type=None, + timeout=None, + settle_time=None, + run=True, + ): + # Store device and attribute information + self.callback = callback + self.obj = obj + # Start timeout thread in the background + super().__init__(obj=obj, timeout=timeout, settle_time=settle_time) + + self.obj.subscribe(self.check_value, event_type=event_type, run=run) + + def check_value(self, *args, **kwargs): + """Update the status object""" + try: + success = self.callback(*args, **kwargs) + except Exception as e: + self.log.error(e) + raise + if success: + self.set_finished() + + def set_finished(self): + """Mark as finished successfully.""" + self.obj.clear_sub(self.check_value) + super().set_finished() + + def _handle_failure(self): + """Clear subscription on failure, run callbacks through super()""" + self.obj.clear_sub(self.check_value) + return super()._handle_failure() + + class CompareStatus(SubscriptionStatus): """ Status to compare a signal value against a given value. @@ -105,7 +245,7 @@ class CompareStatus(SubscriptionStatus): f"failure_value must be a float, int, str, list or None. Received: {failure_value}" ) super().__init__( - device=signal, + obj=signal, callback=self._compare_callback, timeout=timeout, settle_time=settle_time, @@ -199,7 +339,7 @@ class TransitionStatus(SubscriptionStatus): self._strict = strict self._failure_states = failure_states if failure_states else [] super().__init__( - device=signal, + obj=signal, callback=self._compare_callback, timeout=timeout, settle_time=settle_time, @@ -263,12 +403,14 @@ class TaskKilledError(Exception): """Exception raised when a task thread is killed""" -class TaskStatus(DeviceStatus): +class TaskStatus(StatusBase): """Thin wrapper around StatusBase to add information about tasks""" - def __init__(self, device: Device, *, timeout=None, settle_time=0, done=None, success=None): + def __init__( + self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None + ): super().__init__( - device=device, timeout=timeout, settle_time=settle_time, done=done, success=success + obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success ) self._state = TaskState.NOT_STARTED self._task_id = str(uuid.uuid4()) @@ -312,7 +454,7 @@ class TaskHandler: """ task_args = task_args if task_args else () task_kwargs = task_kwargs if task_kwargs else {} - task_status = TaskStatus(device=self._parent) + task_status = TaskStatus(self._parent) thread = threading.Thread( target=self._wrap_task, args=(task, task_args, task_kwargs, task_status), diff --git a/tests/test_utils.py b/tests/test_utils.py index 9ff2d67..2f2d637 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,6 +12,14 @@ from ophyd import Device, EpicsSignalRO, Signal from ophyd.status import WaitTimeoutError from typeguard import TypeCheckError +from ophyd_devices import ( + AndStatus, + DeviceStatus, + MoveStatus, + Status, + StatusBase, + SubscriptionStatus, +) from ophyd_devices.tests.utils import MockPV from ophyd_devices.utils.bec_signals import ( AsyncMultiSignal, @@ -76,8 +84,8 @@ def test_utils_file_handler_has_full_path(file_handler): def test_utils_task_status(device): """Test TaskStatus creation""" - status = TaskStatus(device=device) - assert status.device.name == "device" + status = TaskStatus(device) + assert status.obj.name == "device" assert status.state == "not_started" assert status.task_id == status._task_id status.state = "running" @@ -929,3 +937,68 @@ def test_transition_status_with_mock_pv( status.wait(timeout=1) assert status.done is False assert status.success is False + + +def test_patched_status_objects(): + """Test the patched Status objects in ophyd_devices that improve error handling.""" + + # StatusBase & AndStatus + st = StatusBase() + st2 = StatusBase() + and_st = st & st2 + assert isinstance(and_st, AndStatus) + st.set_exception(ValueError("test error")) + with pytest.raises(ValueError): + and_st.wait(timeout=10) + + # DeviceStatus & StatusBase + dev = Device(name="device") + dev_status = DeviceStatus(device=dev) + assert dev_status.device == dev + dev_status.set_exception(RuntimeError("device error")) + + # Combine DeviceStatus with StatusBase and form AndStatus + st = StatusBase(obj=dev) + assert st.obj == dev + dev_st = DeviceStatus(device=dev) + combined_st = st & dev_st + st.set_finished() + dev_st.set_exception(RuntimeError("combined error")) + with pytest.raises(RuntimeError): + combined_st.wait(timeout=10) + + # SubscriptionStatus + sig = Signal(name="test_signal", value=0) + + def _cb(*args, **kwargs): + pass + + sub_st = SubscriptionStatus(sig, callback=_cb) + sub_st.set_exception(ValueError("subscription error")) + with pytest.raises(ValueError): + sub_st.wait(timeout=10) + assert sub_st.done is True + assert sub_st.success is False + + # MoveStatus, here the default for call_stop_on_failure is True + class Positioner(Device): + SUB_READBACK = "readback" + setpoint = Signal(name="setpoint", value=0) + readback = Signal(name="readback", value=0) + + @property + def position(self): + return self.readback.get() + + def stop(self): + pass + + pos = Positioner(name="positioner") + move_st = MoveStatus(pos, target=10) + with mock.patch.object(pos, "stop") as mock_stop: + move_st.set_exception(RuntimeError("move error")) + mock_stop.assert_called_once() + with pytest.raises(RuntimeError): + move_st.wait(timeout=10) + assert move_st.done is True + assert move_st.success is False