test(MockPv): improve MockPV, allow start value to be set

This commit is contained in:
2025-12-08 18:40:09 +01:00
committed by Christian Appel
parent c4296b0399
commit 1b2eeccbb8
3 changed files with 28 additions and 17 deletions

View File

@@ -2,6 +2,7 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Callable, Generator, TypeVar from typing import TYPE_CHECKING, Callable, Generator, TypeVar
from unittest import mock from unittest import mock
@@ -25,7 +26,9 @@ T = TypeVar("T", bound=Device)
@contextmanager @contextmanager
def patched_device(device_type: type[T], *args, **kwargs) -> Generator[T, None, None]: def patched_device(
device_type: type[T], *args, _mock_pv_initial_value=0, **kwargs
) -> Generator[T, None, None]:
"""Context manager to yield a patched ophyd device with certain initialisation args. """Context manager to yield a patched ophyd device with certain initialisation args.
*args and **kwargs are passed directly through to the device constructor. *args and **kwargs are passed directly through to the device constructor.
@@ -37,7 +40,7 @@ def patched_device(device_type: type[T], *args, **kwargs) -> Generator[T, None,
""" """
with mock.patch.object(ophyd, "cl") as mock_cl: with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV mock_cl.get_pv = partial(MockPV, _mock_pv_initial_value=_mock_pv_initial_value)
mock_cl.thread_class = threading.Thread mock_cl.thread_class = threading.Thread
device = device_type(*args, **kwargs) device = device_type(*args, **kwargs)
patch_dual_pvs(device) patch_dual_pvs(device)
@@ -137,8 +140,6 @@ class MockPV:
""" """
DEFAULT_VALUE = 0
_fmtsca = "<PV '%(pvname)s', count=%(count)i, type=%(typefull)s, access=%(access)s>" _fmtsca = "<PV '%(pvname)s', count=%(count)i, type=%(typefull)s, access=%(access)s>"
_fmtarr = "<PV '%(pvname)s', count=%(count)i/%(nelm)i, type=%(typefull)s, access=%(access)s>" _fmtarr = "<PV '%(pvname)s', count=%(count)i/%(nelm)i, type=%(typefull)s, access=%(access)s>"
_fields = ( _fields = (
@@ -173,6 +174,8 @@ class MockPV:
def __init__( def __init__(
self, self,
pvname, pvname,
*,
_mock_pv_initial_value=0,
callback=None, callback=None,
form="time", form="time",
verbose=False, verbose=False,
@@ -202,7 +205,7 @@ class MockPV:
self._args["access"] = "unknown" self._args["access"] = "unknown"
self._args["status"] = 0 self._args["status"] = 0
self.connection_callbacks = [] self.connection_callbacks = []
self._mock_data = self.DEFAULT_VALUE self._mock_data = _mock_pv_initial_value
if connection_callback is not None: if connection_callback is not None:
self.connection_callbacks = [connection_callback] self.connection_callbacks = [connection_callback]

View File

@@ -9,8 +9,7 @@ from unittest import mock
import ophyd import ophyd
import pytest import pytest
from ophyd_devices.devices.epics_motor_ex import EpicsMotorEx from ophyd_devices.devices.psi_motor import EpicsMotor, EpicsMotorEC, EpicsUserMotorVME, SpmgStates
from ophyd_devices.devices.psi_motor import EpicsMotor, EpicsMotorEC, EpicsUserMotors, SpmgStates
from ophyd_devices.tests.utils import MockPV, patched_device from ophyd_devices.tests.utils import MockPV, patched_device
@@ -169,7 +168,7 @@ def test_epics_vme_user_motor():
with mock.patch.object(ophyd, "cl") as mock_cl: with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread mock_cl.thread_class = threading.Thread
device = EpicsUserMotors(name="test_motor_ex", prefix="SIM:MOTOR:EX") device = EpicsUserMotorVME(name="test_motor_ex", prefix="SIM:MOTOR:EX")
# Should raise # Should raise
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
device.wait_for_connection(all_signals=True) device.wait_for_connection(all_signals=True)

View File

@@ -12,15 +12,8 @@ from ophyd import Device, EpicsSignalRO, Signal
from ophyd.status import WaitTimeoutError from ophyd.status import WaitTimeoutError
from typeguard import TypeCheckError from typeguard import TypeCheckError
from ophyd_devices import ( from ophyd_devices.devices.psi_motor import EpicsMotor
AndStatus, from ophyd_devices.tests.utils import MockPV, patched_device
DeviceStatus,
MoveStatus,
Status,
StatusBase,
SubscriptionStatus,
)
from ophyd_devices.tests.utils import MockPV
from ophyd_devices.utils.bec_signals import ( from ophyd_devices.utils.bec_signals import (
AsyncMultiSignal, AsyncMultiSignal,
AsyncSignal, AsyncSignal,
@@ -31,8 +24,13 @@ from ophyd_devices.utils.bec_signals import (
ProgressSignal, ProgressSignal,
) )
from ophyd_devices.utils.psi_device_base_utils import ( from ophyd_devices.utils.psi_device_base_utils import (
AndStatus,
CompareStatus, CompareStatus,
DeviceStatus,
FileHandler, FileHandler,
MoveStatus,
StatusBase,
SubscriptionStatus,
TaskHandler, TaskHandler,
TaskKilledError, TaskKilledError,
TaskState, TaskState,
@@ -1021,3 +1019,14 @@ def test_patched_status_objects():
move_st.wait(timeout=10) move_st.wait(timeout=10)
assert move_st.done is True assert move_st.done is True
assert move_st.success is False assert move_st.success is False
@pytest.fixture(scope="function")
def mock_device_with_initial_value():
with patched_device(EpicsMotor, _mock_pv_initial_value=2, name="motor") as mtr:
yield mtr
def test_mock_device_initial_value(mock_device_with_initial_value: EpicsMotor):
mtr = mock_device_with_initial_value
assert mtr.velocity.get() == 2