diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 38f88a1..6e0042d 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -48,44 +48,62 @@ OP_MAP = { class CompareStatus(SubscriptionStatus): """ - Status class to compare a signal value against a given value. + Status to compare a signal value against a given value. The comparison is done using the specified operation, which can be one of '==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed. - The status is finished when the comparison is true. + One may also define a value or list of values that will result in an exception if encountered. + The status is finished when the comparison is either true or an exception is raised. Args: - signal: The device signal to compare. - value: The value to compare against. - operation: The operation to use for comparison. Defaults to '=='. - event_type: The type of event to trigger on comparison. Defaults to None (default sub). - timeout: The timeout for the status. Defaults to None (indefinite). - settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. - run: Whether to run the status callback on creation or not. Defaults to True. + signal (Signal): The signal to monitor. + value (float | int | str): The target value to compare against. + operation_success (str, optional): The comparison operation for success. Defaults to '=='. + failure_value (float | int | str | list[float | int | str] | None, optional): + A value or list of values that will trigger an exception if encountered. Defaults to None. + operation_failure (str, optional): The comparison operation for failure values. Defaults to '=='. + event_type (int, optional): The event type to subscribe to. Defaults to None. + timeout (float, optional): Timeout for the status. Defaults to None. + settle_time (float, optional): Settle time before checking the status. Defaults to 0. + run (bool, optional): Whether to start the status immediately. Defaults to True """ def __init__( self, - signal: Signal, + signal: "Signal", value: float | int | str, *, - operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", - event_type=None, + operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + failure_value: float | int | str | list[float | int | str] | None = None, + operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==", timeout: float = None, settle_time: float = 0, run: bool = True, + event_type=None, ): if isinstance(value, str): - if operation not in ("==", "!="): + if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="): raise ValueError( - f"Invalid operation: {operation} for string comparison. Must be '==' or '!='." + f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='." ) - if operation not in ("==", "!=", "<", "<=", ">", ">="): + if operation_success not in ("==", "!=", "<", "<=", ">", ">="): raise ValueError( - f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." + f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='." ) self._signal = signal self._value = value - self._operation = operation + self._operation_success = operation_success + self._operation_failure = operation_failure + self.op_map = OP_MAP + if failure_value is None: + self._failure_values = [] + elif isinstance(failure_value, (float, int, str)): + self._failure_values = [failure_value] + elif isinstance(failure_value, list): + self._failure_values = failure_value + else: + raise ValueError( + f"failure_value must be a float, int, str, list or None. Received: {failure_value}" + ) super().__init__( device=signal, callback=self._compare_callback, @@ -95,49 +113,91 @@ class CompareStatus(SubscriptionStatus): run=run, ) - def _compare_callback(self, value, **kwargs) -> bool: - """Callback for subscription status""" - return OP_MAP[self._operation](value, self._value) + def _compare_callback(self, value: any, **kwargs) -> bool: + """ + Callback for subscription status + + Args: + value (any): Current value of the signal + + Returns: + bool: True if comparison is successful, False otherwise. + """ + try: + if isinstance(value, list): + self.set_exception( + ValueError(f"List values are not supported. Received value: {value}") + ) + return False + if any( + self.op_map[self._operation_failure](value, failure_value) + for failure_value in self._failure_values + ): + self.set_exception( + ValueError( + f"CompareStatus for signal {self._signal.name} " + f"did not reach the desired state {self._operation_success} {self._value}. " + f"But instead reached {value}, which is in list of failure values: {self._failure_values}" + ) + ) + return False + return self.op_map[self._operation_success](value, self._value) + except Exception as e: + # Catch any exception if the value comparison fails, e.g. value is numpy array + logger.error(f"Error in CompareStatus callback: {e}") + self.set_exception(e) + return False class TransitionStatus(SubscriptionStatus): """ - Status class to monitor transitions of a signal value through a list of specified transitions. + Status to monitor transitions of a signal value through a list of specified transitions. The status is finished when all transitions have been observed in order. The keyword argument - `strict` determines whether the transitions must occur in strict order or not. - If `raise_states` is provided, the status will raise an exception if the signal value matches - any of the values in `raise_states`. + `strict` determines whether the transitions must occur in strict order or not. The strict option + only becomes relevant once the first transition has been observed. + If `failure_states` is provided, the status will raise an exception if the signal value matches + any of the values in `failure_states`. Args: - signal: The device signal to monitor. - transitions: A list of values to transition through. - strict: Whether the transitions must occur in strict order. Defaults to True. - raise_states: A list of values that will raise an exception if encountered. Defaults to None. - run: Whether to run the status callback on creation or not. Defaults to True. - event_type: The type of event to trigger on transition. Defaults to None (default sub). - timeout: The timeout for the status. Defaults to None (indefinite). - settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. + signal (Signal): The signal to monitor. + transitions (list[float | int | str]): List of values representing the transitions to observe. + strict (bool, optional): Whether to enforce strict order of transitions. Defaults to True. + failure_states (list[float | int | str] | None, optional): + A list of values that will trigger an exception if encountered. Defaults to None. + run (bool, optional): Whether to start the status immediately. Defaults to True. + event_type (int, optional): The event type to subscribe to. Defaults to None. + timeout (float, optional): Timeout for the status. Defaults to None. + settle_time (float, optional): Settle time before checking the status. Defaults to 0. + + Notes: + The 'strict' option does not raise if transitions are observed which are out of order. + It only determines whether a transition is accepted if it is observed from the + previous value in the list of transitions to the next value. + For example, with strict=True and transitions=[1, 2, 3], the sequence + 0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status + will not complete. With strict=False, both sequences are accepted. + However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted. + To raise an exception if an out-of-order transition is observed, use the + `failure_states` keyword argument. """ def __init__( self, - signal: Signal, + signal: "Signal", transitions: list[float | int | str], *, strict: bool = True, - raise_states: list[float | int | str] | None = None, + failure_states: list[float | int | str] | None = None, run: bool = True, - event_type=None, timeout: float = None, settle_time: float = 0, + event_type=None, ): self._signal = signal - if not isinstance(transitions, list): - raise ValueError(f"Transitions must be a list of values. Received: {transitions}") - self._transitions = transitions + self._transitions = tuple(transitions) self._index = 0 self._strict = strict - self._raise_states = raise_states if raise_states else [] + self._failure_states = failure_states if failure_states else [] super().__init__( device=signal, callback=self._compare_callback, @@ -147,34 +207,45 @@ class TransitionStatus(SubscriptionStatus): run=run, ) - def _compare_callback(self, old_value, value, **kwargs) -> bool: - """Callback for subscription Status""" - if value in self._raise_states: - self.set_exception( - ValueError( - f"Transition raised an exception: {value}. " - f"Expected transitions: {self._transitions}." + def _compare_callback(self, old_value: any, value: any, **kwargs) -> bool: + """ + Callback for subscription Status + + Args: + old_value (any): Previous value of the signal + value (any): Current value of the signal + + Returns: + bool: True if all transitions have been observed, False otherwise. + """ + try: + if value in self._failure_states: + self.set_exception( + ValueError( + f"Transition Status for {self._signal.name} resulted in a value: {value}. " + f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}." + ) ) - ) - return False - if self._index == 0: - if value == self._transitions[0]: - self._index += 1 - else: - if self._strict: - if ( - old_value == self._transitions[self._index - 1] - and value == self._transitions[self._index] - ): + return False + if self._index == 0: + if value == self._transitions[0]: self._index += 1 else: - if value == self._transitions[self._index]: - self._index += 1 - return self._is_finished() - - def _is_finished(self) -> bool: - """Check if the status is finished""" - return self._index >= len(self._transitions) + if self._strict: + if ( + old_value == self._transitions[self._index - 1] + and value == self._transitions[self._index] + ): + self._index += 1 + else: + if value == self._transitions[self._index]: + self._index += 1 + return self._index >= len(self._transitions) + except Exception as e: + # Catch any exception if the value comparison fails, e.g. value is numpy array + logger.error(f"Error in TransitionStatus callback: {e}") + self.set_exception(e) + return False class TaskState(str, Enum): diff --git a/tests/test_utils.py b/tests/test_utils.py index 9bf0be0..9ff2d67 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -691,18 +691,19 @@ def test_utils_progress_signal(): def test_utils_compare_status_number(): - """Test CompareStatus""" + """Test CompareStatus with different operations.""" sig = Signal(name="test_signal", value=0) - status = CompareStatus(signal=sig, value=5, operation="==") + status = CompareStatus(signal=sig, value=5, operation_success="==") assert status.done is False sig.put(1) assert status.done is False sig.put(5) + status.wait(timeout=5) assert status.done is True sig.put(5) # Test with different operations - status = CompareStatus(signal=sig, value=5, operation="!=") + status = CompareStatus(signal=sig, value=5, operation_success="!=") assert status.done is False sig.put(5) assert status.done is False @@ -712,7 +713,7 @@ def test_utils_compare_status_number(): assert status.exception() is None sig.put(0) - status = CompareStatus(signal=sig, value=5, operation=">") + status = CompareStatus(signal=sig, value=5, operation_success=">") assert status.done is False sig.put(5) assert status.done is False @@ -721,11 +722,44 @@ def test_utils_compare_status_number(): assert status.success is True assert status.exception() is None + # Should raise + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation_success="==", failure_value=[10]) + with pytest.raises(ValueError): + sig.put(10) + status.wait() + assert status.done is True + assert status.success is False + assert isinstance(status.exception(), ValueError) + + # failure_operation + sig.put(0) + status = CompareStatus( + signal=sig, value=5, operation_success="==", failure_value=10, operation_failure=">" + ) + sig.put(10) + assert status.done is False + assert status.success is False + sig.put(11) + with pytest.raises(ValueError): + status.wait() + assert status.done is True + assert status.success is False + + # raise if array is returned + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation_success="==") + with pytest.raises(ValueError): + sig.put([1, 2, 3]) + status.wait(timeout=2) + assert status.done is True + assert status.success is False + def test_compare_status_string(): """Test CompareStatus with string values""" sig = Signal(name="test_signal", value="test") - status = CompareStatus(signal=sig, value="test", operation="==") + status = CompareStatus(signal=sig, value="test", operation_success="==") assert status.done is False sig.put("test1") assert status.done is False @@ -734,7 +768,7 @@ def test_compare_status_string(): sig.put("test") # Test with different operations - status = CompareStatus(signal=sig, value="test", operation="!=") + status = CompareStatus(signal=sig, value="test", operation_success="!=") assert status.done is False sig.put("test") assert status.done is False @@ -743,12 +777,6 @@ def test_compare_status_string(): assert status.success is True assert status.exception() is None - # Test with greater than operation - # Raises ValueError for strings - sig.put("a") - with pytest.raises(ValueError): - status = CompareStatus(signal=sig, value="b", operation=">") - def test_transition_status(): """Test TransitionStatus""" @@ -768,9 +796,9 @@ def test_transition_status(): assert status.success is True assert status.exception() is None - # Test strict=True, ra + # Test strict=True, failure_states sig.put(1) - status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4]) + status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4]) assert status.done is False sig.put(4) with pytest.raises(ValueError): @@ -854,7 +882,7 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro): """Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals""" signal = mock_epics_signal_ro - status = CompareStatus(signal=signal, value=5, operation="==") + status = CompareStatus(signal=signal, value=5, operation_success="==") assert status.done is False signal._read_pv.mock_data = 1 assert status.done is False