fix(bec-status): Refactor CompareStatus and TransitionStatus

This commit is contained in:
2025-11-26 13:41:58 +01:00
committed by Christian Appel
parent e95d46a77d
commit 58d4a5141f
2 changed files with 178 additions and 79 deletions

View File

@@ -48,44 +48,62 @@ OP_MAP = {
class CompareStatus(SubscriptionStatus):
"""
Status class to compare a signal value against a given value.
Status to compare a signal value against a given value.
The comparison is done using the specified operation, which can be one of
'==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed.
The status is finished when the comparison is true.
One may also define a value or list of values that will result in an exception if encountered.
The status is finished when the comparison is either true or an exception is raised.
Args:
signal: The device signal to compare.
value: The value to compare against.
operation: The operation to use for comparison. Defaults to '=='.
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
timeout: The timeout for the status. Defaults to None (indefinite).
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
run: Whether to run the status callback on creation or not. Defaults to True.
signal (Signal): The signal to monitor.
value (float | int | str): The target value to compare against.
operation_success (str, optional): The comparison operation for success. Defaults to '=='.
failure_value (float | int | str | list[float | int | str] | None, optional):
A value or list of values that will trigger an exception if encountered. Defaults to None.
operation_failure (str, optional): The comparison operation for failure values. Defaults to '=='.
event_type (int, optional): The event type to subscribe to. Defaults to None.
timeout (float, optional): Timeout for the status. Defaults to None.
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
run (bool, optional): Whether to start the status immediately. Defaults to True
"""
def __init__(
self,
signal: Signal,
signal: "Signal",
value: float | int | str,
*,
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
event_type=None,
operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
failure_value: float | int | str | list[float | int | str] | None = None,
operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
timeout: float = None,
settle_time: float = 0,
run: bool = True,
event_type=None,
):
if isinstance(value, str):
if operation not in ("==", "!="):
if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="):
raise ValueError(
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
)
if operation not in ("==", "!=", "<", "<=", ">", ">="):
if operation_success not in ("==", "!=", "<", "<=", ">", ">="):
raise ValueError(
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='."
)
self._signal = signal
self._value = value
self._operation = operation
self._operation_success = operation_success
self._operation_failure = operation_failure
self.op_map = OP_MAP
if failure_value is None:
self._failure_values = []
elif isinstance(failure_value, (float, int, str)):
self._failure_values = [failure_value]
elif isinstance(failure_value, list):
self._failure_values = failure_value
else:
raise ValueError(
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
)
super().__init__(
device=signal,
callback=self._compare_callback,
@@ -95,49 +113,91 @@ class CompareStatus(SubscriptionStatus):
run=run,
)
def _compare_callback(self, value, **kwargs) -> bool:
"""Callback for subscription status"""
return OP_MAP[self._operation](value, self._value)
def _compare_callback(self, value: any, **kwargs) -> bool:
"""
Callback for subscription status
Args:
value (any): Current value of the signal
Returns:
bool: True if comparison is successful, False otherwise.
"""
try:
if isinstance(value, list):
self.set_exception(
ValueError(f"List values are not supported. Received value: {value}")
)
return False
if any(
self.op_map[self._operation_failure](value, failure_value)
for failure_value in self._failure_values
):
self.set_exception(
ValueError(
f"CompareStatus for signal {self._signal.name} "
f"did not reach the desired state {self._operation_success} {self._value}. "
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
)
)
return False
return self.op_map[self._operation_success](value, self._value)
except Exception as e:
# Catch any exception if the value comparison fails, e.g. value is numpy array
logger.error(f"Error in CompareStatus callback: {e}")
self.set_exception(e)
return False
class TransitionStatus(SubscriptionStatus):
"""
Status class to monitor transitions of a signal value through a list of specified transitions.
Status to monitor transitions of a signal value through a list of specified transitions.
The status is finished when all transitions have been observed in order. The keyword argument
`strict` determines whether the transitions must occur in strict order or not.
If `raise_states` is provided, the status will raise an exception if the signal value matches
any of the values in `raise_states`.
`strict` determines whether the transitions must occur in strict order or not. The strict option
only becomes relevant once the first transition has been observed.
If `failure_states` is provided, the status will raise an exception if the signal value matches
any of the values in `failure_states`.
Args:
signal: The device signal to monitor.
transitions: A list of values to transition through.
strict: Whether the transitions must occur in strict order. Defaults to True.
raise_states: A list of values that will raise an exception if encountered. Defaults to None.
run: Whether to run the status callback on creation or not. Defaults to True.
event_type: The type of event to trigger on transition. Defaults to None (default sub).
timeout: The timeout for the status. Defaults to None (indefinite).
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
signal (Signal): The signal to monitor.
transitions (list[float | int | str]): List of values representing the transitions to observe.
strict (bool, optional): Whether to enforce strict order of transitions. Defaults to True.
failure_states (list[float | int | str] | None, optional):
A list of values that will trigger an exception if encountered. Defaults to None.
run (bool, optional): Whether to start the status immediately. Defaults to True.
event_type (int, optional): The event type to subscribe to. Defaults to None.
timeout (float, optional): Timeout for the status. Defaults to None.
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
Notes:
The 'strict' option does not raise if transitions are observed which are out of order.
It only determines whether a transition is accepted if it is observed from the
previous value in the list of transitions to the next value.
For example, with strict=True and transitions=[1, 2, 3], the sequence
0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status
will not complete. With strict=False, both sequences are accepted.
However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted.
To raise an exception if an out-of-order transition is observed, use the
`failure_states` keyword argument.
"""
def __init__(
self,
signal: Signal,
signal: "Signal",
transitions: list[float | int | str],
*,
strict: bool = True,
raise_states: list[float | int | str] | None = None,
failure_states: list[float | int | str] | None = None,
run: bool = True,
event_type=None,
timeout: float = None,
settle_time: float = 0,
event_type=None,
):
self._signal = signal
if not isinstance(transitions, list):
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
self._transitions = transitions
self._transitions = tuple(transitions)
self._index = 0
self._strict = strict
self._raise_states = raise_states if raise_states else []
self._failure_states = failure_states if failure_states else []
super().__init__(
device=signal,
callback=self._compare_callback,
@@ -147,34 +207,45 @@ class TransitionStatus(SubscriptionStatus):
run=run,
)
def _compare_callback(self, old_value, value, **kwargs) -> bool:
"""Callback for subscription Status"""
if value in self._raise_states:
self.set_exception(
ValueError(
f"Transition raised an exception: {value}. "
f"Expected transitions: {self._transitions}."
def _compare_callback(self, old_value: any, value: any, **kwargs) -> bool:
"""
Callback for subscription Status
Args:
old_value (any): Previous value of the signal
value (any): Current value of the signal
Returns:
bool: True if all transitions have been observed, False otherwise.
"""
try:
if value in self._failure_states:
self.set_exception(
ValueError(
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
)
)
)
return False
if self._index == 0:
if value == self._transitions[0]:
self._index += 1
else:
if self._strict:
if (
old_value == self._transitions[self._index - 1]
and value == self._transitions[self._index]
):
return False
if self._index == 0:
if value == self._transitions[0]:
self._index += 1
else:
if value == self._transitions[self._index]:
self._index += 1
return self._is_finished()
def _is_finished(self) -> bool:
"""Check if the status is finished"""
return self._index >= len(self._transitions)
if self._strict:
if (
old_value == self._transitions[self._index - 1]
and value == self._transitions[self._index]
):
self._index += 1
else:
if value == self._transitions[self._index]:
self._index += 1
return self._index >= len(self._transitions)
except Exception as e:
# Catch any exception if the value comparison fails, e.g. value is numpy array
logger.error(f"Error in TransitionStatus callback: {e}")
self.set_exception(e)
return False
class TaskState(str, Enum):

View File

@@ -691,18 +691,19 @@ def test_utils_progress_signal():
def test_utils_compare_status_number():
"""Test CompareStatus"""
"""Test CompareStatus with different operations."""
sig = Signal(name="test_signal", value=0)
status = CompareStatus(signal=sig, value=5, operation="==")
status = CompareStatus(signal=sig, value=5, operation_success="==")
assert status.done is False
sig.put(1)
assert status.done is False
sig.put(5)
status.wait(timeout=5)
assert status.done is True
sig.put(5)
# Test with different operations
status = CompareStatus(signal=sig, value=5, operation="!=")
status = CompareStatus(signal=sig, value=5, operation_success="!=")
assert status.done is False
sig.put(5)
assert status.done is False
@@ -712,7 +713,7 @@ def test_utils_compare_status_number():
assert status.exception() is None
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation=">")
status = CompareStatus(signal=sig, value=5, operation_success=">")
assert status.done is False
sig.put(5)
assert status.done is False
@@ -721,11 +722,44 @@ def test_utils_compare_status_number():
assert status.success is True
assert status.exception() is None
# Should raise
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation_success="==", failure_value=[10])
with pytest.raises(ValueError):
sig.put(10)
status.wait()
assert status.done is True
assert status.success is False
assert isinstance(status.exception(), ValueError)
# failure_operation
sig.put(0)
status = CompareStatus(
signal=sig, value=5, operation_success="==", failure_value=10, operation_failure=">"
)
sig.put(10)
assert status.done is False
assert status.success is False
sig.put(11)
with pytest.raises(ValueError):
status.wait()
assert status.done is True
assert status.success is False
# raise if array is returned
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation_success="==")
with pytest.raises(ValueError):
sig.put([1, 2, 3])
status.wait(timeout=2)
assert status.done is True
assert status.success is False
def test_compare_status_string():
"""Test CompareStatus with string values"""
sig = Signal(name="test_signal", value="test")
status = CompareStatus(signal=sig, value="test", operation="==")
status = CompareStatus(signal=sig, value="test", operation_success="==")
assert status.done is False
sig.put("test1")
assert status.done is False
@@ -734,7 +768,7 @@ def test_compare_status_string():
sig.put("test")
# Test with different operations
status = CompareStatus(signal=sig, value="test", operation="!=")
status = CompareStatus(signal=sig, value="test", operation_success="!=")
assert status.done is False
sig.put("test")
assert status.done is False
@@ -743,12 +777,6 @@ def test_compare_status_string():
assert status.success is True
assert status.exception() is None
# Test with greater than operation
# Raises ValueError for strings
sig.put("a")
with pytest.raises(ValueError):
status = CompareStatus(signal=sig, value="b", operation=">")
def test_transition_status():
"""Test TransitionStatus"""
@@ -768,9 +796,9 @@ def test_transition_status():
assert status.success is True
assert status.exception() is None
# Test strict=True, ra
# Test strict=True, failure_states
sig.put(1)
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4])
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4])
assert status.done is False
sig.put(4)
with pytest.raises(ValueError):
@@ -854,7 +882,7 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
signal = mock_epics_signal_ro
status = CompareStatus(signal=signal, value=5, operation="==")
status = CompareStatus(signal=signal, value=5, operation_success="==")
assert status.done is False
signal._read_pv.mock_data = 1
assert status.done is False