mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-06 07:08:40 +01:00
fix(mock-pv): add callbacks to mock_pv
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from bec_lib.devicemanager import ScanInfo
|
from bec_lib.devicemanager import ScanInfo
|
||||||
@@ -174,7 +174,7 @@ class MockPV:
|
|||||||
self._args["access"] = "unknown"
|
self._args["access"] = "unknown"
|
||||||
self._args["status"] = 0
|
self._args["status"] = 0
|
||||||
self.connection_callbacks = []
|
self.connection_callbacks = []
|
||||||
self.mock_data = 0
|
self._mock_data = 0
|
||||||
|
|
||||||
if connection_callback is not None:
|
if connection_callback is not None:
|
||||||
self.connection_callbacks = [connection_callback]
|
self.connection_callbacks = [connection_callback]
|
||||||
@@ -183,7 +183,7 @@ class MockPV:
|
|||||||
if access_callback is not None:
|
if access_callback is not None:
|
||||||
self.access_callbacks = [access_callback]
|
self.access_callbacks = [access_callback]
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks: dict[int, tuple[Callable, dict]] = {}
|
||||||
self._put_complete = None
|
self._put_complete = None
|
||||||
self._put_complete_event: threading.Event | None = None
|
self._put_complete_event: threading.Event | None = None
|
||||||
self._monref = None # holder of data returned from create_subscription
|
self._monref = None # holder of data returned from create_subscription
|
||||||
@@ -205,6 +205,20 @@ class MockPV:
|
|||||||
for acc_cb in self.access_callbacks:
|
for acc_cb in self.access_callbacks:
|
||||||
acc_cb(True, True, pv=self)
|
acc_cb(True, True, pv=self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mock_data(self):
|
||||||
|
"""Get mock data"""
|
||||||
|
return self._mock_data
|
||||||
|
|
||||||
|
@mock_data.setter
|
||||||
|
def mock_data(self, value):
|
||||||
|
"""Set mock data"""
|
||||||
|
old_value = self._mock_data
|
||||||
|
|
||||||
|
self._mock_data = value
|
||||||
|
for callback, kw in self.callbacks.values():
|
||||||
|
callback(value=value, old_value=old_value, obj=self, **kw)
|
||||||
|
|
||||||
# pylint disable: unused-argument
|
# pylint disable: unused-argument
|
||||||
def wait_for_connection(self, timeout=None):
|
def wait_for_connection(self, timeout=None):
|
||||||
"""Wait for connection"""
|
"""Wait for connection"""
|
||||||
@@ -251,7 +265,15 @@ class MockPV:
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def add_callback(self, callback=None, index=None, run_now=False, with_ctrlvars=True, **kw):
|
def add_callback(self, callback=None, index=None, run_now=False, with_ctrlvars=True, **kw):
|
||||||
"""Add callback"""
|
"""Add callback"""
|
||||||
return mock.MagicMock()
|
if callback is None:
|
||||||
|
logger.warning("Callback is None, cannot add callback")
|
||||||
|
return
|
||||||
|
if index is None:
|
||||||
|
index = len(self.callbacks)
|
||||||
|
self.callbacks[index] = (callback, kw)
|
||||||
|
if run_now:
|
||||||
|
callback(value=self.mock_data, old_value=self.mock_data, obj=self, **kw)
|
||||||
|
return index
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def get_with_metadata(
|
def get_with_metadata(
|
||||||
@@ -266,6 +288,7 @@ class MockPV:
|
|||||||
as_namespace=False,
|
as_namespace=False,
|
||||||
):
|
):
|
||||||
"""Get MOCKPV data together with metadata"""
|
"""Get MOCKPV data together with metadata"""
|
||||||
|
|
||||||
return {"value": self.mock_data}
|
return {"value": self.mock_data}
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import ophyd
|
||||||
import pytest
|
import pytest
|
||||||
from bec_lib import messages
|
from bec_lib import messages
|
||||||
from ophyd import Device, Signal
|
from ophyd import Device, EpicsSignalRO, Signal
|
||||||
|
|
||||||
|
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
|
||||||
from ophyd_devices.utils.bec_signals import (
|
from ophyd_devices.utils.bec_signals import (
|
||||||
BECMessageSignal,
|
BECMessageSignal,
|
||||||
DynamicSignal,
|
DynamicSignal,
|
||||||
@@ -31,6 +34,17 @@ from ophyd_devices.utils.psi_device_base_utils import (
|
|||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def mock_epics_signal_ro():
|
||||||
|
name = "epics_signal_ro"
|
||||||
|
read_pv = "TEST:EPICS_SIGNAL_RO"
|
||||||
|
with mock.patch.object(ophyd, "cl") as mock_cl:
|
||||||
|
mock_cl.get_pv = MockPV
|
||||||
|
mock_cl.thread_class = threading.Thread
|
||||||
|
dev = EpicsSignalRO(name=name, read_pv=read_pv)
|
||||||
|
yield dev
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def file_handler():
|
def file_handler():
|
||||||
"""Fixture for FileHandler"""
|
"""Fixture for FileHandler"""
|
||||||
@@ -687,3 +701,45 @@ def test_transition_status_strings():
|
|||||||
sig.put("d") # last transition
|
sig.put("d") # last transition
|
||||||
assert status.done is True
|
assert status.done is True
|
||||||
assert status.success is True
|
assert status.success is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_status_with_mock_pv(mock_epics_signal_ro):
|
||||||
|
"""Test CompareStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
|
||||||
|
|
||||||
|
signal = mock_epics_signal_ro
|
||||||
|
status = CompareStatus(signal=signal, value=5, operation="==")
|
||||||
|
assert status.done is False
|
||||||
|
signal._read_pv.mock_data = 1
|
||||||
|
assert status.done is False
|
||||||
|
signal._read_pv.mock_data = 5
|
||||||
|
status.wait(timeout=1)
|
||||||
|
assert status.done is True
|
||||||
|
assert status.success is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"transitions, expected_done, expected_success",
|
||||||
|
[
|
||||||
|
([1, 2, 3], True, True), # Transitions completed successfully
|
||||||
|
([1, 3, 2], True, False), # Transitions completed with an error
|
||||||
|
([5, 4, 2, 1, 2, 3], True, True), # Transitions completed successfully
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_transition_status_with_mock_pv(
|
||||||
|
mock_epics_signal_ro, transitions, expected_done, expected_success
|
||||||
|
):
|
||||||
|
"""Test TransitionStatus with EpicsSignalRO, this tests callbacks on EpicsSignals"""
|
||||||
|
# Starts immediately with 1
|
||||||
|
signal = mock_epics_signal_ro
|
||||||
|
signal._read_pv.mock_data = 1
|
||||||
|
status = TransitionStatus(signal=signal, transitions=[1, 2, 3], strict=False)
|
||||||
|
assert status.done is False
|
||||||
|
# Does not have to wait
|
||||||
|
signal._read_pv.mock_data = 3
|
||||||
|
signal._read_pv.mock_data = 2
|
||||||
|
signal._read_pv.mock_data = 3
|
||||||
|
status.wait(timeout=1)
|
||||||
|
assert status.done is True
|
||||||
|
assert status.success is True
|
||||||
|
# Test with various transitions
|
||||||
|
status = TransitionStatus(signal=signal, transitions=transitions, strict=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user