diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index b1a2a1d..0e7f1ea 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -26,6 +26,7 @@ else: __all__ = [ "CompareStatus", + "ExceptionStatus", "TransitionStatus", "AndStatus", "DeviceStatus", @@ -52,6 +53,8 @@ OP_MAP = { class StatusBase(_StatusBase): """Base class for all status objects.""" + _blocks_success = True + def __init__( self, obj: Union["Device", None] = None, @@ -68,6 +71,15 @@ class StatusBase(_StatusBase): """Returns a new 'composite' status object, AndStatus""" return AndStatus(self, other) + @property + def blocks_success(self) -> bool: + """Whether this status must resolve successfully for a composite to succeed.""" + return self._blocks_success + + def _cleanup(self) -> None: + """Release resources held by the status once a composite no longer needs it.""" + return None + class AndStatus(StatusBase): """ @@ -106,9 +118,11 @@ class AndStatus(StatusBase): with status._lock: if status.done and not status.success: + self._cleanup() self.set_exception(status.exception()) # st._exception return - if self.left.done and self.right.done and self.left.success and self.right.success: + if self._required_statuses_succeeded(): + self._cleanup() self.set_finished() self.left.add_callback(inner) @@ -130,6 +144,30 @@ class AndStatus(StatusBase): return False + @property + def blocks_success(self) -> bool: + return self._child_blocks_success(self.left) or self._child_blocks_success(self.right) + + def _required_statuses_succeeded(self) -> bool: + return all( + not self._child_blocks_success(child) or (child.done and child.success) + for child in (self.left, self.right) + ) + + def _cleanup(self) -> None: + self._cleanup_child(self.left) + self._cleanup_child(self.right) + + @staticmethod + def _child_blocks_success(child) -> bool: + return getattr(child, "blocks_success", True) + + @staticmethod + def _cleanup_child(child) -> None: + cleanup = getattr(child, "_cleanup", None) + if cleanup is not None: + cleanup() + class Status(_Status): """Thin wrapper around StatusBase to add __and__ operator.""" @@ -187,14 +225,17 @@ class SubscriptionStatus(StatusBase): def set_finished(self): """Mark as finished successfully.""" - self.obj.clear_sub(self.check_value) + self._cleanup() super().set_finished() def _handle_failure(self): """Clear subscription on failure, run callbacks through super()""" - self.obj.clear_sub(self.check_value) + self._cleanup() return super()._handle_failure() + def _cleanup(self) -> None: + self.obj.clear_sub(self.check_value) + class CompareStatus(SubscriptionStatus): """ @@ -292,6 +333,59 @@ class CompareStatus(SubscriptionStatus): return False +class ExceptionStatus(CompareStatus): + """ + Status to watch for an error condition on a signal without blocking composite success. + + The status remains pending while the monitored value is in its expected state. If the + comparison matches, the status fails immediately and any composite AndStatus containing + it will fail as well. Unlike CompareStatus, this status never completes successfully on + its own and is intended to be combined with primary statuses using ``&``. + """ + + _blocks_success = False + + def __init__( + self, + signal: "Signal", + value: float | int | str, + *, + operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + timeout: float = None, + settle_time: float = 0, + run: bool = True, + event_type=None, + exception: Exception | None = None, + ): + super().__init__( + signal=signal, + value=value, + operation_success=operation, + timeout=timeout, + settle_time=settle_time, + run=run, + event_type=event_type, + ) + self._configured_exception = exception + + def _compare_callback(self, value: any, **kwargs) -> bool: + try: + if isinstance(value, list): + raise ValueError(f"List values are not supported. Received value: {value}") + if self.op_map[self._operation_success](value, self._value): + if self._configured_exception is not None: + raise self._configured_exception + raise ValueError( + f"ExceptionStatus for signal {self._signal.name} reached monitored value " + f"{self._operation_success} {self._value}. Current value: {value}" + ) + return False + except Exception as e: + logger.error(f"Error in ExceptionStatus callback: {e}") + self.set_exception(e) + return False + + class TransitionStatus(SubscriptionStatus): """ Status to monitor transitions of a signal value through a list of specified transitions. diff --git a/tests/test_utils.py b/tests/test_utils.py index f0023ea..8444d8d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,6 +27,7 @@ from ophyd_devices.utils.psi_device_base_utils import ( AndStatus, CompareStatus, DeviceStatus, + ExceptionStatus, FileHandler, MoveStatus, Status, @@ -823,6 +824,63 @@ def test_compare_status_string(): assert status.exception() is None +def test_exception_status_andstatus_does_not_block_success(): + """ExceptionStatus should fail composites early but not block success while pending.""" + sig_primary_a = Signal(name="primary_a", value=0) + sig_primary_b = Signal(name="primary_b", value=0) + sig_watch = Signal(name="watch", value=0) + + primary_a = CompareStatus(signal=sig_primary_a, value=1, operation_success="==") + primary_b = CompareStatus(signal=sig_primary_b, value=2, operation_success="==") + watch = ExceptionStatus(signal=sig_watch, value=0, operation="!=") + + combined = primary_a & primary_b & watch + sig_primary_a.put(1) + assert not combined.done + sig_primary_b.put(2) + combined.wait(timeout=1) + assert combined.done is True + assert combined.success is True + assert watch.done is False + + +def test_exception_status_andstatus_fails_early(): + """ExceptionStatus should abort a composite status when the watched value is reached.""" + sig_primary = Signal(name="primary", value=0) + sig_watch = Signal(name="watch", value=0) + + primary = CompareStatus(signal=sig_primary, value=1, operation_success="==") + watch = ExceptionStatus(signal=sig_watch, value=0, operation="!=") + combined = primary & watch + + sig_watch.put(1) + with pytest.raises(ValueError): + combined.wait(timeout=1) + assert combined.done is True + assert combined.success is False + + +def test_exception_status_andstatus_fails_early_with_custom_exception(): + """ExceptionStatus should abort a composite status with the specified exception when the watched value is reached.""" + sig_primary = Signal(name="primary", value=0) + sig_watch = Signal(name="watch", value=0) + + primary = CompareStatus(signal=sig_primary, value=1, operation_success="==") + watch = ExceptionStatus( + signal=sig_watch, + value=0, + operation="!=", + exception=RuntimeError("Watch signal reached failure value"), + ) + combined = primary & watch + + sig_watch.put(1) + with pytest.raises(RuntimeError, match="Watch signal reached failure value"): + combined.wait(timeout=1) + assert combined.done is True + assert combined.success is False + + def test_transition_status(): """Test TransitionStatus""" sig = Signal(name="test_signal", value=0)