fix: improved patching of Ophyd 1.9
This commit is contained in:
parent
76d5f24e84
commit
8a9a6a9910
@ -1,18 +1,22 @@
|
|||||||
import inspect
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
_patched_status_base = """
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from ophyd import status as ophyd_status_module
|
_StatusBase = StatusBase
|
||||||
from ophyd.status import StatusBase
|
|
||||||
|
|
||||||
dummy_thread = Mock(spec=threading.Thread)
|
class StatusBase(_StatusBase):
|
||||||
|
_bec_patched = True
|
||||||
|
|
||||||
|
|
||||||
class PatchedStatusBase(StatusBase):
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
timeout = kwargs.get("timeout", None)
|
timeout = kwargs.get("timeout", None)
|
||||||
if not timeout:
|
if not timeout:
|
||||||
with patch("threading.Thread", dummy_thread):
|
with patch("threading.Thread", Mock(spec=threading.Thread)):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -35,15 +39,44 @@ class PatchedStatusBase(StatusBase):
|
|||||||
if isinstance(self._callback_thread, Mock):
|
if isinstance(self._callback_thread, Mock):
|
||||||
self._run_callbacks()
|
self._run_callbacks()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _CustomImporter:
|
||||||
|
def __init__(self):
|
||||||
|
origin = pathlib.Path(importlib.util.find_spec("ophyd").origin)
|
||||||
|
module_file = str(origin.parent / "status.py")
|
||||||
|
|
||||||
|
with open(module_file, "r") as source:
|
||||||
|
src = source.read()
|
||||||
|
before, _, after = src.partition("class StatusBase")
|
||||||
|
orig_status_base, _, final = after.partition("\nclass ")
|
||||||
|
|
||||||
|
self.patched_source = (
|
||||||
|
f"{before}class StatusBase{orig_status_base}{_patched_status_base}class {final}"
|
||||||
|
)
|
||||||
|
self.patched_code = compile(self.patched_source, module_file, "exec")
|
||||||
|
|
||||||
|
def find_module(self, fullname, path):
|
||||||
|
if fullname == "ophyd.status":
|
||||||
|
return self
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_module(self, fullname, module_dict=None):
|
||||||
|
"""Load and execute ophyd.status"""
|
||||||
|
status_module = ModuleType("ophyd.status")
|
||||||
|
status_module.__loader__ = self
|
||||||
|
status_module.__file__ = None
|
||||||
|
status_module.__name__ = fullname
|
||||||
|
|
||||||
|
exec(self.patched_code, status_module.__dict__)
|
||||||
|
sys.modules[fullname] = status_module
|
||||||
|
|
||||||
|
return status_module, True
|
||||||
|
|
||||||
|
def get_source(self, fullname):
|
||||||
|
return self.patched_source
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_ophyd():
|
def monkey_patch_ophyd():
|
||||||
if ophyd_status_module.StatusBase.__name__ == "PatchedStatusBase":
|
sys.meta_path.insert(0, _CustomImporter())
|
||||||
# prevent patching multiple times
|
|
||||||
return
|
|
||||||
for name, klass in inspect.getmembers(
|
|
||||||
ophyd_status_module, lambda x: inspect.isclass(x) and StatusBase in x.__mro__
|
|
||||||
):
|
|
||||||
mro = klass.mro()
|
|
||||||
bases = tuple(PatchedStatusBase if x is StatusBase else x for x in mro)
|
|
||||||
new_klass = type("Patched" + name, bases, {})
|
|
||||||
setattr(ophyd_status_module, name, new_klass)
|
|
||||||
|
@ -3,12 +3,19 @@ import time
|
|||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from ophyd.status import StatusBase, StatusTimeoutError
|
|
||||||
|
|
||||||
import ophyd_devices # ensure we are patched
|
import ophyd_devices # ensure we are patched
|
||||||
|
|
||||||
|
|
||||||
def test_ophyd_status_patch():
|
def test_ophyd_status_patch():
|
||||||
|
from ophyd.status import DeviceStatus, StatusBase, StatusTimeoutError
|
||||||
|
|
||||||
|
assert StatusBase._bec_patched
|
||||||
|
|
||||||
|
st = DeviceStatus(device="test")
|
||||||
|
assert st._bec_patched
|
||||||
|
assert isinstance(st, StatusBase)
|
||||||
|
|
||||||
cb = Mock()
|
cb = Mock()
|
||||||
|
|
||||||
st = StatusBase(timeout=1)
|
st = StatusBase(timeout=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user