diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 3c5d06a..4794e2b 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -6,7 +6,7 @@ import threading import traceback import uuid from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Callable, Literal from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger @@ -16,7 +16,6 @@ from ophyd.status import DeviceStatus as _DeviceStatus from ophyd.status import MoveStatus as _MoveStatus from ophyd.status import Status as _Status from ophyd.status import StatusBase as _StatusBase -from ophyd.status import SubscriptionStatus as _SubscriptionStatus if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import ScanStatusMessage @@ -170,8 +169,9 @@ class SubscriptionStatus(StatusBase): try: success = self.callback(*args, **kwargs) except Exception as e: - self.log.error(e) - raise + logger.error(f"Error in SubscriptionStatus callback: {e}") + self.set_exception(e) + return if success: self.set_finished() @@ -221,7 +221,7 @@ class CompareStatus(SubscriptionStatus): event_type=None, ): if isinstance(value, str): - if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="): + if operation_success not in ("==", "!=") or operation_failure not in ("==", "!="): raise ValueError( f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='." ) @@ -238,7 +238,7 @@ class CompareStatus(SubscriptionStatus): self._failure_values = [] elif isinstance(failure_value, (float, int, str)): self._failure_values = [failure_value] - elif isinstance(failure_value, list): + elif isinstance(failure_value, (list, tuple)): self._failure_values = failure_value else: raise ValueError( @@ -265,25 +265,18 @@ class CompareStatus(SubscriptionStatus): """ try: if isinstance(value, list): - self.set_exception( - ValueError(f"List values are not supported. Received value: {value}") - ) - return False + raise ValueError(f"List values are not supported. Received value: {value}") if any( self.op_map[self._operation_failure](value, failure_value) for failure_value in self._failure_values ): - self.set_exception( - ValueError( - f"CompareStatus for signal {self._signal.name} " - f"did not reach the desired state {self._operation_success} {self._value}. " - f"But instead reached {value}, which is in list of failure values: {self._failure_values}" - ) + raise ValueError( + f"CompareStatus for signal {self._signal.name} " + f"did not reach the desired state {self._operation_success} {self._value}. " + f"But instead reached {value}, which is in list of failure values: {self._failure_values}" ) - return False return self.op_map[self._operation_success](value, self._value) except Exception as e: - # Catch any exception if the value comparison fails, e.g. value is numpy array logger.error(f"Error in CompareStatus callback: {e}") self.set_exception(e) return False @@ -360,13 +353,10 @@ class TransitionStatus(SubscriptionStatus): """ try: if value in self._failure_states: - self.set_exception( - ValueError( - f"Transition Status for {self._signal.name} resulted in a value: {value}. " - f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}." - ) + raise ValueError( + f"Transition Status for {self._signal.name} resulted in a value: {value}. " + f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}." ) - return False if self._index == 0: if value == self._transitions[0]: self._index += 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 2f2d637..aec8141 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -900,6 +900,19 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro): assert status.success is True +def test_compare_status_raises_on_failed_comparison(mock_epics_signal_ro): + """Test CompareStatus raises on failed comparison with EpicsSignalRO""" + + signal = mock_epics_signal_ro + status = CompareStatus( + signal=signal, value=5, operation_success="==", failure_value=[np.array([10])] + ) + assert status.done is False + signal._read_pv.mock_data = 1 + with pytest.raises(Exception): + status.wait(timeout=5) + + @pytest.mark.parametrize( "transitions, expected_done, expected_success", [ @@ -946,6 +959,7 @@ def test_patched_status_objects(): st = StatusBase() st2 = StatusBase() and_st = st & st2 + assert st in and_st assert isinstance(and_st, AndStatus) st.set_exception(ValueError("test error")) with pytest.raises(ValueError):