mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-09 00:28:40 +01:00
fix(bec-status): Refactor CompareStatus and TransitionStatus
This commit is contained in:
@@ -48,44 +48,62 @@ OP_MAP = {
|
|||||||
|
|
||||||
class CompareStatus(SubscriptionStatus):
|
class CompareStatus(SubscriptionStatus):
|
||||||
"""
|
"""
|
||||||
Status class to compare a signal value against a given value.
|
Status to compare a signal value against a given value.
|
||||||
The comparison is done using the specified operation, which can be one of
|
The comparison is done using the specified operation, which can be one of
|
||||||
'==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed.
|
'==', '!=', '<', '<=', '>', '>='. If the value is a string, only '==' and '!=' are allowed.
|
||||||
The status is finished when the comparison is true.
|
One may also define a value or list of values that will result in an exception if encountered.
|
||||||
|
The status is finished when the comparison is either true or an exception is raised.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signal: The device signal to compare.
|
signal (Signal): The signal to monitor.
|
||||||
value: The value to compare against.
|
value (float | int | str): The target value to compare against.
|
||||||
operation: The operation to use for comparison. Defaults to '=='.
|
operation_success (str, optional): The comparison operation for success. Defaults to '=='.
|
||||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
failure_value (float | int | str | list[float | int | str] | None, optional):
|
||||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
A value or list of values that will trigger an exception if encountered. Defaults to None.
|
||||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
operation_failure (str, optional): The comparison operation for failure values. Defaults to '=='.
|
||||||
run: Whether to run the status callback on creation or not. Defaults to True.
|
event_type (int, optional): The event type to subscribe to. Defaults to None.
|
||||||
|
timeout (float, optional): Timeout for the status. Defaults to None.
|
||||||
|
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
|
||||||
|
run (bool, optional): Whether to start the status immediately. Defaults to True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
signal: Signal,
|
signal: "Signal",
|
||||||
value: float | int | str,
|
value: float | int | str,
|
||||||
*,
|
*,
|
||||||
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
operation_success: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||||
event_type=None,
|
failure_value: float | int | str | list[float | int | str] | None = None,
|
||||||
|
operation_failure: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
settle_time: float = 0,
|
settle_time: float = 0,
|
||||||
run: bool = True,
|
run: bool = True,
|
||||||
|
event_type=None,
|
||||||
):
|
):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if operation not in ("==", "!="):
|
if operation_success not in ("==", "!=") and operation_failure not in ("==", "!="):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
|
f"Invalid operation_success: {operation_success} for string comparison. Must be '==' or '!='."
|
||||||
)
|
)
|
||||||
if operation not in ("==", "!=", "<", "<=", ">", ">="):
|
if operation_success not in ("==", "!=", "<", "<=", ">", ">="):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
f"Invalid operation_success: {operation_success}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
||||||
)
|
)
|
||||||
self._signal = signal
|
self._signal = signal
|
||||||
self._value = value
|
self._value = value
|
||||||
self._operation = operation
|
self._operation_success = operation_success
|
||||||
|
self._operation_failure = operation_failure
|
||||||
|
self.op_map = OP_MAP
|
||||||
|
if failure_value is None:
|
||||||
|
self._failure_values = []
|
||||||
|
elif isinstance(failure_value, (float, int, str)):
|
||||||
|
self._failure_values = [failure_value]
|
||||||
|
elif isinstance(failure_value, list):
|
||||||
|
self._failure_values = failure_value
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"failure_value must be a float, int, str, list or None. Received: {failure_value}"
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=signal,
|
device=signal,
|
||||||
callback=self._compare_callback,
|
callback=self._compare_callback,
|
||||||
@@ -95,49 +113,91 @@ class CompareStatus(SubscriptionStatus):
|
|||||||
run=run,
|
run=run,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compare_callback(self, value, **kwargs) -> bool:
|
def _compare_callback(self, value: any, **kwargs) -> bool:
|
||||||
"""Callback for subscription status"""
|
"""
|
||||||
return OP_MAP[self._operation](value, self._value)
|
Callback for subscription status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (any): Current value of the signal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if comparison is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(value, list):
|
||||||
|
self.set_exception(
|
||||||
|
ValueError(f"List values are not supported. Received value: {value}")
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if any(
|
||||||
|
self.op_map[self._operation_failure](value, failure_value)
|
||||||
|
for failure_value in self._failure_values
|
||||||
|
):
|
||||||
|
self.set_exception(
|
||||||
|
ValueError(
|
||||||
|
f"CompareStatus for signal {self._signal.name} "
|
||||||
|
f"did not reach the desired state {self._operation_success} {self._value}. "
|
||||||
|
f"But instead reached {value}, which is in list of failure values: {self._failure_values}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return self.op_map[self._operation_success](value, self._value)
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any exception if the value comparison fails, e.g. value is numpy array
|
||||||
|
logger.error(f"Error in CompareStatus callback: {e}")
|
||||||
|
self.set_exception(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TransitionStatus(SubscriptionStatus):
|
class TransitionStatus(SubscriptionStatus):
|
||||||
"""
|
"""
|
||||||
Status class to monitor transitions of a signal value through a list of specified transitions.
|
Status to monitor transitions of a signal value through a list of specified transitions.
|
||||||
The status is finished when all transitions have been observed in order. The keyword argument
|
The status is finished when all transitions have been observed in order. The keyword argument
|
||||||
`strict` determines whether the transitions must occur in strict order or not.
|
`strict` determines whether the transitions must occur in strict order or not. The strict option
|
||||||
If `raise_states` is provided, the status will raise an exception if the signal value matches
|
only becomes relevant once the first transition has been observed.
|
||||||
any of the values in `raise_states`.
|
If `failure_states` is provided, the status will raise an exception if the signal value matches
|
||||||
|
any of the values in `failure_states`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signal: The device signal to monitor.
|
signal (Signal): The signal to monitor.
|
||||||
transitions: A list of values to transition through.
|
transitions (list[float | int | str]): List of values representing the transitions to observe.
|
||||||
strict: Whether the transitions must occur in strict order. Defaults to True.
|
strict (bool, optional): Whether to enforce strict order of transitions. Defaults to True.
|
||||||
raise_states: A list of values that will raise an exception if encountered. Defaults to None.
|
failure_states (list[float | int | str] | None, optional):
|
||||||
run: Whether to run the status callback on creation or not. Defaults to True.
|
A list of values that will trigger an exception if encountered. Defaults to None.
|
||||||
event_type: The type of event to trigger on transition. Defaults to None (default sub).
|
run (bool, optional): Whether to start the status immediately. Defaults to True.
|
||||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
event_type (int, optional): The event type to subscribe to. Defaults to None.
|
||||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
timeout (float, optional): Timeout for the status. Defaults to None.
|
||||||
|
settle_time (float, optional): Settle time before checking the status. Defaults to 0.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The 'strict' option does not raise if transitions are observed which are out of order.
|
||||||
|
It only determines whether a transition is accepted if it is observed from the
|
||||||
|
previous value in the list of transitions to the next value.
|
||||||
|
For example, with strict=True and transitions=[1, 2, 3], the sequence
|
||||||
|
0 -> 1 -> 2 -> 3 is accepted, but 0 -> 1 -> 3 -> 2 -> 3 is not and the status
|
||||||
|
will not complete. With strict=False, both sequences are accepted.
|
||||||
|
However, with strict=True, the sequence 0 -> 1 -> 3 -> 1 -> 2 -> 3 is accepted.
|
||||||
|
To raise an exception if an out-of-order transition is observed, use the
|
||||||
|
`failure_states` keyword argument.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
signal: Signal,
|
signal: "Signal",
|
||||||
transitions: list[float | int | str],
|
transitions: list[float | int | str],
|
||||||
*,
|
*,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
raise_states: list[float | int | str] | None = None,
|
failure_states: list[float | int | str] | None = None,
|
||||||
run: bool = True,
|
run: bool = True,
|
||||||
event_type=None,
|
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
settle_time: float = 0,
|
settle_time: float = 0,
|
||||||
|
event_type=None,
|
||||||
):
|
):
|
||||||
self._signal = signal
|
self._signal = signal
|
||||||
if not isinstance(transitions, list):
|
self._transitions = tuple(transitions)
|
||||||
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
|
|
||||||
self._transitions = transitions
|
|
||||||
self._index = 0
|
self._index = 0
|
||||||
self._strict = strict
|
self._strict = strict
|
||||||
self._raise_states = raise_states if raise_states else []
|
self._failure_states = failure_states if failure_states else []
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device=signal,
|
device=signal,
|
||||||
callback=self._compare_callback,
|
callback=self._compare_callback,
|
||||||
@@ -147,13 +207,23 @@ class TransitionStatus(SubscriptionStatus):
|
|||||||
run=run,
|
run=run,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compare_callback(self, old_value, value, **kwargs) -> bool:
|
def _compare_callback(self, old_value: any, value: any, **kwargs) -> bool:
|
||||||
"""Callback for subscription Status"""
|
"""
|
||||||
if value in self._raise_states:
|
Callback for subscription Status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_value (any): Previous value of the signal
|
||||||
|
value (any): Current value of the signal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all transitions have been observed, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if value in self._failure_states:
|
||||||
self.set_exception(
|
self.set_exception(
|
||||||
ValueError(
|
ValueError(
|
||||||
f"Transition raised an exception: {value}. "
|
f"Transition Status for {self._signal.name} resulted in a value: {value}. "
|
||||||
f"Expected transitions: {self._transitions}."
|
f"marked to raise {self._failure_states}. Expected transitions: {self._transitions}."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
@@ -170,11 +240,12 @@ class TransitionStatus(SubscriptionStatus):
|
|||||||
else:
|
else:
|
||||||
if value == self._transitions[self._index]:
|
if value == self._transitions[self._index]:
|
||||||
self._index += 1
|
self._index += 1
|
||||||
return self._is_finished()
|
|
||||||
|
|
||||||
def _is_finished(self) -> bool:
|
|
||||||
"""Check if the status is finished"""
|
|
||||||
return self._index >= len(self._transitions)
|
return self._index >= len(self._transitions)
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any exception if the value comparison fails, e.g. value is numpy array
|
||||||
|
logger.error(f"Error in TransitionStatus callback: {e}")
|
||||||
|
self.set_exception(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TaskState(str, Enum):
|
class TaskState(str, Enum):
|
||||||
|
|||||||
@@ -691,18 +691,19 @@ def test_utils_progress_signal():
|
|||||||
|
|
||||||
|
|
||||||
def test_utils_compare_status_number():
|
def test_utils_compare_status_number():
|
||||||
"""Test CompareStatus"""
|
"""Test CompareStatus with different operations."""
|
||||||
sig = Signal(name="test_signal", value=0)
|
sig = Signal(name="test_signal", value=0)
|
||||||
status = CompareStatus(signal=sig, value=5, operation="==")
|
status = CompareStatus(signal=sig, value=5, operation_success="==")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put(1)
|
sig.put(1)
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put(5)
|
sig.put(5)
|
||||||
|
status.wait(timeout=5)
|
||||||
assert status.done is True
|
assert status.done is True
|
||||||
|
|
||||||
sig.put(5)
|
sig.put(5)
|
||||||
# Test with different operations
|
# Test with different operations
|
||||||
status = CompareStatus(signal=sig, value=5, operation="!=")
|
status = CompareStatus(signal=sig, value=5, operation_success="!=")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put(5)
|
sig.put(5)
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
@@ -712,7 +713,7 @@ def test_utils_compare_status_number():
|
|||||||
assert status.exception() is None
|
assert status.exception() is None
|
||||||
|
|
||||||
sig.put(0)
|
sig.put(0)
|
||||||
status = CompareStatus(signal=sig, value=5, operation=">")
|
status = CompareStatus(signal=sig, value=5, operation_success=">")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put(5)
|
sig.put(5)
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
@@ -721,11 +722,44 @@ def test_utils_compare_status_number():
|
|||||||
assert status.success is True
|
assert status.success is True
|
||||||
assert status.exception() is None
|
assert status.exception() is None
|
||||||
|
|
||||||
|
# Should raise
|
||||||
|
sig.put(0)
|
||||||
|
status = CompareStatus(signal=sig, value=5, operation_success="==", failure_value=[10])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sig.put(10)
|
||||||
|
status.wait()
|
||||||
|
assert status.done is True
|
||||||
|
assert status.success is False
|
||||||
|
assert isinstance(status.exception(), ValueError)
|
||||||
|
|
||||||
|
# failure_operation
|
||||||
|
sig.put(0)
|
||||||
|
status = CompareStatus(
|
||||||
|
signal=sig, value=5, operation_success="==", failure_value=10, operation_failure=">"
|
||||||
|
)
|
||||||
|
sig.put(10)
|
||||||
|
assert status.done is False
|
||||||
|
assert status.success is False
|
||||||
|
sig.put(11)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
status.wait()
|
||||||
|
assert status.done is True
|
||||||
|
assert status.success is False
|
||||||
|
|
||||||
|
# raise if array is returned
|
||||||
|
sig.put(0)
|
||||||
|
status = CompareStatus(signal=sig, value=5, operation_success="==")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sig.put([1, 2, 3])
|
||||||
|
status.wait(timeout=2)
|
||||||
|
assert status.done is True
|
||||||
|
assert status.success is False
|
||||||
|
|
||||||
|
|
||||||
def test_compare_status_string():
|
def test_compare_status_string():
|
||||||
"""Test CompareStatus with string values"""
|
"""Test CompareStatus with string values"""
|
||||||
sig = Signal(name="test_signal", value="test")
|
sig = Signal(name="test_signal", value="test")
|
||||||
status = CompareStatus(signal=sig, value="test", operation="==")
|
status = CompareStatus(signal=sig, value="test", operation_success="==")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put("test1")
|
sig.put("test1")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
@@ -734,7 +768,7 @@ def test_compare_status_string():
|
|||||||
|
|
||||||
sig.put("test")
|
sig.put("test")
|
||||||
# Test with different operations
|
# Test with different operations
|
||||||
status = CompareStatus(signal=sig, value="test", operation="!=")
|
status = CompareStatus(signal=sig, value="test", operation_success="!=")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put("test")
|
sig.put("test")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
@@ -743,12 +777,6 @@ def test_compare_status_string():
|
|||||||
assert status.success is True
|
assert status.success is True
|
||||||
assert status.exception() is None
|
assert status.exception() is None
|
||||||
|
|
||||||
# Test with greater than operation
|
|
||||||
# Raises ValueError for strings
|
|
||||||
sig.put("a")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
status = CompareStatus(signal=sig, value="b", operation=">")
|
|
||||||
|
|
||||||
|
|
||||||
def test_transition_status():
|
def test_transition_status():
|
||||||
"""Test TransitionStatus"""
|
"""Test TransitionStatus"""
|
||||||
@@ -768,9 +796,9 @@ def test_transition_status():
|
|||||||
assert status.success is True
|
assert status.success is True
|
||||||
assert status.exception() is None
|
assert status.exception() is None
|
||||||
|
|
||||||
# Test strict=True, ra
|
# Test strict=True, failure_states
|
||||||
sig.put(1)
|
sig.put(1)
|
||||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4])
|
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, failure_states=[4])
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
sig.put(4)
|
sig.put(4)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@@ -854,7 +882,7 @@ def test_compare_status_with_mock_pv(mock_epics_signal_ro):
|
|||||||
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
|
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
|
||||||
|
|
||||||
signal = mock_epics_signal_ro
|
signal = mock_epics_signal_ro
|
||||||
status = CompareStatus(signal=signal, value=5, operation="==")
|
status = CompareStatus(signal=signal, value=5, operation_success="==")
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
signal._read_pv.mock_data = 1
|
signal._read_pv.mock_data = 1
|
||||||
assert status.done is False
|
assert status.done is False
|
||||||
|
|||||||
Reference in New Issue
Block a user