diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 4794e2b..fdf4eee 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -6,12 +6,11 @@ import threading import traceback import uuid from enum import Enum -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING, Callable, Literal, Union from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from -from ophyd import Device, Signal from ophyd.status import DeviceStatus as _DeviceStatus from ophyd.status import MoveStatus as _MoveStatus from ophyd.status import Status as _Status @@ -19,6 +18,7 @@ from ophyd.status import StatusBase as _StatusBase if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import ScanStatusMessage + from ophyd import Device, Signal else: # TODO: put back normal import when Pydantic gets faster ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",)) @@ -53,7 +53,13 @@ class StatusBase(_StatusBase): """Base class for all status objects.""" def __init__( - self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None + self, + obj: Union["Device", None] = None, + *, + timeout=None, + settle_time=0, + done=None, + success=None, ): self.obj = obj super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success) @@ -129,7 +135,7 @@ class Status(_Status): class DeviceStatus(_DeviceStatus): - """Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False""" + """Thin wrapper around DeviceStatus to add __and__ operator.""" def __and__(self, other): """Returns a new 'composite' status object, AndStatus""" @@ -149,7 +155,7 @@ class SubscriptionStatus(StatusBase): def __init__( self, - obj: Device | Signal, + obj: Union["Device", "Signal"], callback: Callable, event_type=None, timeout=None, @@ -328,6 +334,8 @@ class TransitionStatus(SubscriptionStatus): ): self._signal = signal self._transitions = tuple(transitions) + if not transitions: + raise ValueError("Transitions {transitions}must contain at least one value") self._index = 0 self._strict = strict self._failure_states = failure_states if failure_states else [] @@ -397,7 +405,13 @@ class TaskStatus(StatusBase): """Thin wrapper around StatusBase to add information about tasks""" def __init__( - self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None + self, + obj: Union["Device", "Signal"], + *, + timeout=None, + settle_time=0, + done=None, + success=None, ): super().__init__( obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success @@ -423,7 +437,7 @@ class TaskStatus(StatusBase): class TaskHandler: """Handler to manage asynchronous tasks""" - def __init__(self, parent: Device): + def __init__(self, parent: "Device"): """Initialize the handler""" self._tasks = {} self._parent = parent diff --git a/tests/test_utils.py b/tests/test_utils.py index aec8141..7aa83e6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -965,11 +965,16 @@ def test_patched_status_objects(): with pytest.raises(ValueError): and_st.wait(timeout=10) - # DeviceStatus & StatusBase + # DeviceStatus & Status dev = Device(name="device") dev_status = DeviceStatus(device=dev) + + st = Status() + and_st = st and dev_status assert dev_status.device == dev dev_status.set_exception(RuntimeError("device error")) + with pytest.raises(RuntimeError): + and_st.wait(timeout=10) # Combine DeviceStatus with StatusBase and form AndStatus st = StatusBase(obj=dev)