diff --git a/ophyd_devices/__init__.py b/ophyd_devices/__init__.py index 2aa07f2..0818817 100644 --- a/ophyd_devices/__init__.py +++ b/ophyd_devices/__init__.py @@ -20,5 +20,5 @@ from .devices.softpositioner import SoftPositioner from .utils.bec_device_base import BECDeviceBase from .utils.bec_signals import * from .utils.dynamic_pseudo import ComputedSignal -from .utils.psi_device_base_utils import CompareStatus, TargetStatus +from .utils.psi_device_base_utils import * from .utils.static_device_test import launch diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index cdbed6f..aace679 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, Callable, Literal from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from -from ophyd import Device, DeviceStatus, Signal -from ophyd.status import SubscriptionStatus +from ophyd import Device, Signal +from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import ScanStatusMessage @@ -21,6 +21,17 @@ else: ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",)) +__all__ = [ + "CompareStatus", + "TransitionStatus", + "AndStatus", + "DeviceStatus", + "MoveStatus", + "Status", + "StatusBase", + "SubscriptionStatus", +] + logger = bec_logger.logger set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc @@ -36,12 +47,26 @@ OP_MAP = { class CompareStatus(SubscriptionStatus): - """Status class to compare a value from a device signal with a target value.""" + """ + Status class to compare a value from a device signal with a target value. + The value can be a float, int, or string. If the value is a string, + the operation must be either '==' or '!='. For numeric (float or int) values, + the operation can be any of the standard comparison operators. + + Args: + signal: The device signal to compare. + value: The target value to compare against. + operation: The comparison operation to use. Defaults to '=='. + event_type: The type of event to trigger on comparison. Defaults to None (default sub). + timeout: The timeout for the status. Defaults to None (indefinite). + settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + run: Whether to run the status immediately or not. Defaults to True. + """ def __init__( self, signal: Signal, - value: Any, + value: float | int | str, *, operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", event_type=None, @@ -49,6 +74,11 @@ class CompareStatus(SubscriptionStatus): settle_time: float = 0, run: bool = True, ): + if isinstance(value, str): + if operation not in ("==", "!="): + raise ValueError( + f"Invalid operation: {operation} for string comparison. Must be '==' or '!='." + ) if operation not in ("==", "!=", "<", "<=", ">", ">="): raise ValueError( f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." @@ -66,25 +96,52 @@ class CompareStatus(SubscriptionStatus): ) def _compare_callback(self, value, **kwargs) -> bool: - """Callback for subscription Status""" + """Callback for subscription status""" return OP_MAP[self._operation](value, self._value) -class TargetStatus(SubscriptionStatus): - """Status class to compare a list of values that are expected to be reached in sequence for a device signal.""" +class TransitionStatus(SubscriptionStatus): + """ + Status class to compare a list of transitions. + The transitions can be a list of float, int, or string values. + The transitions are checked in order, and the status is finished when all transitions + have been matched in sequence. The keyword argument `strict` determines whether + the transitions must match exactly in order, or if intermediate transitions are allowed. + For the first value, the strict check is not applied, meaning that the sequence starts once + the first transition is matched. + + Args: + signal: The device signal to compare. + transitions: A list of transitions to compare against. + strict: Whether to enforce strict matching of transitions. Defaults to True. + run: Whether to run the status immediately or not. Defaults to True. + event_type: The type of event to trigger on comparison. Defaults to None (default sub). + timeout: The timeout for the status. Defaults to None (indefinite). + settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + + Raises: + ValueError: If the transitions do not match the expected sequence. and strict is True. + """ def __init__( self, signal: Signal, - values: list[Any], + transitions: list[float | int | str], *, + strict: bool = True, + raise_states: list[float | int | str] | None = None, + run: bool = True, event_type=None, timeout: float = None, settle_time: float = 0, - run: bool = True, ): self._signal = signal - self._values = values + if not isinstance(transitions, list): + raise ValueError(f"Transitions must be a list of values. Received: {transitions}") + self._transitions = transitions + self._index = 0 + self._strict = strict + self._raise_states = raise_states if raise_states else [] super().__init__( device=signal, callback=self._compare_callback, @@ -94,13 +151,34 @@ class TargetStatus(SubscriptionStatus): run=run, ) - def _compare_callback(self, value, **kwargs) -> bool: + def _compare_callback(self, old_value, value, **kwargs) -> bool: """Callback for subscription Status""" - if value == self._values[0]: - self._values.pop(0) - if len(self._values) == 0: - return True - return False + if value in self._raise_states: + self.set_exception( + ValueError( + f"Transition raised an exception: {value}. " + f"Expected transitions: {self._transitions}." + ) + ) + return False + if self._index == 0: + if value == self._transitions[0]: + self._index += 1 + else: + if self._strict: + if ( + old_value == self._transitions[self._index - 1] + and value == self._transitions[self._index] + ): + self._index += 1 + else: + if value == self._transitions[self._index]: + self._index += 1 + return self._is_finished() + + def _is_finished(self) -> bool: + """Check if the status is finished""" + return self._index >= len(self._transitions) class TaskState(str, Enum): diff --git a/tests/test_utils.py b/tests/test_utils.py index 1c2c883..f780cb6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,11 +16,11 @@ from ophyd_devices.utils.bec_signals import ( from ophyd_devices.utils.psi_device_base_utils import ( CompareStatus, FileHandler, - TargetStatus, TaskHandler, TaskKilledError, TaskState, TaskStatus, + TransitionStatus, ) # pylint: disable=protected-access @@ -529,20 +529,7 @@ def test_utils_progress_signal(): signal.put({"wrong_key": "wrong_value"}) -def test_utils_target_status(): - """Test TargetStatus""" - sig = Signal(name="test_signal", value=0) - status = TargetStatus(signal=sig, values=[1, 2, 3]) - assert status.done is False - sig.put(1) - assert status.done is False - sig.put(2) - assert status.done is False - sig.put(3) - assert status.done is True - - -def test_utils_compare_status(): +def test_utils_compare_status_number(): """Test CompareStatus""" sig = Signal(name="test_signal", value=0) status = CompareStatus(signal=sig, value=5, operation="==") @@ -560,6 +547,8 @@ def test_utils_compare_status(): assert status.done is False sig.put(6) assert status.done is True + assert status.success is True + assert status.exception() is None sig.put(0) status = CompareStatus(signal=sig, value=5, operation=">") @@ -568,3 +557,80 @@ def test_utils_compare_status(): assert status.done is False sig.put(10) assert status.done is True + assert status.success is True + assert status.exception() is None + + +def test_compare_status_string(): + """Test CompareStatus with string values""" + sig = Signal(name="test_signal", value="test") + status = CompareStatus(signal=sig, value="test", operation="==") + assert status.done is False + sig.put("test1") + assert status.done is False + sig.put("test") + assert status.done is True + + sig.put("test") + # Test with different operations + status = CompareStatus(signal=sig, value="test", operation="!=") + assert status.done is False + sig.put("test") + assert status.done is False + sig.put("test1") + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test with greater than operation + # Raises ValueError for strings + sig.put("a") + with pytest.raises(ValueError): + status = CompareStatus(signal=sig, value="b", operation=">") + + +def test_transition_status(): + """Test TransitionStatus""" + sig = Signal(name="test_signal", value=0) + + # Test strict=True, without intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True) + + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(2) + assert status.done is False + sig.put(3) + assert status.done is True + assert status.success is True + assert status.exception() is None + + # Test strict=True, ra + sig.put(1) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4]) + assert status.done is False + sig.put(4) + with pytest.raises(ValueError): + status.wait() + + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # Test strict=False, with intermediate transitions + sig.put(0) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=False) + + assert status.done is False + sig.put(1) # entering first transition + sig.put(3) + sig.put(2) # transision + assert status.done is False + sig.put(4) + sig.put(2) + sig.put(3) # last transition + assert status.done is True + assert status.success is True + assert status.exception() is None