refactor: refactored compare and transition state

This commit is contained in:
2025-06-17 07:10:03 +02:00
committed by Christian Appel
parent d092b8b51a
commit 20eb5dd83f
3 changed files with 176 additions and 32 deletions

View File

@ -20,5 +20,5 @@ from .devices.softpositioner import SoftPositioner
from .utils.bec_device_base import BECDeviceBase
from .utils.bec_signals import *
from .utils.dynamic_pseudo import ComputedSignal
from .utils.psi_device_base_utils import CompareStatus, TargetStatus
from .utils.psi_device_base_utils import *
from .utils.static_device_test import launch

View File

@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, Callable, Literal
from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger
from bec_lib.utils.import_utils import lazy_import_from
from ophyd import Device, DeviceStatus, Signal
from ophyd.status import SubscriptionStatus
from ophyd import Device, Signal
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus
if TYPE_CHECKING: # pragma: no cover
from bec_lib.messages import ScanStatusMessage
@ -21,6 +21,17 @@ else:
ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",))
__all__ = [
"CompareStatus",
"TransitionStatus",
"AndStatus",
"DeviceStatus",
"MoveStatus",
"Status",
"StatusBase",
"SubscriptionStatus",
]
logger = bec_logger.logger
set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc
@ -36,12 +47,26 @@ OP_MAP = {
class CompareStatus(SubscriptionStatus):
"""Status class to compare a value from a device signal with a target value."""
"""
Status class to compare a value from a device signal with a target value.
The value can be a float, int, or string. If the value is a string,
the operation must be either '==' or '!='. For numeric (float or int) values,
the operation can be any of the standard comparison operators.
Args:
signal: The device signal to compare.
value: The target value to compare against.
operation: The comparison operation to use. Defaults to '=='.
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
timeout: The timeout for the status. Defaults to None (indefinite).
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
run: Whether to run the status immediately or not. Defaults to True.
"""
def __init__(
self,
signal: Signal,
value: Any,
value: float | int | str,
*,
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
event_type=None,
@ -49,6 +74,11 @@ class CompareStatus(SubscriptionStatus):
settle_time: float = 0,
run: bool = True,
):
if isinstance(value, str):
if operation not in ("==", "!="):
raise ValueError(
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
)
if operation not in ("==", "!=", "<", "<=", ">", ">="):
raise ValueError(
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
@ -66,25 +96,52 @@ class CompareStatus(SubscriptionStatus):
)
def _compare_callback(self, value, **kwargs) -> bool:
"""Callback for subscription Status"""
"""Callback for subscription status"""
return OP_MAP[self._operation](value, self._value)
class TargetStatus(SubscriptionStatus):
"""Status class to compare a list of values that are expected to be reached in sequence for a device signal."""
class TransitionStatus(SubscriptionStatus):
"""
Status class to compare a list of transitions.
The transitions can be a list of float, int, or string values.
The transitions are checked in order, and the status is finished when all transitions
have been matched in sequence. The keyword argument `strict` determines whether
the transitions must match exactly in order, or if intermediate transitions are allowed.
For the first value, the strict check is not applied, meaning that the sequence starts once
the first transition is matched.
Args:
signal: The device signal to compare.
transitions: A list of transitions to compare against.
strict: Whether to enforce strict matching of transitions. Defaults to True.
run: Whether to run the status immediately or not. Defaults to True.
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
timeout: The timeout for the status. Defaults to None (indefinite).
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
Raises:
ValueError: If the transitions do not match the expected sequence. and strict is True.
"""
def __init__(
self,
signal: Signal,
values: list[Any],
transitions: list[float | int | str],
*,
strict: bool = True,
raise_states: list[float | int | str] | None = None,
run: bool = True,
event_type=None,
timeout: float = None,
settle_time: float = 0,
run: bool = True,
):
self._signal = signal
self._values = values
if not isinstance(transitions, list):
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
self._transitions = transitions
self._index = 0
self._strict = strict
self._raise_states = raise_states if raise_states else []
super().__init__(
device=signal,
callback=self._compare_callback,
@ -94,13 +151,34 @@ class TargetStatus(SubscriptionStatus):
run=run,
)
def _compare_callback(self, value, **kwargs) -> bool:
def _compare_callback(self, old_value, value, **kwargs) -> bool:
"""Callback for subscription Status"""
if value == self._values[0]:
self._values.pop(0)
if len(self._values) == 0:
return True
return False
if value in self._raise_states:
self.set_exception(
ValueError(
f"Transition raised an exception: {value}. "
f"Expected transitions: {self._transitions}."
)
)
return False
if self._index == 0:
if value == self._transitions[0]:
self._index += 1
else:
if self._strict:
if (
old_value == self._transitions[self._index - 1]
and value == self._transitions[self._index]
):
self._index += 1
else:
if value == self._transitions[self._index]:
self._index += 1
return self._is_finished()
def _is_finished(self) -> bool:
"""Check if the status is finished"""
return self._index >= len(self._transitions)
class TaskState(str, Enum):

View File

@ -16,11 +16,11 @@ from ophyd_devices.utils.bec_signals import (
from ophyd_devices.utils.psi_device_base_utils import (
CompareStatus,
FileHandler,
TargetStatus,
TaskHandler,
TaskKilledError,
TaskState,
TaskStatus,
TransitionStatus,
)
# pylint: disable=protected-access
@ -529,20 +529,7 @@ def test_utils_progress_signal():
signal.put({"wrong_key": "wrong_value"})
def test_utils_target_status():
"""Test TargetStatus"""
sig = Signal(name="test_signal", value=0)
status = TargetStatus(signal=sig, values=[1, 2, 3])
assert status.done is False
sig.put(1)
assert status.done is False
sig.put(2)
assert status.done is False
sig.put(3)
assert status.done is True
def test_utils_compare_status():
def test_utils_compare_status_number():
"""Test CompareStatus"""
sig = Signal(name="test_signal", value=0)
status = CompareStatus(signal=sig, value=5, operation="==")
@ -560,6 +547,8 @@ def test_utils_compare_status():
assert status.done is False
sig.put(6)
assert status.done is True
assert status.success is True
assert status.exception() is None
sig.put(0)
status = CompareStatus(signal=sig, value=5, operation=">")
@ -568,3 +557,80 @@ def test_utils_compare_status():
assert status.done is False
sig.put(10)
assert status.done is True
assert status.success is True
assert status.exception() is None
def test_compare_status_string():
"""Test CompareStatus with string values"""
sig = Signal(name="test_signal", value="test")
status = CompareStatus(signal=sig, value="test", operation="==")
assert status.done is False
sig.put("test1")
assert status.done is False
sig.put("test")
assert status.done is True
sig.put("test")
# Test with different operations
status = CompareStatus(signal=sig, value="test", operation="!=")
assert status.done is False
sig.put("test")
assert status.done is False
sig.put("test1")
assert status.done is True
assert status.success is True
assert status.exception() is None
# Test with greater than operation
# Raises ValueError for strings
sig.put("a")
with pytest.raises(ValueError):
status = CompareStatus(signal=sig, value="b", operation=">")
def test_transition_status():
"""Test TransitionStatus"""
sig = Signal(name="test_signal", value=0)
# Test strict=True, without intermediate transitions
sig.put(0)
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True)
assert status.done is False
sig.put(1)
assert status.done is False
sig.put(2)
assert status.done is False
sig.put(3)
assert status.done is True
assert status.success is True
assert status.exception() is None
# Test strict=True, ra
sig.put(1)
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4])
assert status.done is False
sig.put(4)
with pytest.raises(ValueError):
status.wait()
assert status.done is True
assert status.success is False
assert isinstance(status.exception(), ValueError)
# Test strict=False, with intermediate transitions
sig.put(0)
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=False)
assert status.done is False
sig.put(1) # entering first transition
sig.put(3)
sig.put(2) # transision
assert status.done is False
sig.put(4)
sig.put(2)
sig.put(3) # last transition
assert status.done is True
assert status.success is True
assert status.exception() is None