From cb7f7ba932b372b60827b24f4e1e0234cd64026b Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sun, 15 Jun 2025 13:17:02 +0200 Subject: [PATCH] feat(psi device base): stoppable status objects Add methods to PSIDeviceBase to register status object that should be cancelled when the device is stopped or destroyed. --- .../base_classes/psi_device_base.py | 41 +++++++++++++++++++ tests/test_psi_device_base.py | 30 ++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/ophyd_devices/interfaces/base_classes/psi_device_base.py b/ophyd_devices/interfaces/base_classes/psi_device_base.py index 8fa6b46..216cb64 100644 --- a/ophyd_devices/interfaces/base_classes/psi_device_base.py +++ b/ophyd_devices/interfaces/base_classes/psi_device_base.py @@ -64,6 +64,7 @@ class PSIDeviceBase(Device): else: super().__init__(prefix=prefix, name=name, **kwargs) self._stopped = False + self._stoppable_status_objects: list[StatusBase] = [] self.task_handler = TaskHandler(parent=self) self.file_utils = FileHandler() if scan_info is None: @@ -113,6 +114,7 @@ class PSIDeviceBase(Device): """Unstage the device.""" super_unstage = super().unstage() status = self.on_unstage() # pylint: disable=assignment-from-no-return + self._stop_stoppable_status_objects() if isinstance(status, StatusBase): return status return super_unstage @@ -154,13 +156,52 @@ class PSIDeviceBase(Device): """ self.on_stop() self.stopped = True # Set stopped flag to True, in case a custom stop method listens to stopped property + # Stop all stoppable status objects + self._stop_stoppable_status_objects() super().stop(success=success) def destroy(self): """Destroy the device.""" self.on_destroy() # Call the on_destroy method + self._stop_stoppable_status_objects() return super().destroy() + ######################################## + # Stoppable Status Objects Management # + ######################################## + + def cancel_on_stop(self, status: StatusBase) -> None: + """ + Register a status object to be cancelled when the device is stopped. + + Args: + status (StatusBase): The status object to be cancelled. + """ + if not isinstance(status, StatusBase): + raise TypeError("status must be an instance of StatusBase") + self._stoppable_status_objects.append(status) + + def _clear_stoppable_status_objects(self) -> None: + """ + Clear all registered stoppable status objects. + + This is useful to reset the list of status objects that should be cancelled + when the device is stopped. + """ + self._stoppable_status_objects = [] + + def _stop_stoppable_status_objects(self) -> None: + """ + Stop all registered stoppable status objects. + + This method will cancel all status objects that have been registered + to be stopped when the device is stopped. + """ + for status in self._stoppable_status_objects: + if not status.done: + status.set_exception(DeviceStoppedError(f"Device {self.name} has been stopped")) + self._clear_stoppable_status_objects() + ######################################## # Utility Method to wait for signals # ######################################## diff --git a/tests/test_psi_device_base.py b/tests/test_psi_device_base.py index becfad9..ab1c388 100644 --- a/tests/test_psi_device_base.py +++ b/tests/test_psi_device_base.py @@ -1,5 +1,7 @@ """Module for testing the PSIDeviceBase class.""" +import threading +import time from unittest import mock import pytest @@ -146,3 +148,31 @@ def test_on_stop_hook(device): with mock.patch.object(device, "on_stop") as mock_on_stop: device.stop() mock_on_stop.assert_called_once() + + +def test_stoppable_status(device): + """Test stoppable status""" + status = StatusBase() + device.cancel_on_stop(status) + device.stop() + assert status.done is True + assert status.success is False + + +def test_stoppable_status_not_done(device): + """Test stoppable status not done""" + + def stop_after_delay(): + time.sleep(5) + device.stop() + + status = StatusBase() + device.cancel_on_stop(status) + thread = threading.Thread(target=stop_after_delay) + thread.start() + + with pytest.raises(DeviceStoppedError, match="Device device has been stopped"): + status.wait() + + assert status.done is True + assert status.success is False