From eceab997b82b248328bf5192010916f3e2905988 Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 17 Nov 2025 13:02:59 +0100 Subject: [PATCH] fix: improve device mocking for tests --- ophyd_devices/tests/utils.py | 27 ++++++++++++++++++- tests/test_delay_generator.py | 19 ++----------- tests/test_dxp/test_dxp.py | 51 +++++++++-------------------------- tests/test_psi_motors.py | 23 +++------------- tests/test_undulator.py | 14 +++------- 5 files changed, 47 insertions(+), 87 deletions(-) diff --git a/ophyd_devices/tests/utils.py b/ophyd_devices/tests/utils.py index 23afc80..81707f6 100644 --- a/ophyd_devices/tests/utils.py +++ b/ophyd_devices/tests/utils.py @@ -1,10 +1,12 @@ """Utilities to mock and test devices.""" import threading +from contextlib import contextmanager from time import sleep -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Generator, TypeVar from unittest import mock +import ophyd from bec_lib.devicemanager import ScanInfo from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from @@ -19,6 +21,29 @@ else: logger = bec_logger.logger +T = TypeVar("T", bound=Device) + + +@contextmanager +def patched_device(device_type: type[T], *args, **kwargs) -> Generator[T, None, None]: + """Context manager to yield a patched ophyd device with certain initialisation args. + *args and **kwargs are passed directly through to the device constructor. + + Example: + @pytest.fixture(scope="function") + def mock_ddg(): + with patched_device(DelayGenerator, name="ddg", prefix="X12SA-CPCL-DDG3:") as ddg: + yield ddg + + """ + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + device = device_type(*args, **kwargs) + patch_dual_pvs(device) + patch_functions_required_for_connection(device) + yield device + def patch_dual_pvs(device): """Patch dual PVs""" diff --git a/tests/test_delay_generator.py b/tests/test_delay_generator.py index ca441b6..5ebdf05 100644 --- a/tests/test_delay_generator.py +++ b/tests/test_delay_generator.py @@ -1,7 +1,3 @@ -import threading -from unittest import mock - -import ophyd import pytest from ophyd_devices.devices.delay_generator_645 import ( @@ -9,23 +5,12 @@ from ophyd_devices.devices.delay_generator_645 import ( DelayGeneratorError, TriggerSource, ) -from ophyd_devices.tests.utils import ( - MockPV, - patch_dual_pvs, - patch_functions_required_for_connection, -) +from ophyd_devices.tests.utils import patched_device @pytest.fixture(scope="function") def mock_ddg(): - name = "ddg" - prefix = "X12SA-CPCL-DDG3:" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - ddg = DelayGenerator(name=name, prefix=prefix) - patch_functions_required_for_connection(ddg) - patch_dual_pvs(ddg) + with patched_device(DelayGenerator, name="ddg", prefix="X12SA-CPCL-DDG3:") as ddg: yield ddg diff --git a/tests/test_dxp/test_dxp.py b/tests/test_dxp/test_dxp.py index 9ebb7f5..e2b95e8 100644 --- a/tests/test_dxp/test_dxp.py +++ b/tests/test_dxp/test_dxp.py @@ -3,10 +3,6 @@ Test module for DXP integration, i.e. Falcon, XMAP and Mercury detectors. This also includes EpicsMCARecord for data recording of multichannel analyzers. """ -import threading -from unittest import mock - -import ophyd import pytest from ophyd import Component as Cpt @@ -21,7 +17,7 @@ from ophyd_devices.devices.dxp import ( Mercury, xMAP, ) -from ophyd_devices.tests.utils import MockPV, patch_dual_pvs +from ophyd_devices.tests.utils import patched_device # from ophyd.mca import EpicsDXPMapping @@ -40,46 +36,27 @@ class TestFalcon(Falcon): @pytest.fixture(scope="function") def mock_falcon(): """Fixture to create a mock Falcon device for testing.""" - name = "mca" - prefix = "test_falcon" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - ddg = TestFalcon(name=name, prefix=prefix) - patch_dual_pvs(ddg) - yield ddg + with patched_device(TestFalcon, name="mca", prefix="test_falcon") as falc: + yield falc @pytest.fixture(scope="function") def mock_xmap(): """Fixture to create a mock xMAP device for testing.""" - name = "mca" - prefix = "test_xmap" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - ddg = xMAP(name=name, prefix=prefix) - patch_dual_pvs(ddg) - yield ddg + with patched_device(xMAP, name="mca", prefix="test_xmap") as xmap: + yield xmap @pytest.fixture(scope="function") def mock_mercury(): """Fixture to create a mock Mercury device for testing.""" - name = "mca" - prefix = "test_mercury" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - ddg = Mercury(name=name, prefix=prefix) - patch_dual_pvs(ddg) - yield ddg + with patched_device(Mercury, name="mca", prefix="test_mercury") as merc: + yield merc -def test_falcon(mock_falcon): +def test_falcon(mock_falcon: TestFalcon): """Test the Falcon device.""" # Test the default values - mock_falcon: TestFalcon assert mock_falcon.name == "mca" assert mock_falcon.prefix == "test_falcon" assert isinstance(mock_falcon, EpicsDXPFalconMultiElementSystem) @@ -99,9 +76,8 @@ def test_falcon(mock_falcon): ] -def test_falcon_trigger(mock_falcon): +def test_falcon_trigger(mock_falcon: TestFalcon): """Test the Falcon device trigger method.""" - mock_falcon: TestFalcon mock_falcon.erase_start.put(0) assert mock_falcon.erase_start.get() == 0 status = mock_falcon.trigger() @@ -111,20 +87,18 @@ def test_falcon_trigger(mock_falcon): assert status.done is True -def test_xmap(mock_xmap): +def test_xmap(mock_xmap: xMAP): """Test the xMAP device.""" # Test the default values - mock_xmap: xMAP assert mock_xmap.name == "mca" assert mock_xmap.prefix == "test_xmap" assert isinstance(mock_xmap, EpicsDXPMultiElementSystem) assert isinstance(mock_xmap, ADBase) -def test_mercury(mock_mercury): +def test_mercury(mock_mercury: Mercury): """Test the Mercury device.""" # Test the default values - mock_mercury: Mercury assert mock_mercury.name == "mca" assert mock_mercury.prefix == "test_mercury" assert isinstance(mock_mercury, EpicsDXPMultiElementSystem) @@ -132,9 +106,8 @@ def test_mercury(mock_mercury): assert isinstance(mock_mercury, ADBase) -def test_xmap_trigger(mock_xmap): +def test_xmap_trigger(mock_xmap: xMAP): """Test the xMAP device trigger method.""" - mock_xmap: xMAP mock_xmap.erase_start.put(0) assert mock_xmap.erase_start.get() == 0 status = mock_xmap.trigger() diff --git a/tests/test_psi_motors.py b/tests/test_psi_motors.py index 521b28d..6dde638 100644 --- a/tests/test_psi_motors.py +++ b/tests/test_psi_motors.py @@ -3,42 +3,27 @@ PSI motor integration from the ophyd_devices.devices.psi_motor module.""" from __future__ import annotations -import threading from unittest import mock import ophyd import pytest from ophyd_devices.devices.psi_motor import EpicsMotor, EpicsMotorEC, SpmgStates -from ophyd_devices.tests.utils import MockPV, patch_dual_pvs +from ophyd_devices.tests.utils import patched_device @pytest.fixture(scope="function") def mock_epics_motor(): """Fixture to create a mock EpicsMotor instance.""" - name = "test_motor" - prefix = "SIM:MOTOR" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - motor = EpicsMotor(name=name, prefix=prefix) - motor.wait_for_connection() - patch_dual_pvs(motor) + with patched_device(EpicsMotor, name="test_motor", prefix="SIM:MOTOR") as motor: yield motor @pytest.fixture(scope="function") def mock_epics_motor_ec(): """Fixture to create a mock EpicsMotorEC instance.""" - name = "test_motor_ec" - prefix = "SIM:MOTOR:EC" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - motor_ec = EpicsMotorEC(name=name, prefix=prefix) - motor_ec.wait_for_connection() - patch_dual_pvs(motor_ec) - yield motor_ec + with patched_device(EpicsMotorEC, name="test_motor_ec", prefix="SIM:MOTOR:EC") as motor: + yield motor def test_epics_motor_limits_raise(mock_epics_motor): diff --git a/tests/test_undulator.py b/tests/test_undulator.py index 7b779c1..1ea48cb 100644 --- a/tests/test_undulator.py +++ b/tests/test_undulator.py @@ -1,23 +1,15 @@ -import threading from unittest import mock -import ophyd import pytest from ophyd_devices.devices.undulator import UndulatorGap -from ophyd_devices.tests.utils import MockPV, patch_dual_pvs +from ophyd_devices.tests.utils import patched_device @pytest.fixture(scope="function") def mock_undulator(): - name = "undulator" - prefix = "TEST:UNDULATOR" - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - dev = UndulatorGap(name=name, prefix=prefix) - patch_dual_pvs(dev) - yield dev + with patched_device(UndulatorGap, name="undulator", prefix="TEST:UNDULATOR") as und: + yield und @pytest.mark.parametrize(