diff --git a/ophyd_devices/interfaces/base_classes/psi_detector_base.py b/ophyd_devices/interfaces/base_classes/psi_detector_base.py index 79bef22..9cc8f20 100644 --- a/ophyd_devices/interfaces/base_classes/psi_detector_base.py +++ b/ophyd_devices/interfaces/base_classes/psi_detector_base.py @@ -5,6 +5,7 @@ We use composition with a custom prepare class to implement BL specific logic fo The beamlines need to inherit from the CustomDetectorMixing for their mixin classes.""" import os +import threading import time from bec_lib import messages @@ -75,7 +76,7 @@ class CustomDetectorMixin: This step should include stopping the detector and backend service. """ - def on_trigger(self) -> None: + def on_trigger(self) -> None | DeviceStatus: """ Specify actions to be executed upon receiving trigger signal. Return a DeviceStatus object or None @@ -88,7 +89,7 @@ class CustomDetectorMixin: Only use if needed, and it is recommended to keep this function as short/fast as possible. """ - def on_complete(self) -> None: + def on_complete(self) -> None | DeviceStatus: """ Specify actions to be executed when the scan is complete. @@ -152,6 +153,7 @@ class CustomDetectorMixin: >>> Example usage for EPICS PVs: >>> self.wait_for_signals(signal_conditions=[(self.acquiring.get, False)], timeout=5, interval=0.05, check_stopped=True, all_signals=True) """ + timer = 0 while True: checks = [ @@ -167,6 +169,88 @@ class CustomDetectorMixin: time.sleep(interval) timer += interval + def wait_with_status( + self, + signal_conditions: list[tuple], + timeout: float, + check_stopped: bool = False, + interval: float = 0.05, + all_signals: bool = False, + exception_on_timeout: Exception = TimeoutError("Timeout while waiting for signals"), + ) -> DeviceStatus: + """Utility function to wait for signals in a thread. + Returns a DevicesStatus object that resolves either to set_finished or set_exception. + The DeviceStatus is attached to the parent device, i.e. the detector object inheriting from PSIDetectorBase. + + Usage: + This function should be used to wait for signals to reach a certain condition, especially in the context of + on_trigger and on_complete. If it is not used, functions may block and slow down the performance of BEC. + It will return a DeviceStatus object that is to be returned from the function. Once the conditions are met, + the DeviceStatus will be set to set_finished in case of success or set_exception in case of a timeout or exception. + The exception can be specified with the exception_on_timeout argument. The default exception is a TimeoutError. + + Args: + signal_conditions (list[tuple]): tuple of executable calls for conditions (get_current_state, condition) to check + timeout (float): timeout in seconds + check_stopped (bool): True if stopped flag should be checked + interval (float): interval in seconds + all_signals (bool): True if all signals should be True, False if any signal should be True + exception_on_timeout (Exception): Exception to raise on timeout + + Returns: + DeviceStatus: DeviceStatus object that resolves either to set_finished or set_exception + """ + + status = DeviceStatus(self.parent) + + # utility function to wrap the wait_for_signals function + def wait_for_signals_wrapper( + status: DeviceStatus, + signal_conditions: list[tuple], + timeout: float, + check_stopped: bool, + interval: float, + all_signals: bool, + exception_on_timeout: Exception = TimeoutError("Timeout while waiting for signals"), + ): + """Convenient wrapper around wait_for_signals to set status based on the result. + + Args: + status (DeviceStatus): DeviceStatus object to be set + signal_conditions (list[tuple]): tuple of executable calls for conditions (get_current_state, condition) to check + timeout (float): timeout in seconds + check_stopped (bool): True if stopped flag should be checked + interval (float): interval in seconds + all_signals (bool): True if all signals should be True, False if any signal should be True + exception_on_timeout (Exception): Exception to raise on timeout + """ + try: + result = self.wait_for_signals( + signal_conditions, timeout, check_stopped, interval, all_signals + ) + if result: + status.set_finished() + else: + status.set_exception(exception_on_timeout) + except Exception as exc: + status.set_exception(exc=exc) + + thread = threading.Thread( + target=wait_for_signals_wrapper, + args=( + status, + signal_conditions, + timeout, + check_stopped, + interval, + all_signals, + exception_on_timeout, + ), + daemon=True, + ) + thread.start() + return status + class PSIDetectorBase(Device): """ @@ -281,7 +365,10 @@ class PSIDetectorBase(Device): def trigger(self) -> DeviceStatus: """Trigger the detector, called from BEC.""" - self.custom_prepare.on_trigger() + # pylint: disable=assignment-from-no-return + status = self.custom_prepare.on_trigger() + if isinstance(status, DeviceStatus): + return status return super().trigger() def complete(self) -> None: @@ -292,8 +379,11 @@ class PSIDetectorBase(Device): Actions are implemented in custom_prepare.on_complete since they are beamline specific. """ + # pylint: disable=assignment-from-no-return + status = self.custom_prepare.on_complete() + if isinstance(status, DeviceStatus): + return status status = DeviceStatus(self) - self.custom_prepare.on_complete() status.set_finished() return status diff --git a/tests/test_base_classes.py b/tests/test_base_classes.py index 80740a8..69686e8 100644 --- a/tests/test_base_classes.py +++ b/tests/test_base_classes.py @@ -1,4 +1,5 @@ # pylint: skip-file +import time from unittest import mock import pytest @@ -49,10 +50,17 @@ def test_pre_scan(detector_base): def test_trigger(detector_base): - with mock.patch.object(detector_base.custom_prepare, "on_trigger") as mock_on_trigger: - rtr = detector_base.trigger() - assert isinstance(rtr, DeviceStatus) - mock_on_trigger.assert_called_once() + status = DeviceStatus(detector_base) + with mock.patch.object( + detector_base.custom_prepare, "on_trigger", side_effect=[None, status] + ) as mock_on_trigger: + st = detector_base.trigger() + assert isinstance(st, DeviceStatus) + time.sleep(0.1) + assert st.done is True + st = detector_base.trigger() + assert st.done is False + assert id(st) == id(status) def test_unstage(detector_base): @@ -74,9 +82,17 @@ def test_unstage(detector_base): def test_complete(detector_base): - with mock.patch.object(detector_base.custom_prepare, "on_complete") as mock_on_complete: - detector_base.complete() - mock_on_complete.assert_called_once() + status = DeviceStatus(detector_base) + with mock.patch.object( + detector_base.custom_prepare, "on_complete", side_effect=[None, status] + ) as mock_on_complete: + st = detector_base.complete() + assert isinstance(st, DeviceStatus) + time.sleep(0.1) + assert st.done is True + st = detector_base.complete() + assert st.done is False + assert id(st) == id(status) def test_stop(detector_base): @@ -94,3 +110,38 @@ def test_check_scan_id(detector_base): detector_base.stopped = False detector_base.check_scan_id() assert detector_base.stopped is False + + +def test_wait_for_signal(detector_base): + expected_value = "test" + exception = TimeoutError("Timeout") + status = detector_base.custom_prepare.wait_with_status( + [(detector_base.filepath.get, expected_value)], + check_stopped=True, + timeout=5, + interval=0.01, + exception_on_timeout=exception, + ) + time.sleep(0.1) + assert status.done is False + # Check first that it is stopped when detector_base.stop() is called + detector_base.stop() + # some delay to allow the stop to take effect + time.sleep(0.15) + assert status.done is True + assert id(status.exception()) == id(exception) + detector_base.stopped = False + status = detector_base.custom_prepare.wait_with_status( + [(detector_base.filepath.get, expected_value)], + check_stopped=True, + timeout=5, + interval=0.01, + exception_on_timeout=exception, + ) + # Check that thread resolves when expected value is set + detector_base.filepath.set(expected_value) + # some delay to allow the stop to take effect + time.sleep(0.15) + assert status.done is True + assert status.success is True + assert status.exception() is None