mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-09 00:28:40 +01:00
fix(bec-status): Refactor CompareStatus and TransitionStatus
This commit is contained in:
@@ -48,44 +48,62 @@ OP_MAP = {
|
||||
|
||||
class CompareStatus(SubscriptionStatus):
|
||||
"""
|
||||
Status class to compare a signal value against a given value.
|
||||
Status to compare a signal value against a given value.
|
||||
The comparison is done using the specified operation, which can be one of
|
||||
'==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed.
|
||||
The status is finished when the comparison is true.
|
||||
One may also define a value or list of values that will result in an exception if encountered.
|
||||
The status is finished when the comparison is either true or an exception is raised.
|
||||
|
||||
Args:
|
||||
signal: The device signal to compare.
|
||||
value: The value to compare against.
|
||||
operation: The operation to use for comparison. Defaults to '=='.
|
||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
run: Whether to run the status callback on creation or not. Defaults to True.
|
||||
signal (Signal): The signal to monitor.
|
||||
value (float | int | str): The target value to compare against.
|
||||
operation_success (str, optional): The comparison operation for success. Defaults to '=='.
|
||||
failure_value (float | int | str | list[float | int | str] | None, optional):
|
||||
A value or list of values that will trigger an exception if encountered. Defaults to None.
|
||||
operation_failure (str, optional): The comparison operation for failure values. Defaults to '=='.
|
||||
event_type (int, optional): The event type to subscribe to. Defaults to None.
|
||||
timeout (float, optional): Timeout for the status. Defaults to None.
|
||||
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
|
||||
run (bool, optional): Whether to start the status immediately. Defaults to True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
signal: "Signal",
|
||||
value: float | int | str,
|
||||
*,
|
||||
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||
event_type=None,
|
||||
operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||
failure_value: float | int | str | list[float | int | str] | None = None,
|
||||
operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||
timeout: float = None,
|
||||
settle_time: float = 0,
|
||||
run: bool = True,
|
||||
event_type=None,
|
||||
):
|
||||
if isinstance(value, str):
|
||||
if operation not in ("==", "!="):
|
||||
if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
|
||||
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
|
||||
)
|
||||
if operation not in ("==", "!=", "<", "<=", ">", ">="):
|
||||
if operation_success not in ("==", "!=", "<", "<=", ">", ">="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
||||
f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
||||
)
|
||||
self._signal = signal
|
||||
self._value = value
|
||||
self._operation = operation
|
||||
self._operation_success = operation_success
|
||||
self._operation_failure = operation_failure
|
||||
self.op_map = OP_MAP
|
||||
if failure_value is None:
|
||||
self._failure_values = []
|
||||
elif isinstance(failure_value, (float, int, str)):
|
||||
self._failure_values = [failure_value]
|
||||
elif isinstance(failure_value, list):
|
||||
self._failure_values = failure_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
|
||||
)
|
||||
super().__init__(
|
||||
device=signal,
|
||||
callback=self._compare_callback,
|
||||
@@ -95,49 +113,91 @@ class CompareStatus(SubscriptionStatus):
|
||||
run=run,
|
||||
)
|
||||
|
||||
def _compare_callback(self, value, **kwargs) -> bool:
|
||||
"""Callback for subscription status"""
|
||||
return OP_MAP[self._operation](value, self._value)
|
||||
def _compare_callback(self, value: any, **kwargs) -> bool:
|
||||
"""
|
||||
Callback for subscription status
|
||||
|
||||
Args:
|
||||
value (any): Current value of the signal
|
||||
|
||||
Returns:
|
||||
bool: True if comparison is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
self.set_exception(
|
||||
ValueError(f"List values are not supported. Received value: {value}")
|
||||
)
|
||||
return False
|
||||
if any(
|
||||
self.op_map[self._operation_failure](value, failure_value)
|
||||
for failure_value in self._failure_values
|
||||
):
|
||||
self.set_exception(
|
||||
ValueError(
|
||||
f"CompareStatus for signal {self._signal.name} "
|
||||
f"did not reach the desired state {self._operation_success} {self._value}. "
|
||||
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
|
||||
)
|
||||
)
|
||||
return False
|
||||
return self.op_map[self._operation_success](value, self._value)
|
||||
except Exception as e:
|
||||
# Catch any exception if the value comparison fails, e.g. value is numpy array
|
||||
logger.error(f"Error in CompareStatus callback: {e}")
|
||||
self.set_exception(e)
|
||||
return False
|
||||
|
||||
|
||||
class TransitionStatus(SubscriptionStatus):
|
||||
"""
|
||||
Status class to monitor transitions of a signal value through a list of specified transitions.
|
||||
Status to monitor transitions of a signal value through a list of specified transitions.
|
||||
The status is finished when all transitions have been observed in order. The keyword argument
|
||||
`strict` determines whether the transitions must occur in strict order or not.
|
||||
If `raise_states` is provided, the status will raise an exception if the signal value matches
|
||||
any of the values in `raise_states`.
|
||||
`strict` determines whether the transitions must occur in strict order or not. The strict option
|
||||
only becomes relevant once the first transition has been observed.
|
||||
If `failure_states` is provided, the status will raise an exception if the signal value matches
|
||||
any of the values in `failure_states`.
|
||||
|
||||
Args:
|
||||
signal: The device signal to monitor.
|
||||
transitions: A list of values to transition through.
|
||||
strict: Whether the transitions must occur in strict order. Defaults to True.
|
||||
raise_states: A list of values that will raise an exception if encountered. Defaults to None.
|
||||
run: Whether to run the status callback on creation or not. Defaults to True.
|
||||
event_type: The type of event to trigger on transition. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
signal (Signal): The signal to monitor.
|
||||
transitions (list[float | int | str]): List of values representing the transitions to observe.
|
||||
strict (bool, optional): Whether to enforce strict order of transitions. Defaults to True.
|
||||
failure_states (list[float | int | str] | None, optional):
|
||||
A list of values that will trigger an exception if encountered. Defaults to None.
|
||||
run (bool, optional): Whether to start the status immediately. Defaults to True.
|
||||
event_type (int, optional): The event type to subscribe to. Defaults to None.
|
||||
timeout (float, optional): Timeout for the status. Defaults to None.
|
||||
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
|
||||
|
||||
Notes:
|
||||
The 'strict' option does not raise if transitions are observed which are out of order.
|
||||
It only determines whether a transition is accepted if it is observed from the
|
||||
previous value in the list of transitions to the next value.
|
||||
For example, with strict=True and transitions=[1, 2, 3], the sequence
|
||||
0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status
|
||||
will not complete. With strict=False, both sequences are accepted.
|
||||
However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted.
|
||||
To raise an exception if an out-of-order transition is observed, use the
|
||||
`failure_states` keyword argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
signal: "Signal",
|
||||
transitions: list[float | int | str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
raise_states: list[float | int | str] | None = None,
|
||||
failure_states: list[float | int | str] | None = None,
|
||||
run: bool = True,
|
||||
event_type=None,
|
||||
timeout: float = None,
|
||||
settle_time: float = 0,
|
||||
event_type=None,
|
||||
):
|
||||
self._signal = signal
|
||||
if not isinstance(transitions, list):
|
||||
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
|
||||
self._transitions = transitions
|
||||
self._transitions = tuple(transitions)
|
||||
self._index = 0
|
||||
self._strict = strict
|
||||
self._raise_states = raise_states if raise_states else []
|
||||
self._failure_states = failure_states if failure_states else []
|
||||
super().__init__(
|
||||
device=signal,
|
||||
callback=self._compare_callback,
|
||||
@@ -147,13 +207,23 @@ class TransitionStatus(SubscriptionStatus):
|
||||
run=run,
|
||||
)
|
||||
|
||||
def _compare_callback(self, old_value, value, **kwargs) -> bool:
|
||||
"""Callback for subscription Status"""
|
||||
if value in self._raise_states:
|
||||
def _compare_callback(self, old_value: any, value: any, **kwargs) -> bool:
|
||||
"""
|
||||
Callback for subscription Status
|
||||
|
||||
Args:
|
||||
old_value (any): Previous value of the signal
|
||||
value (any): Current value of the signal
|
||||
|
||||
Returns:
|
||||
bool: True if all transitions have been observed, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if value in self._failure_states:
|
||||
self.set_exception(
|
||||
ValueError(
|
||||
f"Transition raised an exception: {value}. "
|
||||
f"Expected transitions: {self._transitions}."
|
||||
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
|
||||
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
|
||||
)
|
||||
)
|
||||
return False
|
||||
@@ -170,11 +240,12 @@ class TransitionStatus(SubscriptionStatus):
|
||||
else:
|
||||
if value == self._transitions[self._index]:
|
||||
self._index += 1
|
||||
return self._is_finished()
|
||||
|
||||
def _is_finished(self) -> bool:
|
||||
"""Check if the status is finished"""
|
||||
return self._index >= len(self._transitions)
|
||||
except Exception as e:
|
||||
# Catch any exception if the value comparison fails, e.g. value is numpy array
|
||||
logger.error(f"Error in TransitionStatus callback: {e}")
|
||||
self.set_exception(e)
|
||||
return False
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
|
||||
@@ -691,18 +691,19 @@ def test_utils_progress_signal():
|
||||
|
||||
|
||||
def test_utils_compare_status_number():
|
||||
"""Test CompareStatus"""
|
||||
"""Test CompareStatus with different operations."""
|
||||
sig = Signal(name="test_signal", value=0)
|
||||
status = CompareStatus(signal=sig, value=5, operation="==")
|
||||
status = CompareStatus(signal=sig, value=5, operation_success="==")
|
||||
assert status.done is False
|
||||
sig.put(1)
|
||||
assert status.done is False
|
||||
sig.put(5)
|
||||
status.wait(timeout=5)
|
||||
assert status.done is True
|
||||
|
||||
sig.put(5)
|
||||
# Test with different operations
|
||||
status = CompareStatus(signal=sig, value=5, operation="!=")
|
||||
status = CompareStatus(signal=sig, value=5, operation_success="!=")
|
||||
assert status.done is False
|
||||
sig.put(5)
|
||||
assert status.done is False
|
||||
@@ -712,7 +713,7 @@ def test_utils_compare_status_number():
|
||||
assert status.exception() is None
|
||||
|
||||
sig.put(0)
|
||||
status = CompareStatus(signal=sig, value=5, operation=">")
|
||||
status = CompareStatus(signal=sig, value=5, operation_success=">")
|
||||
assert status.done is False
|
||||
sig.put(5)
|
||||
assert status.done is False
|
||||
@@ -721,11 +722,44 @@ def test_utils_compare_status_number():
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
# Should raise
|
||||
sig.put(0)
|
||||
status = CompareStatus(signal=sig, value=5, operation_success="==", failure_value=[10])
|
||||
with pytest.raises(ValueError):
|
||||
sig.put(10)
|
||||
status.wait()
|
||||
assert status.done is True
|
||||
assert status.success is False
|
||||
assert isinstance(status.exception(), ValueError)
|
||||
|
||||
# failure_operation
|
||||
sig.put(0)
|
||||
status = CompareStatus(
|
||||
signal=sig, value=5, operation_success="==", failure_value=10, operation_failure=">"
|
||||
)
|
||||
sig.put(10)
|
||||
assert status.done is False
|
||||
assert status.success is False
|
||||
sig.put(11)
|
||||
with pytest.raises(ValueError):
|
||||
status.wait()
|
||||
assert status.done is True
|
||||
assert status.success is False
|
||||
|
||||
# raise if array is returned
|
||||
sig.put(0)
|
||||
status = CompareStatus(signal=sig, value=5, operation_success="==")
|
||||
with pytest.raises(ValueError):
|
||||
sig.put([1, 2, 3])
|
||||
status.wait(timeout=2)
|
||||
assert status.done is True
|
||||
assert status.success is False
|
||||
|
||||
|
||||
def test_compare_status_string():
|
||||
"""Test CompareStatus with string values"""
|
||||
sig = Signal(name="test_signal", value="test")
|
||||
status = CompareStatus(signal=sig, value="test", operation="==")
|
||||
status = CompareStatus(signal=sig, value="test", operation_success="==")
|
||||
assert status.done is False
|
||||
sig.put("test1")
|
||||
assert status.done is False
|
||||
@@ -734,7 +768,7 @@ def test_compare_status_string():
|
||||
|
||||
sig.put("test")
|
||||
# Test with different operations
|
||||
status = CompareStatus(signal=sig, value="test", operation="!=")
|
||||
status = CompareStatus(signal=sig, value="test", operation_success="!=")
|
||||
assert status.done is False
|
||||
sig.put("test")
|
||||
assert status.done is False
|
||||
@@ -743,12 +777,6 @@ def test_compare_status_string():
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
# Test with greater than operation
|
||||
# Raises ValueError for strings
|
||||
sig.put("a")
|
||||
with pytest.raises(ValueError):
|
||||
status = CompareStatus(signal=sig, value="b", operation=">")
|
||||
|
||||
|
||||
def test_transition_status():
|
||||
"""Test TransitionStatus"""
|
||||
@@ -768,9 +796,9 @@ def test_transition_status():
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
# Test strict=True, ra
|
||||
# Test strict=True, failure_states
|
||||
sig.put(1)
|
||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4])
|
||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4])
|
||||
assert status.done is False
|
||||
sig.put(4)
|
||||
with pytest.raises(ValueError):
|
||||
@@ -854,7 +882,7 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
|
||||
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
|
||||
|
||||
signal = mock_epics_signal_ro
|
||||
status = CompareStatus(signal=signal, value=5, operation="==")
|
||||
status = CompareStatus(signal=signal, value=5, operation_success="==")
|
||||
assert status.done is False
|
||||
signal._read_pv.mock_data = 1
|
||||
assert status.done is False
|
||||
|
||||
Reference in New Issue
Block a user