From 1a7812992adfe49ba734dbccb00456d9eba2c009 Mon Sep 17 00:00:00 2001 From: appel_c Date: Wed, 30 Jul 2025 13:16:40 +0200 Subject: [PATCH] fix(mock-pv): add callbacks to mock_pv --- ophyd_devices/tests/utils.py | 31 ++++++++++++++++--- tests/test_utils.py | 58 +++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 5 deletions(-) diff --git a/ophyd_devices/tests/utils.py b/ophyd_devices/tests/utils.py index 8b1dcab..8ebe89d 100644 --- a/ophyd_devices/tests/utils.py +++ b/ophyd_devices/tests/utils.py @@ -2,7 +2,7 @@ import threading from time import sleep -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from unittest import mock from bec_lib.devicemanager import ScanInfo @@ -174,7 +174,7 @@ class MockPV: self._args["access"] = "unknown" self._args["status"] = 0 self.connection_callbacks = [] - self.mock_data = 0 + self._mock_data = 0 if connection_callback is not None: self.connection_callbacks = [connection_callback] @@ -183,7 +183,7 @@ class MockPV: if access_callback is not None: self.access_callbacks = [access_callback] - self.callbacks = {} + self.callbacks: dict[int, tuple[Callable, dict]] = {} self._put_complete = None self._put_complete_event: threading.Event | None = None self._monref = None # holder of data returned from create_subscription @@ -205,6 +205,20 @@ class MockPV: for acc_cb in self.access_callbacks: acc_cb(True, True, pv=self) + @property + def mock_data(self): + """Get mock data""" + return self._mock_data + + @mock_data.setter + def mock_data(self, value): + """Set mock data""" + old_value = self._mock_data + + self._mock_data = value + for callback, kw in self.callbacks.values(): + callback(value=value, old_value=old_value, obj=self, **kw) + # pylint disable: unused-argument def wait_for_connection(self, timeout=None): """Wait for connection""" @@ -251,7 +265,15 @@ class MockPV: # pylint: disable=unused-argument def add_callback(self, callback=None, index=None, run_now=False, with_ctrlvars=True, **kw): """Add callback""" - return mock.MagicMock() + if callback is None: + logger.warning("Callback is None, cannot add callback") + return + if index is None: + index = len(self.callbacks) + self.callbacks[index] = (callback, kw) + if run_now: + callback(value=self.mock_data, old_value=self.mock_data, obj=self, **kw) + return index # pylint: disable=unused-argument def get_with_metadata( @@ -266,6 +288,7 @@ class MockPV: as_namespace=False, ): """Get MOCKPV data together with metadata""" + return {"value": self.mock_data} def get( diff --git a/tests/test_utils.py b/tests/test_utils.py index e776c47..8b13074 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,14 @@ import threading import time +from unittest import mock import numpy as np +import ophyd import pytest from bec_lib import messages -from ophyd import Device, Signal +from ophyd import Device, EpicsSignalRO, Signal +from ophyd_devices.tests.utils import MockPV, patch_dual_pvs from ophyd_devices.utils.bec_signals import ( BECMessageSignal, DynamicSignal, @@ -31,6 +34,17 @@ from ophyd_devices.utils.psi_device_base_utils import ( ########################################## +@pytest.fixture(scope="function") +def mock_epics_signal_ro(): + name = "epics_signal_ro" + read_pv = "TEST:EPICS_SIGNAL_RO" + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + dev = EpicsSignalRO(name=name, read_pv=read_pv) + yield dev + + @pytest.fixture def file_handler(): """Fixture for FileHandler""" @@ -687,3 +701,45 @@ def test_transition_status_strings(): sig.put("d") # last transition assert status.done is True assert status.success is True + + +def test_compare_status_with_mock_pv(mock_epics_signal_ro): + """Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals""" + + signal = mock_epics_signal_ro + status = CompareStatus(signal=signal, value=5, operation="==") + assert status.done is False + signal._read_pv.mock_data = 1 + assert status.done is False + signal._read_pv.mock_data = 5 + status.wait(timeout=1) + assert status.done is True + assert status.success is True + + +@pytest.mark.parametrize( + "transitions, expected_done, expected_success", + [ + ([1, 2, 3], True, True), # Transitions completed successfully + ([1, 3, 2], True, False), # Transitions completed with an error + ([5, 4, 2, 1, 2, 3], True, True), # Transitions completed successfully + ], +) +def test_transition_status_with_mock_pv( + mock_epics_signal_ro, transitions, expected_done, expected_success +): + """Test TransitionStatus with EpicsSignalRO, this tests callbacks on EpicsSignals""" + # Starts immediately with 1 + signal = mock_epics_signal_ro + signal._read_pv.mock_data = 1 + status = TransitionStatus(signal=signal, transitions=[1, 2, 3], strict=False) + assert status.done is False + # Does not have to wait + signal._read_pv.mock_data = 3 + signal._read_pv.mock_data = 2 + signal._read_pv.mock_data = 3 + status.wait(timeout=1) + assert status.done is True + assert status.success is True + # Test with various transitions + status = TransitionStatus(signal=signal, transitions=transitions, strict=True)