From 8a9a6a9910b44d55412e80443f145d629b1cfc2f Mon Sep 17 00:00:00 2001 From: Mathias Guijarro Date: Fri, 18 Oct 2024 16:09:47 +0200 Subject: [PATCH] fix: improved patching of Ophyd 1.9 --- ophyd_devices/ophyd_patch.py | 67 +++++++++++++++++++++++++--------- tests/test_ophyd_status_obj.py | 9 ++++- 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/ophyd_devices/ophyd_patch.py b/ophyd_devices/ophyd_patch.py index b30df25..31526d3 100644 --- a/ophyd_devices/ophyd_patch.py +++ b/ophyd_devices/ophyd_patch.py @@ -1,18 +1,22 @@ -import inspect +import importlib +import importlib.util +import pathlib +import sys +from types import ModuleType + +_patched_status_base = """ import threading from unittest.mock import Mock, patch -from ophyd import status as ophyd_status_module -from ophyd.status import StatusBase +_StatusBase = StatusBase -dummy_thread = Mock(spec=threading.Thread) +class StatusBase(_StatusBase): + _bec_patched = True - -class PatchedStatusBase(StatusBase): def __init__(self, *args, **kwargs): timeout = kwargs.get("timeout", None) if not timeout: - with patch("threading.Thread", dummy_thread): + with patch("threading.Thread", Mock(spec=threading.Thread)): super().__init__(*args, **kwargs) else: super().__init__(*args, **kwargs) @@ -35,15 +39,44 @@ class PatchedStatusBase(StatusBase): if isinstance(self._callback_thread, Mock): self._run_callbacks() +""" + + +class _CustomImporter: + def __init__(self): + origin = pathlib.Path(importlib.util.find_spec("ophyd").origin) + module_file = str(origin.parent / "status.py") + + with open(module_file, "r") as source: + src = source.read() + before, _, after = src.partition("class StatusBase") + orig_status_base, _, final = after.partition("\nclass ") + + self.patched_source = ( + f"{before}class StatusBase{orig_status_base}{_patched_status_base}class {final}" + ) + self.patched_code = compile(self.patched_source, module_file, "exec") + + def find_module(self, fullname, path): + if fullname == "ophyd.status": + return self + return None + + def load_module(self, fullname, module_dict=None): + """Load and execute ophyd.status""" + status_module = ModuleType("ophyd.status") + status_module.__loader__ = self + status_module.__file__ = None + status_module.__name__ = fullname + + exec(self.patched_code, status_module.__dict__) + sys.modules[fullname] = status_module + + return status_module, True + + def get_source(self, fullname): + return self.patched_source + def monkey_patch_ophyd(): - if ophyd_status_module.StatusBase.__name__ == "PatchedStatusBase": - # prevent patching multiple times - return - for name, klass in inspect.getmembers( - ophyd_status_module, lambda x: inspect.isclass(x) and StatusBase in x.__mro__ - ): - mro = klass.mro() - bases = tuple(PatchedStatusBase if x is StatusBase else x for x in mro) - new_klass = type("Patched" + name, bases, {}) - setattr(ophyd_status_module, name, new_klass) + sys.meta_path.insert(0, _CustomImporter()) diff --git a/tests/test_ophyd_status_obj.py b/tests/test_ophyd_status_obj.py index 6acb6c6..218ebf1 100644 --- a/tests/test_ophyd_status_obj.py +++ b/tests/test_ophyd_status_obj.py @@ -3,12 +3,19 @@ import time from unittest.mock import Mock import pytest -from ophyd.status import StatusBase, StatusTimeoutError import ophyd_devices # ensure we are patched def test_ophyd_status_patch(): + from ophyd.status import DeviceStatus, StatusBase, StatusTimeoutError + + assert StatusBase._bec_patched + + st = DeviceStatus(device="test") + assert st._bec_patched + assert isinstance(st, StatusBase) + cb = Mock() st = StatusBase(timeout=1)