mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-06 23:28:41 +01:00
fix(status): Add wrappers for ophyd status objects to improve error handling
This commit is contained in:
@@ -12,7 +12,11 @@ from bec_lib.file_utils import get_full_path
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.utils.import_utils import lazy_import_from
|
||||
from ophyd import Device, Signal
|
||||
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus
|
||||
from ophyd.status import DeviceStatus as _DeviceStatus
|
||||
from ophyd.status import MoveStatus as _MoveStatus
|
||||
from ophyd.status import Status as _Status
|
||||
from ophyd.status import StatusBase as _StatusBase
|
||||
from ophyd.status import SubscriptionStatus as _SubscriptionStatus
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_lib.messages import ScanStatusMessage
|
||||
@@ -46,6 +50,142 @@ OP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class StatusBase(_StatusBase):
|
||||
"""Base class for all status objects."""
|
||||
|
||||
def __init__(
|
||||
self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None
|
||||
):
|
||||
self.obj = obj
|
||||
super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success)
|
||||
|
||||
def __and__(self, other):
|
||||
"""Returns a new 'composite' status object, AndStatus"""
|
||||
return AndStatus(self, other)
|
||||
|
||||
|
||||
class AndStatus(StatusBase):
|
||||
"""
|
||||
A Status that has composes two other Status objects using logical and.
|
||||
If any of the two Status objects fails, the combined status will fail
|
||||
with the exception of the first Status to fail.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left: StatusBase
|
||||
The left-hand Status object
|
||||
right: StatusBase
|
||||
The right-hand Status object
|
||||
"""
|
||||
|
||||
def __init__(self, left, right, **kwargs):
|
||||
self.left = left
|
||||
self.right = right
|
||||
super().__init__(**kwargs)
|
||||
self._trace_attributes["left"] = self.left._trace_attributes
|
||||
self._trace_attributes["right"] = self.right._trace_attributes
|
||||
|
||||
def inner(status):
|
||||
with self._lock:
|
||||
if self._externally_initiated_completion:
|
||||
return
|
||||
|
||||
# Return if status is already done..
|
||||
if self.done:
|
||||
return
|
||||
|
||||
with status._lock:
|
||||
if status.done and not status.success:
|
||||
self.set_exception(status.exception()) # st._exception
|
||||
return
|
||||
if self.left.done and self.right.done and self.left.success and self.right.success:
|
||||
self.set_finished()
|
||||
|
||||
self.left.add_callback(inner)
|
||||
self.right.add_callback(inner)
|
||||
|
||||
def __repr__(self):
|
||||
return "({self.left!r} & {self.right!r})".format(self=self)
|
||||
|
||||
def __str__(self):
|
||||
return "{0}(done={1.done}, " "success={1.success})" "".format(self.__class__.__name__, self)
|
||||
|
||||
def __contains__(self, status) -> bool:
|
||||
for child in [self.left, self.right]:
|
||||
if child == status:
|
||||
return True
|
||||
if isinstance(child, AndStatus):
|
||||
if status in child:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Status(_Status):
|
||||
"""Thin wrapper around StatusBase to add __and__ operator."""
|
||||
|
||||
def __and__(self, other):
|
||||
"""Returns a new 'composite' status object, AndStatus"""
|
||||
return AndStatus(self, other)
|
||||
|
||||
|
||||
class DeviceStatus(_DeviceStatus):
|
||||
"""Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False"""
|
||||
|
||||
def __and__(self, other):
|
||||
"""Returns a new 'composite' status object, AndStatus"""
|
||||
return AndStatus(self, other)
|
||||
|
||||
|
||||
class MoveStatus(_MoveStatus):
|
||||
"""Thin wrapper around MoveStatus to ensure __and__ operator and stop on failure."""
|
||||
|
||||
def __and__(self, other):
|
||||
"""Returns a new 'composite' status object, AndStatus"""
|
||||
return AndStatus(self, other)
|
||||
|
||||
|
||||
class SubscriptionStatus(StatusBase):
|
||||
"""Subscription status implementation based on wrapped StatusBase implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Device | Signal,
|
||||
callback: Callable,
|
||||
event_type=None,
|
||||
timeout=None,
|
||||
settle_time=None,
|
||||
run=True,
|
||||
):
|
||||
# Store device and attribute information
|
||||
self.callback = callback
|
||||
self.obj = obj
|
||||
# Start timeout thread in the background
|
||||
super().__init__(obj=obj, timeout=timeout, settle_time=settle_time)
|
||||
|
||||
self.obj.subscribe(self.check_value, event_type=event_type, run=run)
|
||||
|
||||
def check_value(self, *args, **kwargs):
|
||||
"""Update the status object"""
|
||||
try:
|
||||
success = self.callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.log.error(e)
|
||||
raise
|
||||
if success:
|
||||
self.set_finished()
|
||||
|
||||
def set_finished(self):
|
||||
"""Mark as finished successfully."""
|
||||
self.obj.clear_sub(self.check_value)
|
||||
super().set_finished()
|
||||
|
||||
def _handle_failure(self):
|
||||
"""Clear subscription on failure, run callbacks through super()"""
|
||||
self.obj.clear_sub(self.check_value)
|
||||
return super()._handle_failure()
|
||||
|
||||
|
||||
class CompareStatus(SubscriptionStatus):
|
||||
"""
|
||||
Status to compare a signal value against a given value.
|
||||
@@ -105,7 +245,7 @@ class CompareStatus(SubscriptionStatus):
|
||||
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
|
||||
)
|
||||
super().__init__(
|
||||
device=signal,
|
||||
obj=signal,
|
||||
callback=self._compare_callback,
|
||||
timeout=timeout,
|
||||
settle_time=settle_time,
|
||||
@@ -199,7 +339,7 @@ class TransitionStatus(SubscriptionStatus):
|
||||
self._strict = strict
|
||||
self._failure_states = failure_states if failure_states else []
|
||||
super().__init__(
|
||||
device=signal,
|
||||
obj=signal,
|
||||
callback=self._compare_callback,
|
||||
timeout=timeout,
|
||||
settle_time=settle_time,
|
||||
@@ -263,12 +403,14 @@ class TaskKilledError(Exception):
|
||||
"""Exception raised when a task thread is killed"""
|
||||
|
||||
|
||||
class TaskStatus(DeviceStatus):
|
||||
class TaskStatus(StatusBase):
|
||||
"""Thin wrapper around StatusBase to add information about tasks"""
|
||||
|
||||
def __init__(self, device: Device, *, timeout=None, settle_time=0, done=None, success=None):
|
||||
def __init__(
|
||||
self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None
|
||||
):
|
||||
super().__init__(
|
||||
device=device, timeout=timeout, settle_time=settle_time, done=done, success=success
|
||||
obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success
|
||||
)
|
||||
self._state = TaskState.NOT_STARTED
|
||||
self._task_id = str(uuid.uuid4())
|
||||
@@ -312,7 +454,7 @@ class TaskHandler:
|
||||
"""
|
||||
task_args = task_args if task_args else ()
|
||||
task_kwargs = task_kwargs if task_kwargs else {}
|
||||
task_status = TaskStatus(device=self._parent)
|
||||
task_status = TaskStatus(self._parent)
|
||||
thread = threading.Thread(
|
||||
target=self._wrap_task,
|
||||
args=(task, task_args, task_kwargs, task_status),
|
||||
|
||||
Reference in New Issue
Block a user