fix: improved patching of Ophyd 1.9
This commit is contained in:
parent
76d5f24e84
commit
8a9a6a9910
@ -1,18 +1,22 @@
|
||||
import inspect
|
||||
import importlib
|
||||
import importlib.util
|
||||
import pathlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
_patched_status_base = """
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from ophyd import status as ophyd_status_module
|
||||
from ophyd.status import StatusBase
|
||||
_StatusBase = StatusBase
|
||||
|
||||
dummy_thread = Mock(spec=threading.Thread)
|
||||
class StatusBase(_StatusBase):
|
||||
_bec_patched = True
|
||||
|
||||
|
||||
class PatchedStatusBase(StatusBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
timeout = kwargs.get("timeout", None)
|
||||
if not timeout:
|
||||
with patch("threading.Thread", dummy_thread):
|
||||
with patch("threading.Thread", Mock(spec=threading.Thread)):
|
||||
super().__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -35,15 +39,44 @@ class PatchedStatusBase(StatusBase):
|
||||
if isinstance(self._callback_thread, Mock):
|
||||
self._run_callbacks()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class _CustomImporter:
|
||||
def __init__(self):
|
||||
origin = pathlib.Path(importlib.util.find_spec("ophyd").origin)
|
||||
module_file = str(origin.parent / "status.py")
|
||||
|
||||
with open(module_file, "r") as source:
|
||||
src = source.read()
|
||||
before, _, after = src.partition("class StatusBase")
|
||||
orig_status_base, _, final = after.partition("\nclass ")
|
||||
|
||||
self.patched_source = (
|
||||
f"{before}class StatusBase{orig_status_base}{_patched_status_base}class {final}"
|
||||
)
|
||||
self.patched_code = compile(self.patched_source, module_file, "exec")
|
||||
|
||||
def find_module(self, fullname, path):
|
||||
if fullname == "ophyd.status":
|
||||
return self
|
||||
return None
|
||||
|
||||
def load_module(self, fullname, module_dict=None):
|
||||
"""Load and execute ophyd.status"""
|
||||
status_module = ModuleType("ophyd.status")
|
||||
status_module.__loader__ = self
|
||||
status_module.__file__ = None
|
||||
status_module.__name__ = fullname
|
||||
|
||||
exec(self.patched_code, status_module.__dict__)
|
||||
sys.modules[fullname] = status_module
|
||||
|
||||
return status_module, True
|
||||
|
||||
def get_source(self, fullname):
|
||||
return self.patched_source
|
||||
|
||||
|
||||
def monkey_patch_ophyd():
|
||||
if ophyd_status_module.StatusBase.__name__ == "PatchedStatusBase":
|
||||
# prevent patching multiple times
|
||||
return
|
||||
for name, klass in inspect.getmembers(
|
||||
ophyd_status_module, lambda x: inspect.isclass(x) and StatusBase in x.__mro__
|
||||
):
|
||||
mro = klass.mro()
|
||||
bases = tuple(PatchedStatusBase if x is StatusBase else x for x in mro)
|
||||
new_klass = type("Patched" + name, bases, {})
|
||||
setattr(ophyd_status_module, name, new_klass)
|
||||
sys.meta_path.insert(0, _CustomImporter())
|
||||
|
@ -3,12 +3,19 @@ import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from ophyd.status import StatusBase, StatusTimeoutError
|
||||
|
||||
import ophyd_devices # ensure we are patched
|
||||
|
||||
|
||||
def test_ophyd_status_patch():
|
||||
from ophyd.status import DeviceStatus, StatusBase, StatusTimeoutError
|
||||
|
||||
assert StatusBase._bec_patched
|
||||
|
||||
st = DeviceStatus(device="test")
|
||||
assert st._bec_patched
|
||||
assert isinstance(st, StatusBase)
|
||||
|
||||
cb = Mock()
|
||||
|
||||
st = StatusBase(timeout=1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user