fix: improve device mocking for tests

This commit is contained in:
2025-11-17 13:02:59 +01:00
committed by David Perl
parent 011b68f3dc
commit eceab997b8
5 changed files with 47 additions and 87 deletions

View File

@@ -1,10 +1,12 @@
"""Utilities to mock and test devices.""" """Utilities to mock and test devices."""
import threading import threading
from contextlib import contextmanager
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, Generator, TypeVar
from unittest import mock from unittest import mock
import ophyd
from bec_lib.devicemanager import ScanInfo from bec_lib.devicemanager import ScanInfo
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.utils.import_utils import lazy_import_from from bec_lib.utils.import_utils import lazy_import_from
@@ -19,6 +21,29 @@ else:
logger = bec_logger.logger logger = bec_logger.logger
T = TypeVar("T", bound=Device)
@contextmanager
def patched_device(device_type: type[T], *args, **kwargs) -> Generator[T, None, None]:
"""Context manager to yield a patched ophyd device with certain initialisation args.
*args and **kwargs are passed directly through to the device constructor.
Example:
@pytest.fixture(scope="function")
def mock_ddg():
with patched_device(DelayGenerator, name="ddg", prefix="X12SA-CPCL-DDG3:") as ddg:
yield ddg
"""
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
device = device_type(*args, **kwargs)
patch_dual_pvs(device)
patch_functions_required_for_connection(device)
yield device
def patch_dual_pvs(device): def patch_dual_pvs(device):
"""Patch dual PVs""" """Patch dual PVs"""

View File

@@ -1,7 +1,3 @@
import threading
from unittest import mock
import ophyd
import pytest import pytest
from ophyd_devices.devices.delay_generator_645 import ( from ophyd_devices.devices.delay_generator_645 import (
@@ -9,23 +5,12 @@ from ophyd_devices.devices.delay_generator_645 import (
DelayGeneratorError, DelayGeneratorError,
TriggerSource, TriggerSource,
) )
from ophyd_devices.tests.utils import ( from ophyd_devices.tests.utils import patched_device
MockPV,
patch_dual_pvs,
patch_functions_required_for_connection,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_ddg(): def mock_ddg():
name = "ddg" with patched_device(DelayGenerator, name="ddg", prefix="X12SA-CPCL-DDG3:") as ddg:
prefix = "X12SA-CPCL-DDG3:"
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
ddg = DelayGenerator(name=name, prefix=prefix)
patch_functions_required_for_connection(ddg)
patch_dual_pvs(ddg)
yield ddg yield ddg

View File

@@ -3,10 +3,6 @@ Test module for DXP integration, i.e. Falcon, XMAP and Mercury detectors.
This also includes EpicsMCARecord for data recording of multichannel analyzers. This also includes EpicsMCARecord for data recording of multichannel analyzers.
""" """
import threading
from unittest import mock
import ophyd
import pytest import pytest
from ophyd import Component as Cpt from ophyd import Component as Cpt
@@ -21,7 +17,7 @@ from ophyd_devices.devices.dxp import (
Mercury, Mercury,
xMAP, xMAP,
) )
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs from ophyd_devices.tests.utils import patched_device
# from ophyd.mca import EpicsDXPMapping # from ophyd.mca import EpicsDXPMapping
@@ -40,46 +36,27 @@ class TestFalcon(Falcon):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_falcon(): def mock_falcon():
"""Fixture to create a mock Falcon device for testing.""" """Fixture to create a mock Falcon device for testing."""
name = "mca" with patched_device(TestFalcon, name="mca", prefix="test_falcon") as falc:
prefix = "test_falcon" yield falc
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
ddg = TestFalcon(name=name, prefix=prefix)
patch_dual_pvs(ddg)
yield ddg
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_xmap(): def mock_xmap():
"""Fixture to create a mock xMAP device for testing.""" """Fixture to create a mock xMAP device for testing."""
name = "mca" with patched_device(xMAP, name="mca", prefix="test_xmap") as xmap:
prefix = "test_xmap" yield xmap
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
ddg = xMAP(name=name, prefix=prefix)
patch_dual_pvs(ddg)
yield ddg
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_mercury(): def mock_mercury():
"""Fixture to create a mock Mercury device for testing.""" """Fixture to create a mock Mercury device for testing."""
name = "mca" with patched_device(Mercury, name="mca", prefix="test_mercury") as merc:
prefix = "test_mercury" yield merc
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
ddg = Mercury(name=name, prefix=prefix)
patch_dual_pvs(ddg)
yield ddg
def test_falcon(mock_falcon): def test_falcon(mock_falcon: TestFalcon):
"""Test the Falcon device.""" """Test the Falcon device."""
# Test the default values # Test the default values
mock_falcon: TestFalcon
assert mock_falcon.name == "mca" assert mock_falcon.name == "mca"
assert mock_falcon.prefix == "test_falcon" assert mock_falcon.prefix == "test_falcon"
assert isinstance(mock_falcon, EpicsDXPFalconMultiElementSystem) assert isinstance(mock_falcon, EpicsDXPFalconMultiElementSystem)
@@ -99,9 +76,8 @@ def test_falcon(mock_falcon):
] ]
def test_falcon_trigger(mock_falcon): def test_falcon_trigger(mock_falcon: TestFalcon):
"""Test the Falcon device trigger method.""" """Test the Falcon device trigger method."""
mock_falcon: TestFalcon
mock_falcon.erase_start.put(0) mock_falcon.erase_start.put(0)
assert mock_falcon.erase_start.get() == 0 assert mock_falcon.erase_start.get() == 0
status = mock_falcon.trigger() status = mock_falcon.trigger()
@@ -111,20 +87,18 @@ def test_falcon_trigger(mock_falcon):
assert status.done is True assert status.done is True
def test_xmap(mock_xmap): def test_xmap(mock_xmap: xMAP):
"""Test the xMAP device.""" """Test the xMAP device."""
# Test the default values # Test the default values
mock_xmap: xMAP
assert mock_xmap.name == "mca" assert mock_xmap.name == "mca"
assert mock_xmap.prefix == "test_xmap" assert mock_xmap.prefix == "test_xmap"
assert isinstance(mock_xmap, EpicsDXPMultiElementSystem) assert isinstance(mock_xmap, EpicsDXPMultiElementSystem)
assert isinstance(mock_xmap, ADBase) assert isinstance(mock_xmap, ADBase)
def test_mercury(mock_mercury): def test_mercury(mock_mercury: Mercury):
"""Test the Mercury device.""" """Test the Mercury device."""
# Test the default values # Test the default values
mock_mercury: Mercury
assert mock_mercury.name == "mca" assert mock_mercury.name == "mca"
assert mock_mercury.prefix == "test_mercury" assert mock_mercury.prefix == "test_mercury"
assert isinstance(mock_mercury, EpicsDXPMultiElementSystem) assert isinstance(mock_mercury, EpicsDXPMultiElementSystem)
@@ -132,9 +106,8 @@ def test_mercury(mock_mercury):
assert isinstance(mock_mercury, ADBase) assert isinstance(mock_mercury, ADBase)
def test_xmap_trigger(mock_xmap): def test_xmap_trigger(mock_xmap: xMAP):
"""Test the xMAP device trigger method.""" """Test the xMAP device trigger method."""
mock_xmap: xMAP
mock_xmap.erase_start.put(0) mock_xmap.erase_start.put(0)
assert mock_xmap.erase_start.get() == 0 assert mock_xmap.erase_start.get() == 0
status = mock_xmap.trigger() status = mock_xmap.trigger()

View File

@@ -3,42 +3,27 @@ PSI motor integration from the ophyd_devices.devices.psi_motor module."""
from __future__ import annotations from __future__ import annotations
import threading
from unittest import mock from unittest import mock
import ophyd import ophyd
import pytest import pytest
from ophyd_devices.devices.psi_motor import EpicsMotor, EpicsMotorEC, SpmgStates from ophyd_devices.devices.psi_motor import EpicsMotor, EpicsMotorEC, SpmgStates
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs from ophyd_devices.tests.utils import patched_device
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_epics_motor(): def mock_epics_motor():
"""Fixture to create a mock EpicsMotor instance.""" """Fixture to create a mock EpicsMotor instance."""
name = "test_motor" with patched_device(EpicsMotor, name="test_motor", prefix="SIM:MOTOR") as motor:
prefix = "SIM:MOTOR"
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
motor = EpicsMotor(name=name, prefix=prefix)
motor.wait_for_connection()
patch_dual_pvs(motor)
yield motor yield motor
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_epics_motor_ec(): def mock_epics_motor_ec():
"""Fixture to create a mock EpicsMotorEC instance.""" """Fixture to create a mock EpicsMotorEC instance."""
name = "test_motor_ec" with patched_device(EpicsMotorEC, name="test_motor_ec", prefix="SIM:MOTOR:EC") as motor:
prefix = "SIM:MOTOR:EC" yield motor
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
motor_ec = EpicsMotorEC(name=name, prefix=prefix)
motor_ec.wait_for_connection()
patch_dual_pvs(motor_ec)
yield motor_ec
def test_epics_motor_limits_raise(mock_epics_motor): def test_epics_motor_limits_raise(mock_epics_motor):

View File

@@ -1,23 +1,15 @@
import threading
from unittest import mock from unittest import mock
import ophyd
import pytest import pytest
from ophyd_devices.devices.undulator import UndulatorGap from ophyd_devices.devices.undulator import UndulatorGap
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs from ophyd_devices.tests.utils import patched_device
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_undulator(): def mock_undulator():
name = "undulator" with patched_device(UndulatorGap, name="undulator", prefix="TEST:UNDULATOR") as und:
prefix = "TEST:UNDULATOR" yield und
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
dev = UndulatorGap(name=name, prefix=prefix)
patch_dual_pvs(dev)
yield dev
@pytest.mark.parametrize( @pytest.mark.parametrize(