fix(bec-status): Refactor CompareStatus and TransitionStatus

This commit is contained in:
2025-11-26 13:41:58 +01:00
committed by Christian Appel
parent e95d46a77d
commit 58d4a5141f
2 changed files with 178 additions and 79 deletions

View File

@@ -48,44 +48,62 @@ OP_MAP = {
class CompareStatus(SubscriptionStatus): class CompareStatus(SubscriptionStatus):
""" """
Status class to compare a signal value against a given value. Status to compare a signal value against a given value.
The comparison is done using the specified operation, which can be one of The comparison is done using the specified operation, which can be one of
'==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed. '==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed.
The status is finished when the comparison is true. One may also define a value or list of values that will result in an exception if encountered.
The status is finished when the comparison is either true or an exception is raised.
Args: Args:
signal: The device signal to compare. signal (Signal): The signal to monitor.
value: The value to compare against. value (float | int | str): The target value to compare against.
operation: The operation to use for comparison. Defaults to '=='. operation_success (str, optional): The comparison operation for success. Defaults to '=='.
event_type: The type of event to trigger on comparison. Defaults to None (default sub). failure_value (float | int | str | list[float | int | str] | None, optional):
timeout: The timeout for the status. Defaults to None (indefinite). A value or list of values that will trigger an exception if encountered. Defaults to None.
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. operation_failure (str, optional): The comparison operation for failure values. Defaults to '=='.
run: Whether to run the status callback on creation or not. Defaults to True. event_type (int, optional): The event type to subscribe to. Defaults to None.
timeout (float, optional): Timeout for the status. Defaults to None.
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
run (bool, optional): Whether to start the status immediately. Defaults to True
""" """
def __init__( def __init__(
self, self,
signal: Signal, signal: "Signal",
value: float | int | str, value: float | int | str,
*, *,
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
event_type=None, failure_value: float | int | str | list[float | int | str] | None = None,
operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
timeout: float = None, timeout: float = None,
settle_time: float = 0, settle_time: float = 0,
run: bool = True, run: bool = True,
event_type=None,
): ):
if isinstance(value, str): if isinstance(value, str):
if operation not in ("==", "!="): if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="):
raise ValueError( raise ValueError(
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='." f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
) )
if operation not in ("==", "!=", "<", "<=", ">", ">="): if operation_success not in ("==", "!=", "<", "<=", ">", ">="):
raise ValueError( raise ValueError(
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='."
) )
self._signal = signal self._signal = signal
self._value = value self._value = value
self._operation = operation self._operation_success = operation_success
self._operation_failure = operation_failure
self.op_map = OP_MAP
if failure_value is None:
self._failure_values = []
elif isinstance(failure_value, (float, int, str)):
self._failure_values = [failure_value]
elif isinstance(failure_value, list):
self._failure_values = failure_value
else:
raise ValueError(
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
)
super().__init__( super().__init__(
device=signal, device=signal,
callback=self._compare_callback, callback=self._compare_callback,
@@ -95,49 +113,91 @@ class CompareStatus(SubscriptionStatus):
run=run, run=run,
) )
def _compare_callback(self, value, **kwargs) -> bool: def _compare_callback(self, value: any, **kwargs) -> bool:
"""Callback for subscription status""" """
return OP_MAP[self._operation](value, self._value) Callback for subscription status
Args:
value (any): Current value of the signal
Returns:
bool: True if comparison is successful, False otherwise.
"""
try:
if isinstance(value, list):
self.set_exception(
ValueError(f"List values are not supported. Received value: {value}")
)
return False
if any(
self.op_map[self._operation_failure](value, failure_value)
for failure_value in self._failure_values
):
self.set_exception(
ValueError(
f"CompareStatus for signal {self._signal.name} "
f"did not reach the desired state {self._operation_success} {self._value}. "
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
)
)
return False
return self.op_map[self._operation_success](value, self._value)
except Exception as e:
# Catch any exception if the value comparison fails, e.g. value is numpy array
logger.error(f"Error in CompareStatus callback: {e}")
self.set_exception(e)
return False
class TransitionStatus(SubscriptionStatus): class TransitionStatus(SubscriptionStatus):
""" """
Status class to monitor transitions of a signal value through a list of specified transitions. Status to monitor transitions of a signal value through a list of specified transitions.
The status is finished when all transitions have been observed in order. The keyword argument The status is finished when all transitions have been observed in order. The keyword argument
`strict` determines whether the transitions must occur in strict order or not. `strict` determines whether the transitions must occur in strict order or not. The strict option
If `raise_states` is provided, the status will raise an exception if the signal value matches only becomes relevant once the first transition has been observed.
any of the values in `raise_states`. If `failure_states` is provided, the status will raise an exception if the signal value matches
any of the values in `failure_states`.
Args: Args:
signal: The device signal to monitor. signal (Signal): The signal to monitor.
transitions: A list of values to transition through. transitions (list[float | int | str]): List of values representing the transitions to observe.
strict: Whether the transitions must occur in strict order. Defaults to True. strict (bool, optional): Whether to enforce strict order of transitions. Defaults to True.
raise_states: A list of values that will raise an exception if encountered. Defaults to None. failure_states (list[float | int | str] | None, optional):
run: Whether to run the status callback on creation or not. Defaults to True. A list of values that will trigger an exception if encountered. Defaults to None.
event_type: The type of event to trigger on transition. Defaults to None (default sub). run (bool, optional): Whether to start the status immediately. Defaults to True.
timeout: The timeout for the status. Defaults to None (indefinite). event_type (int, optional): The event type to subscribe to. Defaults to None.
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0. timeout (float, optional): Timeout for the status. Defaults to None.
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
Notes:
The 'strict' option does not raise if transitions are observed which are out of order.
It only determines whether a transition is accepted if it is observed from the
previous value in the list of transitions to the next value.
For example, with strict=True and transitions=[1, 2, 3], the sequence
0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status
will not complete. With strict=False, both sequences are accepted.
However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted.
To raise an exception if an out-of-order transition is observed, use the
`failure_states` keyword argument.
""" """
def __init__( def __init__(
self, self,
signal: Signal, signal: "Signal",
transitions: list[float | int | str], transitions: list[float | int | str],
*, *,
strict: bool = True, strict: bool = True,
raise_states: list[float | int | str] | None = None, failure_states: list[float | int | str] | None = None,
run: bool = True, run: bool = True,
event_type=None,
timeout: float = None, timeout: float = None,
settle_time: float = 0, settle_time: float = 0,
event_type=None,
): ):
self._signal = signal self._signal = signal
if not isinstance(transitions, list): self._transitions = tuple(transitions)
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
self._transitions = transitions
self._index = 0 self._index = 0
self._strict = strict self._strict = strict
self._raise_states = raise_states if raise_states else [] self._failure_states = failure_states if failure_states else []
super().__init__( super().__init__(
device=signal, device=signal,
callback=self._compare_callback, callback=self._compare_callback,
@@ -147,34 +207,45 @@ class TransitionStatus(SubscriptionStatus):
run=run, run=run,
) )
def _compare_callback(self, old_value, value, **kwargs) -> bool: def _compare_callback(self, old_value: any, value: any, **kwargs) -> bool:
"""Callback for subscription Status""" """
if value in self._raise_states: Callback for subscription Status
self.set_exception(
ValueError( Args:
f"Transition raised an exception: {value}. " old_value (any): Previous value of the signal
f"Expected transitions: {self._transitions}." value (any): Current value of the signal
Returns:
bool: True if all transitions have been observed, False otherwise.
"""
try:
if value in self._failure_states:
self.set_exception(
ValueError(
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
)
) )
) return False
return False if self._index == 0:
if self._index == 0: if value == self._transitions[0]:
if value == self._transitions[0]:
self._index += 1
else:
if self._strict:
if (
old_value == self._transitions[self._index - 1]
and value == self._transitions[self._index]
):
self._index += 1 self._index += 1
else: else:
if value == self._transitions[self._index]: if self._strict:
self._index += 1 if (
return self._is_finished() old_value == self._transitions[self._index - 1]
and value == self._transitions[self._index]
def _is_finished(self) -> bool: ):
"""Check if the status is finished""" self._index += 1
return self._index >= len(self._transitions) else:
if value == self._transitions[self._index]:
self._index += 1
return self._index >= len(self._transitions)
except Exception as e:
# Catch any exception if the value comparison fails, e.g. value is numpy array
logger.error(f"Error in TransitionStatus callback: {e}")
self.set_exception(e)
return False
class TaskState(str, Enum): class TaskState(str, Enum):

View File

@@ -691,18 +691,19 @@ def test_utils_progress_signal():
def test_utils_compare_status_number(): def test_utils_compare_status_number():
"""Test CompareStatus""" """Test CompareStatus with different operations."""
sig = Signal(name="test_signal", value=0) sig = Signal(name="test_signal", value=0)
status = CompareStatus(signal=sig, value=5, operation="==") status = CompareStatus(signal=sig, value=5, operation_success="==")
assert status.done is False assert status.done is False
sig.put(1) sig.put(1)
assert status.done is False assert status.done is False
sig.put(5) sig.put(5)
status.wait(timeout=5)
assert status.done is True assert status.done is True
sig.put(5) sig.put(5)
# Test with different operations # Test with different operations
status = CompareStatus(signal=sig, value=5, operation="!=") status = CompareStatus(signal=sig, value=5, operation_success="!=")
assert status.done is False assert status.done is False
sig.put(5) sig.put(5)
assert status.done is False assert status.done is False
@@ -712,7 +713,7 @@ def test_utils_compare_status_number():
assert status.exception() is None assert status.exception() is None
sig.put(0) sig.put(0)
status = CompareStatus(signal=sig, value=5, operation=">") status = CompareStatus(signal=sig, value=5, operation_success=">")
assert status.done is False assert status.done is False
sig.put(5) sig.put(5)
assert status.done is False assert status.done is False
@@ -721,11 +722,44 @@ def test_utils_compare_status_number():
assert status.success is True assert status.success is True
assert status.exception() is None assert status.exception() is None
# Should raise
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation_success="==", failure_value=[10])
with pytest.raises(ValueError):
sig.put(10)
status.wait()
assert status.done is True
assert status.success is False
assert isinstance(status.exception(), ValueError)
# failure_operation
sig.put(0)
status = CompareStatus(
signal=sig, value=5, operation_success="==", failure_value=10, operation_failure=">"
)
sig.put(10)
assert status.done is False
assert status.success is False
sig.put(11)
with pytest.raises(ValueError):
status.wait()
assert status.done is True
assert status.success is False
# raise if array is returned
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation_success="==")
with pytest.raises(ValueError):
sig.put([1, 2, 3])
status.wait(timeout=2)
assert status.done is True
assert status.success is False
def test_compare_status_string(): def test_compare_status_string():
"""Test CompareStatus with string values""" """Test CompareStatus with string values"""
sig = Signal(name="test_signal", value="test") sig = Signal(name="test_signal", value="test")
status = CompareStatus(signal=sig, value="test", operation="==") status = CompareStatus(signal=sig, value="test", operation_success="==")
assert status.done is False assert status.done is False
sig.put("test1") sig.put("test1")
assert status.done is False assert status.done is False
@@ -734,7 +768,7 @@ def test_compare_status_string():
sig.put("test") sig.put("test")
# Test with different operations # Test with different operations
status = CompareStatus(signal=sig, value="test", operation="!=") status = CompareStatus(signal=sig, value="test", operation_success="!=")
assert status.done is False assert status.done is False
sig.put("test") sig.put("test")
assert status.done is False assert status.done is False
@@ -743,12 +777,6 @@ def test_compare_status_string():
assert status.success is True assert status.success is True
assert status.exception() is None assert status.exception() is None
# Test with greater than operation
# Raises ValueError for strings
sig.put("a")
with pytest.raises(ValueError):
status = CompareStatus(signal=sig, value="b", operation=">")
def test_transition_status(): def test_transition_status():
"""Test TransitionStatus""" """Test TransitionStatus"""
@@ -768,9 +796,9 @@ def test_transition_status():
assert status.success is True assert status.success is True
assert status.exception() is None assert status.exception() is None
# Test strict=True, ra # Test strict=True, failure_states
sig.put(1) sig.put(1)
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4]) status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4])
assert status.done is False assert status.done is False
sig.put(4) sig.put(4)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -854,7 +882,7 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals""" """Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
signal = mock_epics_signal_ro signal = mock_epics_signal_ro
status = CompareStatus(signal=signal, value=5, operation="==") status = CompareStatus(signal=signal, value=5, operation_success="==")
assert status.done is False assert status.done is False
signal._read_pv.mock_data = 1 signal._read_pv.mock_data = 1
assert status.done is False assert status.done is False