mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-23 19:27:59 +02:00
refactor: refactored compare and transition state
This commit is contained in:
@ -20,5 +20,5 @@ from .devices.softpositioner import SoftPositioner
|
||||
from .utils.bec_device_base import BECDeviceBase
|
||||
from .utils.bec_signals import *
|
||||
from .utils.dynamic_pseudo import ComputedSignal
|
||||
from .utils.psi_device_base_utils import CompareStatus, TargetStatus
|
||||
from .utils.psi_device_base_utils import *
|
||||
from .utils.static_device_test import launch
|
||||
|
@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
from bec_lib.file_utils import get_full_path
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.utils.import_utils import lazy_import_from
|
||||
from ophyd import Device, DeviceStatus, Signal
|
||||
from ophyd.status import SubscriptionStatus
|
||||
from ophyd import Device, Signal
|
||||
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_lib.messages import ScanStatusMessage
|
||||
@ -21,6 +21,17 @@ else:
|
||||
ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompareStatus",
|
||||
"TransitionStatus",
|
||||
"AndStatus",
|
||||
"DeviceStatus",
|
||||
"MoveStatus",
|
||||
"Status",
|
||||
"StatusBase",
|
||||
"SubscriptionStatus",
|
||||
]
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
||||
@ -36,12 +47,26 @@ OP_MAP = {
|
||||
|
||||
|
||||
class CompareStatus(SubscriptionStatus):
|
||||
"""Status class to compare a value from a device signal with a target value."""
|
||||
"""
|
||||
Status class to compare a value from a device signal with a target value.
|
||||
The value can be a float, int, or string. If the value is a string,
|
||||
the operation must be either '==' or '!='. For numeric (float or int) values,
|
||||
the operation can be any of the standard comparison operators.
|
||||
|
||||
Args:
|
||||
signal: The device signal to compare.
|
||||
value: The target value to compare against.
|
||||
operation: The comparison operation to use. Defaults to '=='.
|
||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
run: Whether to run the status immediately or not. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
value: Any,
|
||||
value: float | int | str,
|
||||
*,
|
||||
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||
event_type=None,
|
||||
@ -49,6 +74,11 @@ class CompareStatus(SubscriptionStatus):
|
||||
settle_time: float = 0,
|
||||
run: bool = True,
|
||||
):
|
||||
if isinstance(value, str):
|
||||
if operation not in ("==", "!="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
|
||||
)
|
||||
if operation not in ("==", "!=", "<", "<=", ">", ">="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
||||
@ -66,25 +96,52 @@ class CompareStatus(SubscriptionStatus):
|
||||
)
|
||||
|
||||
def _compare_callback(self, value, **kwargs) -> bool:
|
||||
"""Callback for subscription Status"""
|
||||
"""Callback for subscription status"""
|
||||
return OP_MAP[self._operation](value, self._value)
|
||||
|
||||
|
||||
class TargetStatus(SubscriptionStatus):
|
||||
"""Status class to compare a list of values that are expected to be reached in sequence for a device signal."""
|
||||
class TransitionStatus(SubscriptionStatus):
|
||||
"""
|
||||
Status class to compare a list of transitions.
|
||||
The transitions can be a list of float, int, or string values.
|
||||
The transitions are checked in order, and the status is finished when all transitions
|
||||
have been matched in sequence. The keyword argument `strict` determines whether
|
||||
the transitions must match exactly in order, or if intermediate transitions are allowed.
|
||||
For the first value, the strict check is not applied, meaning that the sequence starts once
|
||||
the first transition is matched.
|
||||
|
||||
Args:
|
||||
signal: The device signal to compare.
|
||||
transitions: A list of transitions to compare against.
|
||||
strict: Whether to enforce strict matching of transitions. Defaults to True.
|
||||
run: Whether to run the status immediately or not. Defaults to True.
|
||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If the transitions do not match the expected sequence. and strict is True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
values: list[Any],
|
||||
transitions: list[float | int | str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
raise_states: list[float | int | str] | None = None,
|
||||
run: bool = True,
|
||||
event_type=None,
|
||||
timeout: float = None,
|
||||
settle_time: float = 0,
|
||||
run: bool = True,
|
||||
):
|
||||
self._signal = signal
|
||||
self._values = values
|
||||
if not isinstance(transitions, list):
|
||||
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
|
||||
self._transitions = transitions
|
||||
self._index = 0
|
||||
self._strict = strict
|
||||
self._raise_states = raise_states if raise_states else []
|
||||
super().__init__(
|
||||
device=signal,
|
||||
callback=self._compare_callback,
|
||||
@ -94,13 +151,34 @@ class TargetStatus(SubscriptionStatus):
|
||||
run=run,
|
||||
)
|
||||
|
||||
def _compare_callback(self, value, **kwargs) -> bool:
|
||||
def _compare_callback(self, old_value, value, **kwargs) -> bool:
|
||||
"""Callback for subscription Status"""
|
||||
if value == self._values[0]:
|
||||
self._values.pop(0)
|
||||
if len(self._values) == 0:
|
||||
return True
|
||||
return False
|
||||
if value in self._raise_states:
|
||||
self.set_exception(
|
||||
ValueError(
|
||||
f"Transition raised an exception: {value}. "
|
||||
f"Expected transitions: {self._transitions}."
|
||||
)
|
||||
)
|
||||
return False
|
||||
if self._index == 0:
|
||||
if value == self._transitions[0]:
|
||||
self._index += 1
|
||||
else:
|
||||
if self._strict:
|
||||
if (
|
||||
old_value == self._transitions[self._index - 1]
|
||||
and value == self._transitions[self._index]
|
||||
):
|
||||
self._index += 1
|
||||
else:
|
||||
if value == self._transitions[self._index]:
|
||||
self._index += 1
|
||||
return self._is_finished()
|
||||
|
||||
def _is_finished(self) -> bool:
|
||||
"""Check if the status is finished"""
|
||||
return self._index >= len(self._transitions)
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
|
@ -16,11 +16,11 @@ from ophyd_devices.utils.bec_signals import (
|
||||
from ophyd_devices.utils.psi_device_base_utils import (
|
||||
CompareStatus,
|
||||
FileHandler,
|
||||
TargetStatus,
|
||||
TaskHandler,
|
||||
TaskKilledError,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TransitionStatus,
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@ -529,20 +529,7 @@ def test_utils_progress_signal():
|
||||
signal.put({"wrong_key": "wrong_value"})
|
||||
|
||||
|
||||
def test_utils_target_status():
|
||||
"""Test TargetStatus"""
|
||||
sig = Signal(name="test_signal", value=0)
|
||||
status = TargetStatus(signal=sig, values=[1, 2, 3])
|
||||
assert status.done is False
|
||||
sig.put(1)
|
||||
assert status.done is False
|
||||
sig.put(2)
|
||||
assert status.done is False
|
||||
sig.put(3)
|
||||
assert status.done is True
|
||||
|
||||
|
||||
def test_utils_compare_status():
|
||||
def test_utils_compare_status_number():
|
||||
"""Test CompareStatus"""
|
||||
sig = Signal(name="test_signal", value=0)
|
||||
status = CompareStatus(signal=sig, value=5, operation="==")
|
||||
@ -560,6 +547,8 @@ def test_utils_compare_status():
|
||||
assert status.done is False
|
||||
sig.put(6)
|
||||
assert status.done is True
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
sig.put(0)
|
||||
status = CompareStatus(signal=sig, value=5, operation=">")
|
||||
@ -568,3 +557,80 @@ def test_utils_compare_status():
|
||||
assert status.done is False
|
||||
sig.put(10)
|
||||
assert status.done is True
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
|
||||
def test_compare_status_string():
|
||||
"""Test CompareStatus with string values"""
|
||||
sig = Signal(name="test_signal", value="test")
|
||||
status = CompareStatus(signal=sig, value="test", operation="==")
|
||||
assert status.done is False
|
||||
sig.put("test1")
|
||||
assert status.done is False
|
||||
sig.put("test")
|
||||
assert status.done is True
|
||||
|
||||
sig.put("test")
|
||||
# Test with different operations
|
||||
status = CompareStatus(signal=sig, value="test", operation="!=")
|
||||
assert status.done is False
|
||||
sig.put("test")
|
||||
assert status.done is False
|
||||
sig.put("test1")
|
||||
assert status.done is True
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
# Test with greater than operation
|
||||
# Raises ValueError for strings
|
||||
sig.put("a")
|
||||
with pytest.raises(ValueError):
|
||||
status = CompareStatus(signal=sig, value="b", operation=">")
|
||||
|
||||
|
||||
def test_transition_status():
|
||||
"""Test TransitionStatus"""
|
||||
sig = Signal(name="test_signal", value=0)
|
||||
|
||||
# Test strict=True, without intermediate transitions
|
||||
sig.put(0)
|
||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True)
|
||||
|
||||
assert status.done is False
|
||||
sig.put(1)
|
||||
assert status.done is False
|
||||
sig.put(2)
|
||||
assert status.done is False
|
||||
sig.put(3)
|
||||
assert status.done is True
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
||||
# Test strict=True, ra
|
||||
sig.put(1)
|
||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=True, raise_states=[4])
|
||||
assert status.done is False
|
||||
sig.put(4)
|
||||
with pytest.raises(ValueError):
|
||||
status.wait()
|
||||
|
||||
assert status.done is True
|
||||
assert status.success is False
|
||||
assert isinstance(status.exception(), ValueError)
|
||||
|
||||
# Test strict=False, with intermediate transitions
|
||||
sig.put(0)
|
||||
status = TransitionStatus(signal=sig, transitions=[1, 2, 3], strict=False)
|
||||
|
||||
assert status.done is False
|
||||
sig.put(1) # entering first transition
|
||||
sig.put(3)
|
||||
sig.put(2) # transision
|
||||
assert status.done is False
|
||||
sig.put(4)
|
||||
sig.put(2)
|
||||
sig.put(3) # last transition
|
||||
assert status.done is True
|
||||
assert status.success is True
|
||||
assert status.exception() is None
|
||||
|
Reference in New Issue
Block a user