refactor(status): Improve logic to set exceptions to allow to catch the error traceback

This commit is contained in:
2025-11-28 17:08:17 +01:00
committed by Christian Appel
parent b918f1851c
commit 13d658241a
2 changed files with 28 additions and 24 deletions

View File

@@ -6,7 +6,7 @@ import threading
import traceback import traceback
import uuid import uuid
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Literal from typing import TYPE_CHECKING, Callable, Literal
from bec_lib.file_utils import get_full_path from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
@@ -16,7 +16,6 @@ from ophyd.status import DeviceStatus as _DeviceStatus
from ophyd.status import MoveStatus as _MoveStatus from ophyd.status import MoveStatus as _MoveStatus
from ophyd.status import Status as _Status from ophyd.status import Status as _Status
from ophyd.status import StatusBase as _StatusBase from ophyd.status import StatusBase as _StatusBase
from ophyd.status import SubscriptionStatus as _SubscriptionStatus
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from bec_lib.messages import ScanStatusMessage from bec_lib.messages import ScanStatusMessage
@@ -170,8 +169,9 @@ class SubscriptionStatus(StatusBase):
try: try:
success = self.callback(*args, **kwargs) success = self.callback(*args, **kwargs)
except Exception as e: except Exception as e:
self.log.error(e) logger.error(f"Error in SubscriptionStatus callback: {e}")
raise self.set_exception(e)
return
if success: if success:
self.set_finished() self.set_finished()
@@ -221,7 +221,7 @@ class CompareStatus(SubscriptionStatus):
event_type=None, event_type=None,
): ):
if isinstance(value, str): if isinstance(value, str):
if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="): if operation_success not in ("==", "!=") or operation_failure not in ("==", "!="):
raise ValueError( raise ValueError(
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='." f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
) )
@@ -238,7 +238,7 @@ class CompareStatus(SubscriptionStatus):
self._failure_values = [] self._failure_values = []
elif isinstance(failure_value, (float, int, str)): elif isinstance(failure_value, (float, int, str)):
self._failure_values = [failure_value] self._failure_values = [failure_value]
elif isinstance(failure_value, list): elif isinstance(failure_value, (list, tuple)):
self._failure_values = failure_value self._failure_values = failure_value
else: else:
raise ValueError( raise ValueError(
@@ -265,25 +265,18 @@ class CompareStatus(SubscriptionStatus):
""" """
try: try:
if isinstance(value, list): if isinstance(value, list):
self.set_exception( raise ValueError(f"List values are not supported. Received value: {value}")
ValueError(f"List values are not supported. Received value: {value}")
)
return False
if any( if any(
self.op_map[self._operation_failure](value, failure_value) self.op_map[self._operation_failure](value, failure_value)
for failure_value in self._failure_values for failure_value in self._failure_values
): ):
self.set_exception( raise ValueError(
ValueError( f"CompareStatus for signal {self._signal.name} "
f"CompareStatus for signal {self._signal.name} " f"did not reach the desired state {self._operation_success} {self._value}. "
f"did not reach the desired state {self._operation_success} {self._value}. " f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
)
) )
return False
return self.op_map[self._operation_success](value, self._value) return self.op_map[self._operation_success](value, self._value)
except Exception as e: except Exception as e:
# Catch any exception if the value comparison fails, e.g. value is numpy array
logger.error(f"Error in CompareStatus callback: {e}") logger.error(f"Error in CompareStatus callback: {e}")
self.set_exception(e) self.set_exception(e)
return False return False
@@ -360,13 +353,10 @@ class TransitionStatus(SubscriptionStatus):
""" """
try: try:
if value in self._failure_states: if value in self._failure_states:
self.set_exception( raise ValueError(
ValueError( f"Transition Status for {self._signal.name} resulted in a value: {value}. "
f"Transition Status for {self._signal.name} resulted in a value: {value}. " f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
)
) )
return False
if self._index == 0: if self._index == 0:
if value == self._transitions[0]: if value == self._transitions[0]:
self._index += 1 self._index += 1

View File

@@ -900,6 +900,19 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
assert status.success is True assert status.success is True
def test_compare_status_raises_on_failed_comparison(mock_epics_signal_ro):
"""Test CompareStatus raises on failed comparison with EpicsSignalRO"""
signal = mock_epics_signal_ro
status = CompareStatus(
signal=signal, value=5, operation_success="==", failure_value=[np.array([10])]
)
assert status.done is False
signal._read_pv.mock_data = 1
with pytest.raises(Exception):
status.wait(timeout=5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"transitions, expected_done, expected_success", "transitions, expected_done, expected_success",
[ [
@@ -946,6 +959,7 @@ def test_patched_status_objects():
st = StatusBase() st = StatusBase()
st2 = StatusBase() st2 = StatusBase()
and_st = st & st2 and_st = st & st2
assert st in and_st
assert isinstance(and_st, AndStatus) assert isinstance(and_st, AndStatus)
st.set_exception(ValueError("test error")) st.set_exception(ValueError("test error"))
with pytest.raises(ValueError): with pytest.raises(ValueError):