fix(status): Add wrappers for ophyd status objects to improve error handling

This commit is contained in:
2025-11-28 11:04:47 +01:00
committed by Christian Appel
parent 58d4a5141f
commit b918f1851c
2 changed files with 224 additions and 9 deletions

View File

@@ -12,7 +12,11 @@ from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.utils.import_utils import lazy_import_from from bec_lib.utils.import_utils import lazy_import_from
from ophyd import Device, Signal from ophyd import Device, Signal
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus from ophyd.status import DeviceStatus as _DeviceStatus
from ophyd.status import MoveStatus as _MoveStatus
from ophyd.status import Status as _Status
from ophyd.status import StatusBase as _StatusBase
from ophyd.status import SubscriptionStatus as _SubscriptionStatus
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from bec_lib.messages import ScanStatusMessage from bec_lib.messages import ScanStatusMessage
@@ -46,6 +50,142 @@ OP_MAP = {
} }
class StatusBase(_StatusBase):
"""Base class for all status objects."""
def __init__(
self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None
):
self.obj = obj
super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success)
def __and__(self, other):
"""Returns a new 'composite' status object, AndStatus"""
return AndStatus(self, other)
class AndStatus(StatusBase):
"""
A Status that has composes two other Status objects using logical and.
If any of the two Status objects fails, the combined status will fail
with the exception of the first Status to fail.
Parameters
----------
left: StatusBase
The left-hand Status object
right: StatusBase
The right-hand Status object
"""
def __init__(self, left, right, **kwargs):
self.left = left
self.right = right
super().__init__(**kwargs)
self._trace_attributes["left"] = self.left._trace_attributes
self._trace_attributes["right"] = self.right._trace_attributes
def inner(status):
with self._lock:
if self._externally_initiated_completion:
return
# Return if status is already done..
if self.done:
return
with status._lock:
if status.done and not status.success:
self.set_exception(status.exception()) # st._exception
return
if self.left.done and self.right.done and self.left.success and self.right.success:
self.set_finished()
self.left.add_callback(inner)
self.right.add_callback(inner)
def __repr__(self):
return "({self.left!r} & {self.right!r})".format(self=self)
def __str__(self):
return "{0}(done={1.done}, " "success={1.success})" "".format(self.__class__.__name__, self)
def __contains__(self, status) -> bool:
for child in [self.left, self.right]:
if child == status:
return True
if isinstance(child, AndStatus):
if status in child:
return True
return False
class Status(_Status):
"""Thin wrapper around StatusBase to add __and__ operator."""
def __and__(self, other):
"""Returns a new 'composite' status object, AndStatus"""
return AndStatus(self, other)
class DeviceStatus(_DeviceStatus):
"""Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False"""
def __and__(self, other):
"""Returns a new 'composite' status object, AndStatus"""
return AndStatus(self, other)
class MoveStatus(_MoveStatus):
"""Thin wrapper around MoveStatus to ensure __and__ operator and stop on failure."""
def __and__(self, other):
"""Returns a new 'composite' status object, AndStatus"""
return AndStatus(self, other)
class SubscriptionStatus(StatusBase):
"""Subscription status implementation based on wrapped StatusBase implementation."""
def __init__(
self,
obj: Device | Signal,
callback: Callable,
event_type=None,
timeout=None,
settle_time=None,
run=True,
):
# Store device and attribute information
self.callback = callback
self.obj = obj
# Start timeout thread in the background
super().__init__(obj=obj, timeout=timeout, settle_time=settle_time)
self.obj.subscribe(self.check_value, event_type=event_type, run=run)
def check_value(self, *args, **kwargs):
"""Update the status object"""
try:
success = self.callback(*args, **kwargs)
except Exception as e:
self.log.error(e)
raise
if success:
self.set_finished()
def set_finished(self):
"""Mark as finished successfully."""
self.obj.clear_sub(self.check_value)
super().set_finished()
def _handle_failure(self):
"""Clear subscription on failure, run callbacks through super()"""
self.obj.clear_sub(self.check_value)
return super()._handle_failure()
class CompareStatus(SubscriptionStatus): class CompareStatus(SubscriptionStatus):
""" """
Status to compare a signal value against a given value. Status to compare a signal value against a given value.
@@ -105,7 +245,7 @@ class CompareStatus(SubscriptionStatus):
f"failure_value must be a float, int, str, list or None. Received: {failure_value}" f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
) )
super().__init__( super().__init__(
device=signal, obj=signal,
callback=self._compare_callback, callback=self._compare_callback,
timeout=timeout, timeout=timeout,
settle_time=settle_time, settle_time=settle_time,
@@ -199,7 +339,7 @@ class TransitionStatus(SubscriptionStatus):
self._strict = strict self._strict = strict
self._failure_states = failure_states if failure_states else [] self._failure_states = failure_states if failure_states else []
super().__init__( super().__init__(
device=signal, obj=signal,
callback=self._compare_callback, callback=self._compare_callback,
timeout=timeout, timeout=timeout,
settle_time=settle_time, settle_time=settle_time,
@@ -263,12 +403,14 @@ class TaskKilledError(Exception):
"""Exception raised when a task thread is killed""" """Exception raised when a task thread is killed"""
class TaskStatus(DeviceStatus): class TaskStatus(StatusBase):
"""Thin wrapper around StatusBase to add information about tasks""" """Thin wrapper around StatusBase to add information about tasks"""
def __init__(self, device: Device, *, timeout=None, settle_time=0, done=None, success=None): def __init__(
self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None
):
super().__init__( super().__init__(
device=device, timeout=timeout, settle_time=settle_time, done=done, success=success obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success
) )
self._state = TaskState.NOT_STARTED self._state = TaskState.NOT_STARTED
self._task_id = str(uuid.uuid4()) self._task_id = str(uuid.uuid4())
@@ -312,7 +454,7 @@ class TaskHandler:
""" """
task_args = task_args if task_args else () task_args = task_args if task_args else ()
task_kwargs = task_kwargs if task_kwargs else {} task_kwargs = task_kwargs if task_kwargs else {}
task_status = TaskStatus(device=self._parent) task_status = TaskStatus(self._parent)
thread = threading.Thread( thread = threading.Thread(
target=self._wrap_task, target=self._wrap_task,
args=(task, task_args, task_kwargs, task_status), args=(task, task_args, task_kwargs, task_status),

View File

@@ -12,6 +12,14 @@ from ophyd import Device, EpicsSignalRO, Signal
from ophyd.status import WaitTimeoutError from ophyd.status import WaitTimeoutError
from typeguard import TypeCheckError from typeguard import TypeCheckError
from ophyd_devices import (
AndStatus,
DeviceStatus,
MoveStatus,
Status,
StatusBase,
SubscriptionStatus,
)
from ophyd_devices.tests.utils import MockPV from ophyd_devices.tests.utils import MockPV
from ophyd_devices.utils.bec_signals import ( from ophyd_devices.utils.bec_signals import (
AsyncMultiSignal, AsyncMultiSignal,
@@ -76,8 +84,8 @@ def test_utils_file_handler_has_full_path(file_handler):
def test_utils_task_status(device): def test_utils_task_status(device):
"""Test TaskStatus creation""" """Test TaskStatus creation"""
status = TaskStatus(device=device) status = TaskStatus(device)
assert status.device.name == "device" assert status.obj.name == "device"
assert status.state == "not_started" assert status.state == "not_started"
assert status.task_id == status._task_id assert status.task_id == status._task_id
status.state = "running" status.state = "running"
@@ -929,3 +937,68 @@ def test_transition_status_with_mock_pv(
status.wait(timeout=1) status.wait(timeout=1)
assert status.done is False assert status.done is False
assert status.success is False assert status.success is False
def test_patched_status_objects():
"""Test the patched Status objects in ophyd_devices that improve error handling."""
# StatusBase & AndStatus
st = StatusBase()
st2 = StatusBase()
and_st = st & st2
assert isinstance(and_st, AndStatus)
st.set_exception(ValueError("test error"))
with pytest.raises(ValueError):
and_st.wait(timeout=10)
# DeviceStatus & StatusBase
dev = Device(name="device")
dev_status = DeviceStatus(device=dev)
assert dev_status.device == dev
dev_status.set_exception(RuntimeError("device error"))
# Combine DeviceStatus with StatusBase and form AndStatus
st = StatusBase(obj=dev)
assert st.obj == dev
dev_st = DeviceStatus(device=dev)
combined_st = st & dev_st
st.set_finished()
dev_st.set_exception(RuntimeError("combined error"))
with pytest.raises(RuntimeError):
combined_st.wait(timeout=10)
# SubscriptionStatus
sig = Signal(name="test_signal", value=0)
def _cb(*args, **kwargs):
pass
sub_st = SubscriptionStatus(sig, callback=_cb)
sub_st.set_exception(ValueError("subscription error"))
with pytest.raises(ValueError):
sub_st.wait(timeout=10)
assert sub_st.done is True
assert sub_st.success is False
# MoveStatus, here the default for call_stop_on_failure is True
class Positioner(Device):
SUB_READBACK = "readback"
setpoint = Signal(name="setpoint", value=0)
readback = Signal(name="readback", value=0)
@property
def position(self):
return self.readback.get()
def stop(self):
pass
pos = Positioner(name="positioner")
move_st = MoveStatus(pos, target=10)
with mock.patch.object(pos, "stop") as mock_stop:
move_st.set_exception(RuntimeError("move error"))
mock_stop.assert_called_once()
with pytest.raises(RuntimeError):
move_st.wait(timeout=10)
assert move_st.done is True
assert move_st.success is False