mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-04 06:08:42 +01:00
refactor: refactored compare and transition state
This commit is contained in:
@@ -20,5 +20,5 @@ from .devices.softpositioner import SoftPositioner
|
||||
from .utils.bec_device_base import BECDeviceBase
|
||||
from .utils.bec_signals import *
|
||||
from .utils.dynamic_pseudo import ComputedSignal
|
||||
from .utils.psi_device_base_utils import CompareStatus, TargetStatus
|
||||
from .utils.psi_device_base_utils import *
|
||||
from .utils.static_device_test import launch
|
||||
|
||||
@@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
from bec_lib.file_utils import get_full_path
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.utils.import_utils import lazy_import_from
|
||||
from ophyd import Device, DeviceStatus, Signal
|
||||
from ophyd.status import SubscriptionStatus
|
||||
from ophyd import Device, Signal
|
||||
from ophyd.status import AndStatus, DeviceStatus, MoveStatus, Status, StatusBase, SubscriptionStatus
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_lib.messages import ScanStatusMessage
|
||||
@@ -21,6 +21,17 @@ else:
|
||||
ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompareStatus",
|
||||
"TransitionStatus",
|
||||
"AndStatus",
|
||||
"DeviceStatus",
|
||||
"MoveStatus",
|
||||
"Status",
|
||||
"StatusBase",
|
||||
"SubscriptionStatus",
|
||||
]
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
||||
@@ -36,12 +47,26 @@ OP_MAP = {
|
||||
|
||||
|
||||
class CompareStatus(SubscriptionStatus):
|
||||
"""Status class to compare a value from a device signal with a target value."""
|
||||
"""
|
||||
Status class to compare a value from a device signal with a target value.
|
||||
The value can be a float, int, or string. If the value is a string,
|
||||
the operation must be either '==' or '!='. For numeric (float or int) values,
|
||||
the operation can be any of the standard comparison operators.
|
||||
|
||||
Args:
|
||||
signal: The device signal to compare.
|
||||
value: The target value to compare against.
|
||||
operation: The comparison operation to use. Defaults to '=='.
|
||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
run: Whether to run the status immediately or not. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
value: Any,
|
||||
value: float | int | str,
|
||||
*,
|
||||
operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==",
|
||||
event_type=None,
|
||||
@@ -49,6 +74,11 @@ class CompareStatus(SubscriptionStatus):
|
||||
settle_time: float = 0,
|
||||
run: bool = True,
|
||||
):
|
||||
if isinstance(value, str):
|
||||
if operation not in ("==", "!="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation} for string comparison. Must be '==' or '!='."
|
||||
)
|
||||
if operation not in ("==", "!=", "<", "<=", ">", ">="):
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='."
|
||||
@@ -66,25 +96,52 @@ class CompareStatus(SubscriptionStatus):
|
||||
)
|
||||
|
||||
def _compare_callback(self, value, **kwargs) -> bool:
|
||||
"""Callback for subscription Status"""
|
||||
"""Callback for subscription status"""
|
||||
return OP_MAP[self._operation](value, self._value)
|
||||
|
||||
|
||||
class TargetStatus(SubscriptionStatus):
|
||||
"""Status class to compare a list of values that are expected to be reached in sequence for a device signal."""
|
||||
class TransitionStatus(SubscriptionStatus):
|
||||
"""
|
||||
Status class to compare a list of transitions.
|
||||
The transitions can be a list of float, int, or string values.
|
||||
The transitions are checked in order, and the status is finished when all transitions
|
||||
have been matched in sequence. The keyword argument `strict` determines whether
|
||||
the transitions must match exactly in order, or if intermediate transitions are allowed.
|
||||
For the first value, the strict check is not applied, meaning that the sequence starts once
|
||||
the first transition is matched.
|
||||
|
||||
Args:
|
||||
signal: The device signal to compare.
|
||||
transitions: A list of transitions to compare against.
|
||||
strict: Whether to enforce strict matching of transitions. Defaults to True.
|
||||
run: Whether to run the status immediately or not. Defaults to True.
|
||||
event_type: The type of event to trigger on comparison. Defaults to None (default sub).
|
||||
timeout: The timeout for the status. Defaults to None (indefinite).
|
||||
settle_time: The time to wait for the signal to settle before comparison. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If the transitions do not match the expected sequence. and strict is True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
signal: Signal,
|
||||
values: list[Any],
|
||||
transitions: list[float | int | str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
raise_states: list[float | int | str] | None = None,
|
||||
run: bool = True,
|
||||
event_type=None,
|
||||
timeout: float = None,
|
||||
settle_time: float = 0,
|
||||
run: bool = True,
|
||||
):
|
||||
self._signal = signal
|
||||
self._values = values
|
||||
if not isinstance(transitions, list):
|
||||
raise ValueError(f"Transitions must be a list of values. Received: {transitions}")
|
||||
self._transitions = transitions
|
||||
self._index = 0
|
||||
self._strict = strict
|
||||
self._raise_states = raise_states if raise_states else []
|
||||
super().__init__(
|
||||
device=signal,
|
||||
callback=self._compare_callback,
|
||||
@@ -94,13 +151,34 @@ class TargetStatus(SubscriptionStatus):
|
||||
run=run,
|
||||
)
|
||||
|
||||
def _compare_callback(self, value, **kwargs) -> bool:
|
||||
def _compare_callback(self, old_value, value, **kwargs) -> bool:
|
||||
"""Callback for subscription Status"""
|
||||
if value == self._values[0]:
|
||||
self._values.pop(0)
|
||||
if len(self._values) == 0:
|
||||
return True
|
||||
return False
|
||||
if value in self._raise_states:
|
||||
self.set_exception(
|
||||
ValueError(
|
||||
f"Transition raised an exception: {value}. "
|
||||
f"Expected transitions: {self._transitions}."
|
||||
)
|
||||
)
|
||||
return False
|
||||
if self._index == 0:
|
||||
if value == self._transitions[0]:
|
||||
self._index += 1
|
||||
else:
|
||||
if self._strict:
|
||||
if (
|
||||
old_value == self._transitions[self._index - 1]
|
||||
and value == self._transitions[self._index]
|
||||
):
|
||||
self._index += 1
|
||||
else:
|
||||
if value == self._transitions[self._index]:
|
||||
self._index += 1
|
||||
return self._is_finished()
|
||||
|
||||
def _is_finished(self) -> bool:
|
||||
"""Check if the status is finished"""
|
||||
return self._index >= len(self._transitions)
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
|
||||
Reference in New Issue
Block a user