From d092b8b51afb8b8f4aad5f98e8b6372bed546cce Mon Sep 17 00:00:00 2001 From: appel_c Date: Mon, 16 Jun 2025 15:17:43 +0200 Subject: [PATCH] feat: add custom status, CompareStatus and TargetStatus for easier signal value comparison --- ophyd_devices/__init__.py | 1 + ophyd_devices/utils/psi_device_base_utils.py | 83 +++++++++++++++++++- tests/test_utils.py | 45 ++++++++++- 3 files changed, 126 insertions(+), 3 deletions(-) diff --git a/ophyd_devices/__init__.py b/ophyd_devices/__init__.py index 8c863e4..2aa07f2 100644 --- a/ophyd_devices/__init__.py +++ b/ophyd_devices/__init__.py @@ -20,4 +20,5 @@ from .devices.softpositioner import SoftPositioner from .utils.bec_device_base import BECDeviceBase from .utils.bec_signals import * from .utils.dynamic_pseudo import ComputedSignal +from .utils.psi_device_base_utils import CompareStatus, TargetStatus from .utils.static_device_test import launch diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 97dde80..cdbed6f 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -1,16 +1,18 @@ """Utility handler to run tasks (function, conditions) in an asynchronous fashion.""" import ctypes +import operator import threading import traceback import uuid from enum import Enum -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from -from ophyd import Device, DeviceStatus +from ophyd import Device, DeviceStatus, Signal +from ophyd.status import SubscriptionStatus if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import ScanStatusMessage @@ -23,6 +25,83 @@ logger = bec_logger.logger set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc +OP_MAP = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, +} + + +class CompareStatus(SubscriptionStatus): + """Status class to compare a value from a device signal with a target value.""" + + def __init__( + self, + signal: Signal, + value: Any, + *, + operation: Literal["==", "!=", "<", "<=", ">", ">="] = "==", + event_type=None, + timeout: float = None, + settle_time: float = 0, + run: bool = True, + ): + if operation not in ("==", "!=", "<", "<=", ">", ">="): + raise ValueError( + f"Invalid operation: {operation}. Must be one of '==', '!=', '<', '<=', '>', '>='." + ) + self._signal = signal + self._value = value + self._operation = operation + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, value, **kwargs) -> bool: + """Callback for subscription Status""" + return OP_MAP[self._operation](value, self._value) + + +class TargetStatus(SubscriptionStatus): + """Status class to compare a list of values that are expected to be reached in sequence for a device signal.""" + + def __init__( + self, + signal: Signal, + values: list[Any], + *, + event_type=None, + timeout: float = None, + settle_time: float = 0, + run: bool = True, + ): + self._signal = signal + self._values = values + super().__init__( + device=signal, + callback=self._compare_callback, + timeout=timeout, + settle_time=settle_time, + event_type=event_type, + run=run, + ) + + def _compare_callback(self, value, **kwargs) -> bool: + """Callback for subscription Status""" + if value == self._values[0]: + self._values.pop(0) + if len(self._values) == 0: + return True + return False + class TaskState(str, Enum): """Possible task states""" diff --git a/tests/test_utils.py b/tests/test_utils.py index b22695c..1c2c883 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,7 +4,7 @@ import time import numpy as np import pytest from bec_lib import messages -from ophyd import Device +from ophyd import Device, Signal from ophyd_devices.utils.bec_signals import ( BECMessageSignal, @@ -14,7 +14,9 @@ from ophyd_devices.utils.bec_signals import ( ProgressSignal, ) from ophyd_devices.utils.psi_device_base_utils import ( + CompareStatus, FileHandler, + TargetStatus, TaskHandler, TaskKilledError, TaskState, @@ -525,3 +527,44 @@ def test_utils_progress_signal(): # Put fails with wrong dict with pytest.raises(ValueError): signal.put({"wrong_key": "wrong_value"}) + + +def test_utils_target_status(): + """Test TargetStatus""" + sig = Signal(name="test_signal", value=0) + status = TargetStatus(signal=sig, values=[1, 2, 3]) + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(2) + assert status.done is False + sig.put(3) + assert status.done is True + + +def test_utils_compare_status(): + """Test CompareStatus""" + sig = Signal(name="test_signal", value=0) + status = CompareStatus(signal=sig, value=5, operation="==") + assert status.done is False + sig.put(1) + assert status.done is False + sig.put(5) + assert status.done is True + + sig.put(5) + # Test with different operations + status = CompareStatus(signal=sig, value=5, operation="!=") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(6) + assert status.done is True + + sig.put(0) + status = CompareStatus(signal=sig, value=5, operation=">") + assert status.done is False + sig.put(5) + assert status.done is False + sig.put(10) + assert status.done is True