mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-06 15:18:40 +01:00
refactor(status): Improve logic to set exceptions to allow to catch the error traceback
This commit is contained in:
@@ -6,7 +6,7 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
from typing import TYPE_CHECKING, Callable, Literal
|
||||||
|
|
||||||
from bec_lib.file_utils import get_full_path
|
from bec_lib.file_utils import get_full_path
|
||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
@@ -16,7 +16,6 @@ from ophyd.status import DeviceStatus as _DeviceStatus
|
|||||||
from ophyd.status import MoveStatus as _MoveStatus
|
from ophyd.status import MoveStatus as _MoveStatus
|
||||||
from ophyd.status import Status as _Status
|
from ophyd.status import Status as _Status
|
||||||
from ophyd.status import StatusBase as _StatusBase
|
from ophyd.status import StatusBase as _StatusBase
|
||||||
from ophyd.status import SubscriptionStatus as _SubscriptionStatus
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from bec_lib.messages import ScanStatusMessage
|
from bec_lib.messages import ScanStatusMessage
|
||||||
@@ -170,8 +169,9 @@ class SubscriptionStatus(StatusBase):
|
|||||||
try:
|
try:
|
||||||
success = self.callback(*args, **kwargs)
|
success = self.callback(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(e)
|
logger.error(f"Error in SubscriptionStatus callback: {e}")
|
||||||
raise
|
self.set_exception(e)
|
||||||
|
return
|
||||||
if success:
|
if success:
|
||||||
self.set_finished()
|
self.set_finished()
|
||||||
|
|
||||||
@@ -221,7 +221,7 @@ class CompareStatus(SubscriptionStatus):
|
|||||||
event_type=None,
|
event_type=None,
|
||||||
):
|
):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="):
|
if operation_success not in ("==", "!=") or operation_failure not in ("==", "!="):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
|
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
|
||||||
)
|
)
|
||||||
@@ -238,7 +238,7 @@ class CompareStatus(SubscriptionStatus):
|
|||||||
self._failure_values = []
|
self._failure_values = []
|
||||||
elif isinstance(failure_value, (float, int, str)):
|
elif isinstance(failure_value, (float, int, str)):
|
||||||
self._failure_values = [failure_value]
|
self._failure_values = [failure_value]
|
||||||
elif isinstance(failure_value, list):
|
elif isinstance(failure_value, (list, tuple)):
|
||||||
self._failure_values = failure_value
|
self._failure_values = failure_value
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -265,25 +265,18 @@ class CompareStatus(SubscriptionStatus):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
self.set_exception(
|
raise ValueError(f"List values are not supported. Received value: {value}")
|
||||||
ValueError(f"List values are not supported. Received value: {value}")
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
if any(
|
if any(
|
||||||
self.op_map[self._operation_failure](value, failure_value)
|
self.op_map[self._operation_failure](value, failure_value)
|
||||||
for failure_value in self._failure_values
|
for failure_value in self._failure_values
|
||||||
):
|
):
|
||||||
self.set_exception(
|
raise ValueError(
|
||||||
ValueError(
|
f"CompareStatus for signal {self._signal.name} "
|
||||||
f"CompareStatus for signal {self._signal.name} "
|
f"did not reach the desired state {self._operation_success} {self._value}. "
|
||||||
f"did not reach the desired state {self._operation_success} {self._value}. "
|
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
|
||||||
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return False
|
|
||||||
return self.op_map[self._operation_success](value, self._value)
|
return self.op_map[self._operation_success](value, self._value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any exception if the value comparison fails, e.g. value is numpy array
|
|
||||||
logger.error(f"Error in CompareStatus callback: {e}")
|
logger.error(f"Error in CompareStatus callback: {e}")
|
||||||
self.set_exception(e)
|
self.set_exception(e)
|
||||||
return False
|
return False
|
||||||
@@ -360,13 +353,10 @@ class TransitionStatus(SubscriptionStatus):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if value in self._failure_states:
|
if value in self._failure_states:
|
||||||
self.set_exception(
|
raise ValueError(
|
||||||
ValueError(
|
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
|
||||||
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
|
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
|
||||||
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return False
|
|
||||||
if self._index == 0:
|
if self._index == 0:
|
||||||
if value == self._transitions[0]:
|
if value == self._transitions[0]:
|
||||||
self._index += 1
|
self._index += 1
|
||||||
|
|||||||
@@ -900,6 +900,19 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
|
|||||||
assert status.success is True
|
assert status.success is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_status_raises_on_failed_comparison(mock_epics_signal_ro):
|
||||||
|
"""Test CompareStatus raises on failed comparison with EpicsSignalRO"""
|
||||||
|
|
||||||
|
signal = mock_epics_signal_ro
|
||||||
|
status = CompareStatus(
|
||||||
|
signal=signal, value=5, operation_success="==", failure_value=[np.array([10])]
|
||||||
|
)
|
||||||
|
assert status.done is False
|
||||||
|
signal._read_pv.mock_data = 1
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
status.wait(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"transitions, expected_done, expected_success",
|
"transitions, expected_done, expected_success",
|
||||||
[
|
[
|
||||||
@@ -946,6 +959,7 @@ def test_patched_status_objects():
|
|||||||
st = StatusBase()
|
st = StatusBase()
|
||||||
st2 = StatusBase()
|
st2 = StatusBase()
|
||||||
and_st = st & st2
|
and_st = st & st2
|
||||||
|
assert st in and_st
|
||||||
assert isinstance(and_st, AndStatus)
|
assert isinstance(and_st, AndStatus)
|
||||||
st.set_exception(ValueError("test error"))
|
st.set_exception(ValueError("test error"))
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|||||||
Reference in New Issue
Block a user