mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-06 23:28:41 +01:00
fix(status): Add wrappers for ophyd status objects to improve error handling
This commit is contained in:
@@ -12,7 +12,11 @@ from bec_lib.file_utils import get_full_path
|
|||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
from bec_lib.utils.import_utils import lazy_import_from
|
from bec_lib.utils.import_utils import lazy_import_from
|
||||||
from ophyd import Device, Signal
|
from ophyd import Device, Signal
|
||||||
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus
|
from ophyd.status import DeviceStatus as _DeviceStatus
|
||||||
|
from ophyd.status import MoveStatus as _MoveStatus
|
||||||
|
from ophyd.status import Status as _Status
|
||||||
|
from ophyd.status import StatusBase as _StatusBase
|
||||||
|
from ophyd.status import SubscriptionStatus as _SubscriptionStatus
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from bec_lib.messages import ScanStatusMessage
|
from bec_lib.messages import ScanStatusMessage
|
||||||
@@ -46,6 +50,142 @@ OP_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StatusBase(_StatusBase):
|
||||||
|
"""Base class for all status objects."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None
|
||||||
|
):
|
||||||
|
self.obj = obj
|
||||||
|
super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success)
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
"""Returns a new 'composite' status object, AndStatus"""
|
||||||
|
return AndStatus(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
class AndStatus(StatusBase):
|
||||||
|
"""
|
||||||
|
A Status that has composes two other Status objects using logical and.
|
||||||
|
If any of the two Status objects fails, the combined status will fail
|
||||||
|
with the exception of the first Status to fail.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
left: StatusBase
|
||||||
|
The left-hand Status object
|
||||||
|
right: StatusBase
|
||||||
|
The right-hand Status object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, left, right, **kwargs):
|
||||||
|
self.left = left
|
||||||
|
self.right = right
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._trace_attributes["left"] = self.left._trace_attributes
|
||||||
|
self._trace_attributes["right"] = self.right._trace_attributes
|
||||||
|
|
||||||
|
def inner(status):
|
||||||
|
with self._lock:
|
||||||
|
if self._externally_initiated_completion:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Return if status is already done..
|
||||||
|
if self.done:
|
||||||
|
return
|
||||||
|
|
||||||
|
with status._lock:
|
||||||
|
if status.done and not status.success:
|
||||||
|
self.set_exception(status.exception()) # st._exception
|
||||||
|
return
|
||||||
|
if self.left.done and self.right.done and self.left.success and self.right.success:
|
||||||
|
self.set_finished()
|
||||||
|
|
||||||
|
self.left.add_callback(inner)
|
||||||
|
self.right.add_callback(inner)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "({self.left!r} & {self.right!r})".format(self=self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{0}(done={1.done}, " "success={1.success})" "".format(self.__class__.__name__, self)
|
||||||
|
|
||||||
|
def __contains__(self, status) -> bool:
|
||||||
|
for child in [self.left, self.right]:
|
||||||
|
if child == status:
|
||||||
|
return True
|
||||||
|
if isinstance(child, AndStatus):
|
||||||
|
if status in child:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Status(_Status):
|
||||||
|
"""Thin wrapper around StatusBase to add __and__ operator."""
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
"""Returns a new 'composite' status object, AndStatus"""
|
||||||
|
return AndStatus(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStatus(_DeviceStatus):
|
||||||
|
"""Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False"""
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
"""Returns a new 'composite' status object, AndStatus"""
|
||||||
|
return AndStatus(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
class MoveStatus(_MoveStatus):
|
||||||
|
"""Thin wrapper around MoveStatus to ensure __and__ operator and stop on failure."""
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
"""Returns a new 'composite' status object, AndStatus"""
|
||||||
|
return AndStatus(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionStatus(StatusBase):
|
||||||
|
"""Subscription status implementation based on wrapped StatusBase implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obj: Device | Signal,
|
||||||
|
callback: Callable,
|
||||||
|
event_type=None,
|
||||||
|
timeout=None,
|
||||||
|
settle_time=None,
|
||||||
|
run=True,
|
||||||
|
):
|
||||||
|
# Store device and attribute information
|
||||||
|
self.callback = callback
|
||||||
|
self.obj = obj
|
||||||
|
# Start timeout thread in the background
|
||||||
|
super().__init__(obj=obj, timeout=timeout, settle_time=settle_time)
|
||||||
|
|
||||||
|
self.obj.subscribe(self.check_value, event_type=event_type, run=run)
|
||||||
|
|
||||||
|
def check_value(self, *args, **kwargs):
|
||||||
|
"""Update the status object"""
|
||||||
|
try:
|
||||||
|
success = self.callback(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error(e)
|
||||||
|
raise
|
||||||
|
if success:
|
||||||
|
self.set_finished()
|
||||||
|
|
||||||
|
def set_finished(self):
|
||||||
|
"""Mark as finished successfully."""
|
||||||
|
self.obj.clear_sub(self.check_value)
|
||||||
|
super().set_finished()
|
||||||
|
|
||||||
|
def _handle_failure(self):
|
||||||
|
"""Clear subscription on failure, run callbacks through super()"""
|
||||||
|
self.obj.clear_sub(self.check_value)
|
||||||
|
return super()._handle_failure()
|
||||||
|
|
||||||
|
|
||||||
class CompareStatus(SubscriptionStatus):
|
class CompareStatus(SubscriptionStatus):
|
||||||
"""
|
"""
|
||||||
Status to compare a signal value against a given value.
|
Status to compare a signal value against a given value.
|
||||||
@@ -105,7 +245,7 @@ class CompareStatus(SubscriptionStatus):
|
|||||||
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
|
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=signal,
|
obj=signal,
|
||||||
callback=self._compare_callback,
|
callback=self._compare_callback,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
settle_time=settle_time,
|
settle_time=settle_time,
|
||||||
@@ -199,7 +339,7 @@ class TransitionStatus(SubscriptionStatus):
|
|||||||
self._strict = strict
|
self._strict = strict
|
||||||
self._failure_states = failure_states if failure_states else []
|
self._failure_states = failure_states if failure_states else []
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=signal,
|
obj=signal,
|
||||||
callback=self._compare_callback,
|
callback=self._compare_callback,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
settle_time=settle_time,
|
settle_time=settle_time,
|
||||||
@@ -263,12 +403,14 @@ class TaskKilledError(Exception):
|
|||||||
"""Exception raised when a task thread is killed"""
|
"""Exception raised when a task thread is killed"""
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(DeviceStatus):
|
class TaskStatus(StatusBase):
|
||||||
"""Thin wrapper around StatusBase to add information about tasks"""
|
"""Thin wrapper around StatusBase to add information about tasks"""
|
||||||
|
|
||||||
def __init__(self, device: Device, *, timeout=None, settle_time=0, done=None, success=None):
|
def __init__(
|
||||||
|
self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=device, timeout=timeout, settle_time=settle_time, done=done, success=success
|
obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success
|
||||||
)
|
)
|
||||||
self._state = TaskState.NOT_STARTED
|
self._state = TaskState.NOT_STARTED
|
||||||
self._task_id = str(uuid.uuid4())
|
self._task_id = str(uuid.uuid4())
|
||||||
@@ -312,7 +454,7 @@ class TaskHandler:
|
|||||||
"""
|
"""
|
||||||
task_args = task_args if task_args else ()
|
task_args = task_args if task_args else ()
|
||||||
task_kwargs = task_kwargs if task_kwargs else {}
|
task_kwargs = task_kwargs if task_kwargs else {}
|
||||||
task_status = TaskStatus(device=self._parent)
|
task_status = TaskStatus(self._parent)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._wrap_task,
|
target=self._wrap_task,
|
||||||
args=(task, task_args, task_kwargs, task_status),
|
args=(task, task_args, task_kwargs, task_status),
|
||||||
|
|||||||
@@ -12,6 +12,14 @@ from ophyd import Device, EpicsSignalRO, Signal
|
|||||||
from ophyd.status import WaitTimeoutError
|
from ophyd.status import WaitTimeoutError
|
||||||
from typeguard import TypeCheckError
|
from typeguard import TypeCheckError
|
||||||
|
|
||||||
|
from ophyd_devices import (
|
||||||
|
AndStatus,
|
||||||
|
DeviceStatus,
|
||||||
|
MoveStatus,
|
||||||
|
Status,
|
||||||
|
StatusBase,
|
||||||
|
SubscriptionStatus,
|
||||||
|
)
|
||||||
from ophyd_devices.tests.utils import MockPV
|
from ophyd_devices.tests.utils import MockPV
|
||||||
from ophyd_devices.utils.bec_signals import (
|
from ophyd_devices.utils.bec_signals import (
|
||||||
AsyncMultiSignal,
|
AsyncMultiSignal,
|
||||||
@@ -76,8 +84,8 @@ def test_utils_file_handler_has_full_path(file_handler):
|
|||||||
|
|
||||||
def test_utils_task_status(device):
|
def test_utils_task_status(device):
|
||||||
"""Test TaskStatus creation"""
|
"""Test TaskStatus creation"""
|
||||||
status = TaskStatus(device=device)
|
status = TaskStatus(device)
|
||||||
assert status.device.name == "device"
|
assert status.obj.name == "device"
|
||||||
assert status.state == "not_started"
|
assert status.state == "not_started"
|
||||||
assert status.task_id == status._task_id
|
assert status.task_id == status._task_id
|
||||||
status.state = "running"
|
status.state = "running"
|
||||||
@@ -929,3 +937,68 @@ def test_transition_status_with_mock_pv(
|
|||||||
status.wait(timeout=1)
|
status.wait(timeout=1)
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
assert status.success is False
|
assert status.success is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_status_objects():
|
||||||
|
"""Test the patched Status objects in ophyd_devices that improve error handling."""
|
||||||
|
|
||||||
|
# StatusBase & AndStatus
|
||||||
|
st = StatusBase()
|
||||||
|
st2 = StatusBase()
|
||||||
|
and_st = st & st2
|
||||||
|
assert isinstance(and_st, AndStatus)
|
||||||
|
st.set_exception(ValueError("test error"))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
and_st.wait(timeout=10)
|
||||||
|
|
||||||
|
# DeviceStatus & StatusBase
|
||||||
|
dev = Device(name="device")
|
||||||
|
dev_status = DeviceStatus(device=dev)
|
||||||
|
assert dev_status.device == dev
|
||||||
|
dev_status.set_exception(RuntimeError("device error"))
|
||||||
|
|
||||||
|
# Combine DeviceStatus with StatusBase and form AndStatus
|
||||||
|
st = StatusBase(obj=dev)
|
||||||
|
assert st.obj == dev
|
||||||
|
dev_st = DeviceStatus(device=dev)
|
||||||
|
combined_st = st & dev_st
|
||||||
|
st.set_finished()
|
||||||
|
dev_st.set_exception(RuntimeError("combined error"))
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
combined_st.wait(timeout=10)
|
||||||
|
|
||||||
|
# SubscriptionStatus
|
||||||
|
sig = Signal(name="test_signal", value=0)
|
||||||
|
|
||||||
|
def _cb(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
sub_st = SubscriptionStatus(sig, callback=_cb)
|
||||||
|
sub_st.set_exception(ValueError("subscription error"))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sub_st.wait(timeout=10)
|
||||||
|
assert sub_st.done is True
|
||||||
|
assert sub_st.success is False
|
||||||
|
|
||||||
|
# MoveStatus, here the default for call_stop_on_failure is True
|
||||||
|
class Positioner(Device):
|
||||||
|
SUB_READBACK = "readback"
|
||||||
|
setpoint = Signal(name="setpoint", value=0)
|
||||||
|
readback = Signal(name="readback", value=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position(self):
|
||||||
|
return self.readback.get()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
pos = Positioner(name="positioner")
|
||||||
|
move_st = MoveStatus(pos, target=10)
|
||||||
|
with mock.patch.object(pos, "stop") as mock_stop:
|
||||||
|
move_st.set_exception(RuntimeError("move error"))
|
||||||
|
mock_stop.assert_called_once()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
move_st.wait(timeout=10)
|
||||||
|
assert move_st.done is True
|
||||||
|
assert move_st.success is False
|
||||||
|
|||||||
Reference in New Issue
Block a user