diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 0e7f1ea..644e9ae 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -357,6 +357,7 @@ class ExceptionStatus(CompareStatus): event_type=None, exception: Exception | None = None, ): + self._configured_exception = exception super().__init__( signal=signal, value=value, @@ -366,7 +367,6 @@ class ExceptionStatus(CompareStatus): run=run, event_type=event_type, ) - self._configured_exception = exception def _compare_callback(self, value: any, **kwargs) -> bool: try: diff --git a/tests/test_utils.py b/tests/test_utils.py index 8444d8d..f40536c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -881,6 +881,19 @@ def test_exception_status_andstatus_fails_early_with_custom_exception(): assert combined.success is False +def test_exception_status_with_exception(): + """Test that ExceptionStatus raises the specified exception when the condition is met.""" + sig = Signal(name="test_signal", value=0) + sig.put(1) + status = ExceptionStatus( + signal=sig, value=1, operation="==", exception=RuntimeError("Test signal reached 1") + ) + assert status.done is True + assert status.success is False + with pytest.raises(RuntimeError, match="Test signal reached 1"): + status.wait(timeout=1) + + def test_transition_status(): """Test TransitionStatus""" sig = Signal(name="test_signal", value=0)