From df8ce79ca0606ad415f45cfd5d80b057aec107d9 Mon Sep 17 00:00:00 2001 From: Mathias Guijarro Date: Mon, 18 Mar 2024 15:32:12 +0100 Subject: [PATCH] feat(ophyd): temporary until new Ophyd release, prevent Status objects threads Monkey-patching of Ophyd library --- ophyd_devices/__init__.py | 4 +++ ophyd_devices/ophyd_patch.py | 50 +++++++++++++++++++++++++++++++++ tests/test_ophyd_status_obj.py | 51 ++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 ophyd_devices/ophyd_patch.py create mode 100644 tests/test_ophyd_status_obj.py diff --git a/ophyd_devices/__init__.py b/ophyd_devices/__init__.py index 4fc1976..e89b012 100644 --- a/ophyd_devices/__init__.py +++ b/ophyd_devices/__init__.py @@ -1,3 +1,7 @@ +from .ophyd_patch import monkey_patch_ophyd + +monkey_patch_ophyd() + from .eiger1p5m_csaxs.eiger1p5m import Eiger1p5MDetector from .epics import * from .galil.fgalil_ophyd import FlomniGalilMotor diff --git a/ophyd_devices/ophyd_patch.py b/ophyd_devices/ophyd_patch.py new file mode 100644 index 0000000..054289b --- /dev/null +++ b/ophyd_devices/ophyd_patch.py @@ -0,0 +1,50 @@ +import inspect +import threading +import types + +from ophyd import status as ophyd_status_module +from ophyd.status import StatusBase +from unittest.mock import patch, Mock + +dummy_thread = Mock(spec=threading.Thread) + + +class PatchedStatusBase(StatusBase): + def __init__(self, *args, **kwargs): + timeout = kwargs.get("timeout", None) + if not timeout: + with patch("threading.Thread", dummy_thread): + super().__init__(*args, **kwargs) + else: + super().__init__(*args, **kwargs) + + def set_finished(self, *args, **kwargs): + super().set_finished(*args, **kwargs) + if isinstance(self._callback_thread, Mock): + if self.settle_time > 0: + + def settle_done(): + self._settled_event.set() + self._run_callbacks() + + threading.Timer(self.settle_time, settle_done).start() + else: + self._run_callbacks() + + def set_exception(self, *args, **kwargs): + super().set_exception(*args, **kwargs) + if isinstance(self._callback_thread, Mock): + self._run_callbacks() + + +def monkey_patch_ophyd(): + if ophyd_status_module.StatusBase.__name__ == "PatchedStatusBase": + # prevent patching multiple times + return + for name, klass in inspect.getmembers( + ophyd_status_module, lambda x: inspect.isclass(x) and StatusBase in x.__mro__ + ): + mro = klass.mro() + bases = tuple(PatchedStatusBase if x is StatusBase else x for x in mro) + new_klass = type("Patched" + name, bases, {}) + setattr(ophyd_status_module, name, new_klass) diff --git a/tests/test_ophyd_status_obj.py b/tests/test_ophyd_status_obj.py new file mode 100644 index 0000000..7a86176 --- /dev/null +++ b/tests/test_ophyd_status_obj.py @@ -0,0 +1,51 @@ +import pytest +import threading +import time + +from unittest.mock import Mock + +import ophyd_devices # ensure we are patched +from ophyd.status import StatusBase, StatusTimeoutError + + +def test_ophyd_status_patch(): + cb = Mock() + + st = StatusBase(timeout=1) + assert isinstance(st._callback_thread, threading.Thread) + st.add_callback(cb) + with pytest.raises(StatusTimeoutError): + time.sleep(1.1) + st.wait() + cb.assert_called_once() + cb.reset_mock() + + st = StatusBase() + assert isinstance(st._callback_thread, Mock) + st.add_callback(cb) + st.set_finished() + cb.assert_called_once() + cb.reset_mock() + st.wait() + + st = StatusBase(settle_time=1) + st.add_callback(cb) + assert isinstance(st._callback_thread, Mock) + st.set_finished() + assert cb.call_count == 0 + time.sleep(0.5) + assert cb.call_count == 0 # not yet! + time.sleep(0.6) + cb.assert_called_once() + cb.reset_mock() + st.wait() + + class TestException(RuntimeError): + pass + + st = StatusBase() + st.add_callback(cb) + st.set_exception(TestException()) + cb.assert_called_once() + with pytest.raises(TestException): + st.wait()