feat(psi_device_base): add psi_device_base

This commit is contained in:
2025-02-22 12:57:38 +01:00
parent 5ce67e62cb
commit ac4f0c5af7
11 changed files with 981 additions and 510 deletions

View File

@ -6,200 +6,215 @@ import pytest
from ophyd import DeviceStatus, Staged
from ophyd.utils.errors import RedundantStaging
from ophyd_devices.interfaces.base_classes.bec_device_base import BECDeviceBase, CustomPrepare
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.errors import DeviceStopError, DeviceTimeoutError
@pytest.fixture
def detector_base():
yield BECDeviceBase(name="test_detector")
yield PSIDeviceBase(name="test_detector")
def test_detector_base_init(detector_base):
assert detector_base.stopped is False
assert detector_base.name == "test_detector"
assert "base_path" in detector_base.filewriter.service_config
assert isinstance(detector_base.scaninfo, BecScaninfoMixin)
assert issubclass(detector_base.custom_prepare_cls, CustomPrepare)
assert detector_base.staged == Staged.no
assert detector_base.destroyed == False
def test_stage(detector_base):
detector_base._staged = Staged.yes
with pytest.raises(RedundantStaging):
detector_base.stage()
assert detector_base._staged == Staged.no
assert detector_base.stopped is False
detector_base._staged = Staged.no
with (
mock.patch.object(detector_base.custom_prepare, "on_stage") as mock_on_stage,
mock.patch.object(detector_base.scaninfo, "load_scan_metadata") as mock_load_metadata,
):
with mock.patch.object(detector_base, "on_stage") as mock_on_stage:
rtr = detector_base.stage()
assert isinstance(rtr, list)
mock_on_stage.assert_called_once()
mock_load_metadata.assert_called_once()
assert mock_on_stage.called is True
with pytest.raises(RedundantStaging):
detector_base.stage()
detector_base._staged = Staged.no
detector_base.stopped = True
detector_base.stage()
assert detector_base.stopped is False
assert mock_on_stage.call_count == 2
def test_pre_scan(detector_base):
with mock.patch.object(detector_base.custom_prepare, "on_pre_scan") as mock_on_pre_scan:
detector_base.pre_scan()
mock_on_pre_scan.assert_called_once()
# def test_stage(detector_base):
# detector_base._staged = Staged.yes
# with pytest.raises(RedundantStaging):
# detector_base.stage()
# assert detector_base.stopped is False
# detector_base._staged = Staged.no
# with (
# mock.patch.object(detector_base.custom_prepare, "on_stage") as mock_on_stage,
# mock.patch.object(detector_base.scaninfo, "load_scan_metadata") as mock_load_metadata,
# ):
# rtr = detector_base.stage()
# assert isinstance(rtr, list)
# mock_on_stage.assert_called_once()
# mock_load_metadata.assert_called_once()
# assert detector_base.stopped is False
def test_trigger(detector_base):
status = DeviceStatus(detector_base)
with mock.patch.object(
detector_base.custom_prepare, "on_trigger", side_effect=[None, status]
) as mock_on_trigger:
st = detector_base.trigger()
assert isinstance(st, DeviceStatus)
time.sleep(0.1)
assert st.done is True
st = detector_base.trigger()
assert st.done is False
assert id(st) == id(status)
# def test_pre_scan(detector_base):
# with mock.patch.object(detector_base.custom_prepare, "on_pre_scan") as mock_on_pre_scan:
# detector_base.pre_scan()
# mock_on_pre_scan.assert_called_once()
def test_unstage(detector_base):
detector_base.stopped = True
with (
mock.patch.object(detector_base.custom_prepare, "on_unstage") as mock_on_unstage,
mock.patch.object(detector_base, "check_scan_id") as mock_check_scan_id,
):
rtr = detector_base.unstage()
assert isinstance(rtr, list)
assert mock_check_scan_id.call_count == 1
assert mock_on_unstage.call_count == 1
detector_base.stopped = False
rtr = detector_base.unstage()
assert isinstance(rtr, list)
assert mock_check_scan_id.call_count == 2
assert mock_on_unstage.call_count == 2
# def test_trigger(detector_base):
# status = DeviceStatus(detector_base)
# with mock.patch.object(
# detector_base.custom_prepare, "on_trigger", side_effect=[None, status]
# ) as mock_on_trigger:
# st = detector_base.trigger()
# assert isinstance(st, DeviceStatus)
# time.sleep(0.1)
# assert st.done is True
# st = detector_base.trigger()
# assert st.done is False
# assert id(st) == id(status)
def test_complete(detector_base):
status = DeviceStatus(detector_base)
with mock.patch.object(
detector_base.custom_prepare, "on_complete", side_effect=[None, status]
) as mock_on_complete:
st = detector_base.complete()
assert isinstance(st, DeviceStatus)
time.sleep(0.1)
assert st.done is True
st = detector_base.complete()
assert st.done is False
assert id(st) == id(status)
# def test_unstage(detector_base):
# detector_base.stopped = True
# with (
# mock.patch.object(detector_base.custom_prepare, "on_unstage") as mock_on_unstage,
# mock.patch.object(detector_base, "check_scan_id") as mock_check_scan_id,
# ):
# rtr = detector_base.unstage()
# assert isinstance(rtr, list)
# assert mock_check_scan_id.call_count == 1
# assert mock_on_unstage.call_count == 1
# detector_base.stopped = False
# rtr = detector_base.unstage()
# assert isinstance(rtr, list)
# assert mock_check_scan_id.call_count == 2
# assert mock_on_unstage.call_count == 2
def test_stop(detector_base):
with mock.patch.object(detector_base.custom_prepare, "on_stop") as mock_on_stop:
detector_base.stop()
mock_on_stop.assert_called_once()
assert detector_base.stopped is True
# def test_complete(detector_base):
# status = DeviceStatus(detector_base)
# with mock.patch.object(
# detector_base.custom_prepare, "on_complete", side_effect=[None, status]
# ) as mock_on_complete:
# st = detector_base.complete()
# assert isinstance(st, DeviceStatus)
# time.sleep(0.1)
# assert st.done is True
# st = detector_base.complete()
# assert st.done is False
# assert id(st) == id(status)
def test_check_scan_id(detector_base):
detector_base.scaninfo.scan_id = "abcde"
detector_base.stopped = False
detector_base.check_scan_id()
assert detector_base.stopped is True
detector_base.stopped = False
detector_base.check_scan_id()
assert detector_base.stopped is False
# def test_stop(detector_base):
# with mock.patch.object(detector_base.custom_prepare, "on_stop") as mock_on_stop:
# detector_base.stop()
# mock_on_stop.assert_called_once()
# assert detector_base.stopped is True
def test_wait_for_signal(detector_base):
my_value = False
def my_callback():
return my_value
detector_base
status = detector_base.custom_prepare.wait_with_status(
[(my_callback, True)],
check_stopped=True,
timeout=5,
interval=0.01,
exception_on_timeout=None,
)
time.sleep(0.1)
assert status.done is False
# Check first that it is stopped when detector_base.stop() is called
detector_base.stop()
# some delay to allow the stop to take effect
time.sleep(0.15)
assert status.done is True
assert status.exception().args == DeviceStopError(f"{detector_base.name} was stopped").args
detector_base.stopped = False
status = detector_base.custom_prepare.wait_with_status(
[(my_callback, True)],
check_stopped=True,
timeout=5,
interval=0.01,
exception_on_timeout=None,
)
# Check that thread resolves when expected value is set
my_value = True
# some delay to allow the stop to take effect
time.sleep(0.15)
assert status.done is True
assert status.success is True
assert status.exception() is None
detector_base.stopped = False
# Check that wait for status runs into timeout with expectd exception
my_value = "random_value"
exception = TimeoutError("Timeout")
status = detector_base.custom_prepare.wait_with_status(
[(my_callback, True)],
check_stopped=True,
timeout=0.01,
interval=0.01,
exception_on_timeout=exception,
)
time.sleep(0.2)
assert status.done is True
assert id(status.exception()) == id(exception)
assert status.success is False
# def test_check_scan_id(detector_base):
# detector_base.scaninfo.scan_id = "abcde"
# detector_base.stopped = False
# detector_base.check_scan_id()
# assert detector_base.stopped is True
# detector_base.stopped = False
# detector_base.check_scan_id()
# assert detector_base.stopped is False
def test_wait_for_signal_returns_exception(detector_base):
my_value = False
# def test_wait_for_signal(detector_base):
# my_value = False
def my_callback():
return my_value
# def my_callback():
# return my_value
# Check that wait for status runs into timeout with expectd exception
# detector_base
# status = detector_base.custom_prepare.wait_with_status(
# [(my_callback, True)],
# check_stopped=True,
# timeout=5,
# interval=0.01,
# exception_on_timeout=None,
# )
# time.sleep(0.1)
# assert status.done is False
# # Check first that it is stopped when detector_base.stop() is called
# detector_base.stop()
# # some delay to allow the stop to take effect
# time.sleep(0.15)
# assert status.done is True
# assert status.exception().args == DeviceStopError(f"{detector_base.name} was stopped").args
# detector_base.stopped = False
# status = detector_base.custom_prepare.wait_with_status(
# [(my_callback, True)],
# check_stopped=True,
# timeout=5,
# interval=0.01,
# exception_on_timeout=None,
# )
# # Check that thread resolves when expected value is set
# my_value = True
# # some delay to allow the stop to take effect
# time.sleep(0.15)
# assert status.done is True
# assert status.success is True
# assert status.exception() is None
exception = TimeoutError("Timeout")
status = detector_base.custom_prepare.wait_with_status(
[(my_callback, True)],
check_stopped=True,
timeout=0.01,
interval=0.01,
exception_on_timeout=exception,
)
time.sleep(0.2)
assert status.done is True
assert id(status.exception()) == id(exception)
assert status.success is False
# detector_base.stopped = False
# # Check that wait for status runs into timeout with expectd exception
# my_value = "random_value"
# exception = TimeoutError("Timeout")
# status = detector_base.custom_prepare.wait_with_status(
# [(my_callback, True)],
# check_stopped=True,
# timeout=0.01,
# interval=0.01,
# exception_on_timeout=exception,
# )
# time.sleep(0.2)
# assert status.done is True
# assert id(status.exception()) == id(exception)
# assert status.success is False
detector_base.stopped = False
# Check that standard exception is thrown
status = detector_base.custom_prepare.wait_with_status(
[(my_callback, True)],
check_stopped=True,
timeout=0.01,
interval=0.01,
exception_on_timeout=None,
)
time.sleep(0.2)
assert status.done is True
assert (
status.exception().args
== DeviceTimeoutError(
f"Timeout error for {detector_base.name} while waiting for signals {[(my_callback, True)]}"
).args
)
assert status.success is False
# def test_wait_for_signal_returns_exception(detector_base):
# my_value = False
# def my_callback():
# return my_value
# # Check that wait for status runs into timeout with expectd exception
# exception = TimeoutError("Timeout")
# status = detector_base.custom_prepare.wait_with_status(
# [(my_callback, True)],
# check_stopped=True,
# timeout=0.01,
# interval=0.01,
# exception_on_timeout=exception,
# )
# time.sleep(0.2)
# assert status.done is True
# assert id(status.exception()) == id(exception)
# assert status.success is False
# detector_base.stopped = False
# # Check that standard exception is thrown
# status = detector_base.custom_prepare.wait_with_status(
# [(my_callback, True)],
# check_stopped=True,
# timeout=0.01,
# interval=0.01,
# exception_on_timeout=None,
# )
# time.sleep(0.2)
# assert status.done is True
# assert (
# status.exception().args
# == DeviceTimeoutError(
# f"Timeout error for {detector_base.name} while waiting for signals {[(my_callback, True)]}"
# ).args
# )
# assert status.success is False

View File

@ -17,7 +17,6 @@ from ophyd import Device, Signal
from ophyd.status import wait as status_wait
from ophyd_devices.interfaces.protocols.bec_protocols import (
BECBaseProtocol,
BECDeviceProtocol,
BECFlyerProtocol,
BECPositionerProtocol,
@ -31,6 +30,7 @@ from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimP
from ophyd_devices.sim.sim_signals import ReadOnlySignal
from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory
from ophyd_devices.sim.sim_waveform import SimWaveform
from ophyd_devices.tests.utils import get_mock_scan_info
from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase
@ -423,7 +423,6 @@ def test_h5proxy(h5proxy_fixture):
)
camera._registered_proxies.update({h5proxy.name: camera.image.name})
camera.sim.params = {"noise": "none", "noise_multiplier": 0}
camera.scaninfo.sim_mode = True
# pylint: disable=no-member
camera.image_shape.set(data.shape[1:])
camera.stage()
@ -544,15 +543,15 @@ def test_cam_stage_h5writer(camera):
mock.patch.object(camera, "h5_writer") as mock_h5_writer,
mock.patch.object(camera, "_run_subs") as mock_run_subs,
):
camera.scaninfo.num_points = 10
camera.scaninfo.frames_per_trigger = 1
camera.scaninfo.exp_time = 1
camera.scan_info.msg.num_points = 10
camera.scan_info.msg.scan_parameters["frames_per_trigger"] = 1
camera.scan_info.msg.scan_parameters["exp_time"] = 1
camera.stage()
assert mock_h5_writer.on_stage.call_count == 0
camera.unstage()
camera.write_to_disk.put(True)
camera.stage()
calls = [mock.call(file_path="", h5_entry="/entry/data/data")]
calls = [mock.call(file_path="./data/test_file_camera.h5", h5_entry="/entry/data/data")]
assert mock_h5_writer.on_stage.mock_calls == calls
# mock_h5_writer.prepare
@ -622,17 +621,17 @@ def test_async_monitor_stage(async_monitor):
def test_async_monitor_prep_random_interval(async_monitor):
"""Test the stage method of SimMonitorAsync."""
async_monitor.custom_prepare.prep_random_interval()
assert async_monitor.custom_prepare._counter == 0
async_monitor.prep_random_interval()
assert async_monitor._counter == 0
assert async_monitor.current_trigger.get() == 0
assert 0 < async_monitor.custom_prepare._random_send_interval < 10
assert 0 < async_monitor._random_send_interval < 10
def test_async_monitor_complete(async_monitor):
"""Test the on_complete method of SimMonitorAsync."""
with (
mock.patch.object(async_monitor.custom_prepare, "_send_data_to_bec") as mock_send,
mock.patch.object(async_monitor.custom_prepare, "prep_random_interval") as mock_prep,
mock.patch.object(async_monitor, "_send_data_to_bec") as mock_send,
mock.patch.object(async_monitor, "prep_random_interval") as mock_prep,
):
status = async_monitor.complete()
status_wait(status)
@ -649,11 +648,11 @@ def test_async_monitor_complete(async_monitor):
def test_async_mon_on_trigger(async_monitor):
"""Test the on_trigger method of SimMonitorAsync."""
with (mock.patch.object(async_monitor.custom_prepare, "_send_data_to_bec") as mock_send,):
async_monitor.custom_prepare.on_stage()
upper_limit = async_monitor.custom_prepare._random_send_interval
with (mock.patch.object(async_monitor, "_send_data_to_bec") as mock_send,):
async_monitor.on_stage()
upper_limit = async_monitor._random_send_interval
for ii in range(1, upper_limit + 1):
status = async_monitor.custom_prepare.on_trigger()
status = async_monitor.on_trigger()
status_wait(status)
assert async_monitor.current_trigger.get() == ii
assert mock_send.call_count == 1
@ -661,10 +660,10 @@ def test_async_mon_on_trigger(async_monitor):
def test_async_mon_send_data_to_bec(async_monitor):
"""Test the _send_data_to_bec method of SimMonitorAsync."""
async_monitor.scaninfo.scan_msg = SimpleNamespace(metadata={})
async_monitor.scan_info = get_mock_scan_info()
async_monitor.data_buffer.update({"value": [0, 5], "timestamp": [0, 0]})
with mock.patch.object(async_monitor.connector, "xadd") as mock_xadd:
async_monitor.custom_prepare._send_data_to_bec()
async_monitor._send_data_to_bec()
dev_msg = messages.DeviceMessage(
signals={async_monitor.readback.name: async_monitor.data_buffer},
metadata={"async_update": async_monitor.async_update.get()},
@ -673,10 +672,10 @@ def test_async_mon_send_data_to_bec(async_monitor):
call = [
mock.call(
MessageEndpoints.device_async_readback(
scan_id=async_monitor.scaninfo.scan_id, device=async_monitor.name
scan_id=async_monitor.scan_info.msg.scan_id, device=async_monitor.name
),
{"data": dev_msg},
expire=async_monitor.custom_prepare._stream_ttl,
expire=async_monitor._stream_ttl,
)
]
assert mock_xadd.mock_calls == call
@ -711,8 +710,8 @@ def test_waveform(waveform):
# Now also test the async readback
mock_connector = waveform.connector = mock.MagicMock()
mock_run_subs = waveform._run_subs = mock.MagicMock()
waveform.scaninfo.scan_msg = SimpleNamespace(metadata={})
waveform.scaninfo.scan_id = "test"
waveform.scan_info = get_mock_scan_info()
waveform.scan_info.msg.scan_id = "test"
status = waveform.trigger()
timer = 0
while not status.done: