fix: improved patching of Ophyd 1.9

This commit is contained in:
guijar_m 2024-10-18 16:09:47 +02:00
parent 76d5f24e84
commit 8a9a6a9910
2 changed files with 58 additions and 18 deletions

View File

@ -1,18 +1,22 @@
import inspect
import importlib
import importlib.util
import pathlib
import sys
from types import ModuleType
_patched_status_base = """
import threading
from unittest.mock import Mock, patch
from ophyd import status as ophyd_status_module
from ophyd.status import StatusBase
_StatusBase = StatusBase
dummy_thread = Mock(spec=threading.Thread)
class StatusBase(_StatusBase):
_bec_patched = True
class PatchedStatusBase(StatusBase):
def __init__(self, *args, **kwargs):
timeout = kwargs.get("timeout", None)
if not timeout:
with patch("threading.Thread", dummy_thread):
with patch("threading.Thread", Mock(spec=threading.Thread)):
super().__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
@ -35,15 +39,44 @@ class PatchedStatusBase(StatusBase):
if isinstance(self._callback_thread, Mock):
self._run_callbacks()
"""
class _CustomImporter:
def __init__(self):
origin = pathlib.Path(importlib.util.find_spec("ophyd").origin)
module_file = str(origin.parent / "status.py")
with open(module_file, "r") as source:
src = source.read()
before, _, after = src.partition("class StatusBase")
orig_status_base, _, final = after.partition("\nclass ")
self.patched_source = (
f"{before}class StatusBase{orig_status_base}{_patched_status_base}class {final}"
)
self.patched_code = compile(self.patched_source, module_file, "exec")
def find_module(self, fullname, path):
if fullname == "ophyd.status":
return self
return None
def load_module(self, fullname, module_dict=None):
"""Load and execute ophyd.status"""
status_module = ModuleType("ophyd.status")
status_module.__loader__ = self
status_module.__file__ = None
status_module.__name__ = fullname
exec(self.patched_code, status_module.__dict__)
sys.modules[fullname] = status_module
return status_module, True
def get_source(self, fullname):
return self.patched_source
def monkey_patch_ophyd():
if ophyd_status_module.StatusBase.__name__ == "PatchedStatusBase":
# prevent patching multiple times
return
for name, klass in inspect.getmembers(
ophyd_status_module, lambda x: inspect.isclass(x) and StatusBase in x.__mro__
):
mro = klass.mro()
bases = tuple(PatchedStatusBase if x is StatusBase else x for x in mro)
new_klass = type("Patched" + name, bases, {})
setattr(ophyd_status_module, name, new_klass)
sys.meta_path.insert(0, _CustomImporter())

View File

@ -3,12 +3,19 @@ import time
from unittest.mock import Mock
import pytest
from ophyd.status import StatusBase, StatusTimeoutError
import ophyd_devices # ensure we are patched
def test_ophyd_status_patch():
from ophyd.status import DeviceStatus, StatusBase, StatusTimeoutError
assert StatusBase._bec_patched
st = DeviceStatus(device="test")
assert st._bec_patched
assert isinstance(st, StatusBase)
cb = Mock()
st = StatusBase(timeout=1)