refactor: cleanup, fix tests
This commit is contained in:
@@ -159,60 +159,6 @@ def test_set_control_settings(mock_bragg):
|
||||
assert dev.scan_control.scan_duration.get() == 5
|
||||
|
||||
|
||||
def test_update_scan_parameters(mock_bragg):
|
||||
dev = mock_bragg
|
||||
msg = ScanStatusMessage(
|
||||
scan_id="my_scan_id",
|
||||
status="closed",
|
||||
request_inputs={
|
||||
"inputs": {},
|
||||
"kwargs": {
|
||||
"start": 0,
|
||||
"stop": 5,
|
||||
"scan_time": 1,
|
||||
"scan_duration": 10,
|
||||
"xrd_enable_low": True,
|
||||
"xrd_enable_high": False,
|
||||
"num_trigger_low": 1,
|
||||
"num_trigger_high": 7,
|
||||
"exp_time_low": 1,
|
||||
"exp_time_high": 3,
|
||||
"cycle_low": 1,
|
||||
"cycle_high": 5,
|
||||
"p_kink": 50,
|
||||
"e_kink": 8000,
|
||||
},
|
||||
},
|
||||
info={
|
||||
"kwargs": {
|
||||
"start": 0,
|
||||
"stop": 5,
|
||||
"scan_time": 1,
|
||||
"scan_duration": 10,
|
||||
"xrd_enable_low": True,
|
||||
"xrd_enable_high": False,
|
||||
"num_trigger_low": 1,
|
||||
"num_trigger_high": 7,
|
||||
"exp_time_low": 1,
|
||||
"exp_time_high": 3,
|
||||
"cycle_low": 1,
|
||||
"cycle_high": 5,
|
||||
"p_kink": 50,
|
||||
"e_kink": 8000,
|
||||
}
|
||||
},
|
||||
metadata={},
|
||||
)
|
||||
mock_bragg.scan_info.msg = msg
|
||||
scan_param = dev.scan_parameter.model_dump()
|
||||
for _, v in scan_param.items():
|
||||
assert v == None
|
||||
dev._update_scan_parameter()
|
||||
scan_param = dev.scan_parameter.model_dump()
|
||||
for k, v in scan_param.items():
|
||||
assert v == msg.content["request_inputs"]["kwargs"].get(k, None)
|
||||
|
||||
|
||||
def test_kickoff_scan(mock_bragg):
|
||||
dev = mock_bragg
|
||||
dev.scan_control.scan_status._read_pv.mock_data = ScanControlScanStatus.READY
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest import mock
|
||||
import ophyd
|
||||
import pytest
|
||||
from bec_server.scan_server.scan_worker import ScanWorker
|
||||
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
|
||||
from ophyd.status import WaitTimeoutError
|
||||
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
|
||||
from ophyd_devices.tests.utils import MockPV
|
||||
@@ -15,6 +16,13 @@ from debye_bec.devices.nidaq.nidaq import Nidaq, NidaqError
|
||||
|
||||
# TODO move this function to ophyd_devices, it is duplicated in csaxs_bec and needed for other pluging repositories
|
||||
from debye_bec.devices.test_utils.utils import patch_dual_pvs
|
||||
from debye_bec.devices.utils.utils import fetch_scan_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def scan_info_mock():
|
||||
"""Fixture for the ScanInfo object."""
|
||||
return ScanServerScanInfo(scan_name="xas_simple_scan", scan_id="test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -52,13 +60,17 @@ def test_init(mock_nidaq):
|
||||
]
|
||||
|
||||
|
||||
def test_check_if_scan_name_is_valid(mock_nidaq):
|
||||
def test_check_if_scan_name_is_valid(mock_nidaq, scan_info_mock):
|
||||
"""Test the check_if_scan_name_is_valid method."""
|
||||
dev = mock_nidaq
|
||||
dev.scan_info.msg.scan_name = "xas_simple_scan"
|
||||
assert dev._check_if_scan_name_is_valid()
|
||||
dev.scan_info.msg.scan_name = "invalid_scan_name"
|
||||
assert not dev._check_if_scan_name_is_valid()
|
||||
scan_info_mock.scan_name = "xas_simple_scan"
|
||||
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
|
||||
scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
assert dev._check_if_scan_name_is_valid(scan_parameters)
|
||||
scan_info_mock.scan_name = "invalid_scan_name"
|
||||
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
|
||||
scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
assert not dev._check_if_scan_name_is_valid(scan_parameters)
|
||||
|
||||
|
||||
def test_set_config(mock_nidaq):
|
||||
@@ -120,11 +132,13 @@ def test_on_unstage(mock_nidaq):
|
||||
("nidaq_continuous_scan", False, 0),
|
||||
],
|
||||
)
|
||||
def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state):
|
||||
def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state, scan_info_mock):
|
||||
"""Test the on_pre_scan method of the Nidaq device."""
|
||||
dev = mock_nidaq
|
||||
dev.state.put(nidaq_state)
|
||||
dev.scan_info.msg.scan_name = scan_name
|
||||
scan_info_mock.scan_name = scan_name
|
||||
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
|
||||
dev.scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
dev._timeout_wait_for_pv = 0.1 # Set a short timeout for testing
|
||||
if not raise_error:
|
||||
dev.pre_scan()
|
||||
@@ -133,11 +147,13 @@ def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state):
|
||||
dev.pre_scan()
|
||||
|
||||
|
||||
def test_on_complete(mock_nidaq):
|
||||
def test_on_complete(mock_nidaq, scan_info_mock):
|
||||
"""Test the on_complete method of the Nidaq device."""
|
||||
dev = mock_nidaq
|
||||
scan_info_mock.scan_name = "nidaq_continuous_scan"
|
||||
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
|
||||
dev.scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
# Check for nidaq_continuous_scan
|
||||
dev.scan_info.msg.scan_name = "nidaq_continuous_scan"
|
||||
dev.state.put(0) # Set state to DISABLED
|
||||
status = dev.complete()
|
||||
assert status.done is False
|
||||
@@ -147,7 +163,9 @@ def test_on_complete(mock_nidaq):
|
||||
assert status.done is True
|
||||
|
||||
# Check for XAS simple scan
|
||||
dev.scan_info.msg.scan_name = "xas_simple_scan"
|
||||
scan_info_mock.scan_name = "xas_simple_scan"
|
||||
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
|
||||
dev.scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
dev.state.put(0) # Set state to ACQUIRE
|
||||
dev.stop_call.put(0)
|
||||
dev._timeout_wait_for_pv = 5
|
||||
|
||||
@@ -7,6 +7,9 @@ import ophyd
|
||||
import pytest
|
||||
from bec_lib.messages import ScanStatusMessage
|
||||
from bec_server.scan_server.scan_worker import ScanWorker
|
||||
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
|
||||
from bec_server.scan_server.tests.scan_fixtures import *
|
||||
from bec_server.scan_server.tests.scan_fixtures import _MockDevice
|
||||
from ophyd_devices import CompareStatus, DeviceStatus
|
||||
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
|
||||
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
|
||||
@@ -20,6 +23,7 @@ from debye_bec.devices.pilatus.pilatus import (
|
||||
TRIGGERMODE,
|
||||
Pilatus,
|
||||
)
|
||||
from debye_bec.devices.utils.utils import fetch_scan_info
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from bec_lib.messages import FileMessage
|
||||
@@ -34,32 +38,38 @@ if TYPE_CHECKING: # pragma no cover
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
(0.1, 1, 1, "line_scan", "step"),
|
||||
(0.2, 2, 2, "time_scan", "step"),
|
||||
(0.5, 5, 5, "xas_advanced_scan", "fly"),
|
||||
(("samx", 0.1, 1, 5, "samy", 0, 1, 5), {"relative": True}, "_v4_hexagonal_scan"),
|
||||
((1, 0.2), {}, "_v4_time_scan"),
|
||||
((9000, 10000, 1, 20, 0.1, 9500), {}, "xas_advanced_scan"),
|
||||
],
|
||||
)
|
||||
def mock_scan_info(request, tmpdir):
|
||||
exp_time, frames_per_trigger, num_points, scan_name, scan_type = request.param
|
||||
scan_info = ScanStatusMessage(
|
||||
scan_id="test_id",
|
||||
status="open",
|
||||
scan_type=scan_type,
|
||||
scan_number=1,
|
||||
scan_parameters={
|
||||
"exp_time": exp_time,
|
||||
"frames_per_trigger": frames_per_trigger,
|
||||
"system_config": {},
|
||||
},
|
||||
info={"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")},
|
||||
num_points=num_points,
|
||||
scan_name=scan_name,
|
||||
)
|
||||
yield scan_info
|
||||
def mock_scan_info(request, tmpdir, v4_scan_assembler, device_manager):
|
||||
args, kwargs, scan_name = request.param
|
||||
mo1_bragg = _MockDevice(name="mo1_bragg")
|
||||
nidaq = _MockDevice(name="nidaq")
|
||||
device_manager.add_device(mo1_bragg)
|
||||
device_manager.add_device(nidaq)
|
||||
scan = v4_scan_assembler(scan_name, *args, **kwargs)
|
||||
yield scan.scan_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]:
|
||||
def mock_scan_status_message(mock_scan_info, tmpdir) -> ScanStatusMessage:
|
||||
info = mock_scan_info.model_dump()
|
||||
info.update({"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")})
|
||||
return ScanStatusMessage(
|
||||
scan_id=mock_scan_info.scan_id,
|
||||
status="open",
|
||||
scan_number=1,
|
||||
scan_name=mock_scan_info.scan_name,
|
||||
scan_type="fly" if mock_scan_info.scan_type == "hardware_triggered" else "step",
|
||||
num_points=mock_scan_info.num_points,
|
||||
info=info,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pilatus(mock_scan_status_message) -> Generator[Pilatus, None, None]:
|
||||
name = "pilatus"
|
||||
prefix = "X01DA-OP-MO1:PILATUS:"
|
||||
with mock.patch.object(ophyd, "cl") as mock_cl:
|
||||
@@ -70,8 +80,9 @@ def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]:
|
||||
# dev.image1 = mock.MagicMock()
|
||||
# with mock.patch.object(dev, "image1"):
|
||||
with mock.patch.object(dev, "task_handler"):
|
||||
dev.scan_info.msg = mock_scan_info
|
||||
dev.scan_info.msg = mock_scan_status_message
|
||||
try:
|
||||
dev.scan_parameters = fetch_scan_info(dev.scan_info)
|
||||
yield dev
|
||||
finally:
|
||||
try:
|
||||
@@ -177,7 +188,6 @@ def test_pilatus_on_trigger_cancel_on_stop(pilatus):
|
||||
|
||||
def test_pilatus_on_complete(pilatus: Pilatus):
|
||||
"""Test the on_complete logic of the Pilatus detector."""
|
||||
|
||||
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
||||
# TODO add test cases for xas scans
|
||||
# status = pilatus.complete()
|
||||
@@ -196,8 +206,9 @@ def test_pilatus_on_complete(pilatus: Pilatus):
|
||||
pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
||||
pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
||||
pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.ARMED.value
|
||||
num_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get(
|
||||
"frames_per_trigger", 1
|
||||
num_images = (
|
||||
pilatus.scan_parameters.num_points
|
||||
* pilatus.scan_parameters.additional_scan_parameters.get("frames_per_trigger", 1)
|
||||
)
|
||||
pilatus.hdf.num_captured._read_pv.mock_data = num_images - 1
|
||||
# Call on complete
|
||||
@@ -275,9 +286,12 @@ def test_pilatus_on_complete(pilatus: Pilatus):
|
||||
|
||||
def test_pilatus_on_stage_raises_low_exp_time(pilatus):
|
||||
"""Test that on_stage raises a ValueError if the exposure time is too low."""
|
||||
pilatus.scan_info.msg.scan_parameters["exp_time"] = 0.09
|
||||
scan_msg = pilatus.scan_info.msg
|
||||
if scan_msg.scan_type != "step" and scan_msg.scan_name not in pilatus.xas_xrd_scan_names:
|
||||
pilatus.scan_info.msg.info["exp_time"] = 0.09
|
||||
pilatus.scan_parameters = fetch_scan_info(pilatus.scan_info)
|
||||
if (
|
||||
pilatus.scan_parameters.scan_type != "software_triggered"
|
||||
and pilatus.scan_parameters.scan_name not in pilatus.xas_xrd_scan_names
|
||||
):
|
||||
return
|
||||
with pytest.raises(ValueError):
|
||||
pilatus.on_stage()
|
||||
|
||||
Reference in New Issue
Block a user