From 39adeb72de3d4644b17958f0a269e994e21157c0 Mon Sep 17 00:00:00 2001 From: gac-x01da Date: Tue, 27 May 2025 11:40:49 +0200 Subject: [PATCH] test: add test for nidaq continous scan --- debye_bec/devices/mo1_bragg/mo1_bragg.py | 2 +- debye_bec/devices/nidaq/nidaq.py | 34 ++-- debye_bec/scans/__init__.py | 5 +- debye_bec/scans/nidaq_cont_scan.py | 17 +- tests/tests_devices/test_nidaq.py | 167 ++++++++++++++++++ .../tests_scans/test_nidaq_continous_scan.py | 127 +++++++++++++ 6 files changed, 320 insertions(+), 32 deletions(-) create mode 100644 tests/tests_devices/test_nidaq.py create mode 100644 tests/tests_scans/test_nidaq_continous_scan.py diff --git a/debye_bec/devices/mo1_bragg/mo1_bragg.py b/debye_bec/devices/mo1_bragg/mo1_bragg.py index 12a43a3..50a659e 100644 --- a/debye_bec/devices/mo1_bragg/mo1_bragg.py +++ b/debye_bec/devices/mo1_bragg/mo1_bragg.py @@ -204,7 +204,7 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner): self.wait_for_signal( self.scan_control.scan_msg, ScanControlLoadMessage.SUCCESS, - timeout=2*self.timeout_for_pvwait, + timeout=2 * self.timeout_for_pvwait, ) return None diff --git a/debye_bec/devices/nidaq/nidaq.py b/debye_bec/devices/nidaq/nidaq.py index a1eaf3e..f560949 100644 --- a/debye_bec/devices/nidaq/nidaq.py +++ b/debye_bec/devices/nidaq/nidaq.py @@ -346,7 +346,7 @@ class Nidaq(PSIDeviceBase, NidaqControl): def __init__(self, prefix: str = "", *, name: str, scan_info: ScanInfo = None, **kwargs): super().__init__(name=name, prefix=prefix, scan_info=scan_info, **kwargs) - self.scan_info : ScanInfo + self.scan_info: ScanInfo self.timeout_wait_for_signal = 5 # put 5s firsts self._timeout_wait_for_pv = 3 # 3s timeout for pv calls self.valid_scan_names = [ @@ -488,14 +488,16 @@ class Nidaq(PSIDeviceBase, NidaqControl): Called after the device is connected and its signals are connected. Default values for signals should be set here. """ + def heartbeat_callback(*, old_value, value, **kwargs): - return ((old_value) == 0 and (value == 1)) or ((old_value) == 1 and (value == 0)) + return ((old_value) == 0 and (value == 1)) or ((old_value) == 1 and (value == 0)) + status = SubscriptionStatus(self.heartbeat, callback=heartbeat_callback) try: - status.wait(timeout=self.timeout_wait_for_signal) # Raises if timeout is reached + status.wait(timeout=self.timeout_wait_for_signal) # Raises if timeout is reached except WaitTimeoutError: self.power.put(1) - + status.wait(timeout=self.timeout_wait_for_signal) if not self.wait_for_condition( @@ -527,15 +529,19 @@ class Nidaq(PSIDeviceBase, NidaqControl): raise NidaqError( f"Device {self.name} has not been reached in state STANDBY, current state {NidaqState(self.state.get())}" ) - # If scan is not part of the valid_scan_names, + # If scan is not part of the valid_scan_names, if self.scan_info.msg.scan_name != "nidaq_continuous_scan": self.scan_type.set(ScanType.TRIGGERED).wait(timeout=self._timeout_wait_for_pv) self.scan_duration.set(0).wait(timeout=self._timeout_wait_for_pv) self.enable_compression.set(1).wait(timeout=self._timeout_wait_for_pv) else: self.scan_type.set(ScanType.CONTINUOUS).wait(timeout=self._timeout_wait_for_pv) - self.scan_duration.set(self.scan_info.msg.scan_parameters["scan_duration"]).wait(timeout=self._timeout_wait_for_pv) - self.enable_compression.set(self.scan_info.msg.scan_parameters["compression"]).wait(timeout=self._timeout_wait_for_pv) + self.scan_duration.set(self.scan_info.msg.scan_parameters["scan_duration"]).wait( + timeout=self._timeout_wait_for_pv + ) + self.enable_compression.set(self.scan_info.msg.scan_parameters["compression"]).wait( + timeout=self._timeout_wait_for_pv + ) self.stage_call.set(1).wait(timeout=self._timeout_wait_for_pv) @@ -553,7 +559,7 @@ class Nidaq(PSIDeviceBase, NidaqControl): logger.info(f"Device {self.name} was staged: {NidaqState(self.state.get())}") def on_kickoff(self) -> DeviceStatus | StatusBase: - """ Kickoff the Nidaq""" + """Kickoff the Nidaq""" status = self.kickoff_call.set(1) return status @@ -583,11 +589,9 @@ class Nidaq(PSIDeviceBase, NidaqControl): """ if not self._check_if_scan_name_is_valid(): return None - + if self.scan_info.msg.scan_name == "nidaq_continuous_scan": - logger.info( - f"Device {self.name} ready to be kicked off for nidaq_continuous_scan" - ) + logger.info(f"Device {self.name} ready to be kicked off for nidaq_continuous_scan") return None def _wait_for_state(): @@ -625,7 +629,7 @@ class Nidaq(PSIDeviceBase, NidaqControl): # if time.time() > timeout_time: # raise TimeoutError(f"Device {self.name} ran into timeout") time.sleep(0.1) - + if self.scan_info.msg.scan_name != "nidaq_continuous_scan": self.on_stop() timeout = self.timeout_wait_for_signal @@ -637,7 +641,7 @@ class Nidaq(PSIDeviceBase, NidaqControl): else: status = self.task_handler.submit_task(task=_check_state, task_args=(self,)) return status - + def _progress_update(self, value, **kwargs) -> None: """Callback method to update the scan progress, runs a callback to SUB_PROGRESS subscribers, i.e. BEC. @@ -645,7 +649,7 @@ class Nidaq(PSIDeviceBase, NidaqControl): Args: value (int) : current progress value """ - scan_duration = self.scan_info.msg.scan_parameters.get("scan_duration", None) + scan_duration = self.scan_info.msg.scan_parameters.get("scan_duration", None) if not isinstance(scan_duration, (int, float)): return value = scan_duration - value diff --git a/debye_bec/scans/__init__.py b/debye_bec/scans/__init__.py index 206af37..9a3e710 100644 --- a/debye_bec/scans/__init__.py +++ b/debye_bec/scans/__init__.py @@ -4,7 +4,4 @@ from .mono_bragg_scans import ( XASSimpleScan, XASSimpleScanWithXRD, ) - -from .nidaq_cont_scan import ( - NIDAQContinuousScan, -) +from .nidaq_cont_scan import NIDAQContinuousScan diff --git a/debye_bec/scans/nidaq_cont_scan.py b/debye_bec/scans/nidaq_cont_scan.py index 7adf5dc..172da13 100644 --- a/debye_bec/scans/nidaq_cont_scan.py +++ b/debye_bec/scans/nidaq_cont_scan.py @@ -20,19 +20,12 @@ class NIDAQContinuousScan(AsyncFlyScanBase): required_kwargs = [] use_scan_progress_report = False pre_move = False - gui_config = { - "Scan Parameters": ["scan_duration"], - "Data Compression" : ["compression"], - } + gui_config = {"Scan Parameters": ["scan_duration"], "Data Compression": ["compression"]} def __init__( - self, - scan_duration: float, - daq: DeviceBase = "nidaq", - compression: bool = False, - **kwargs, + self, scan_duration: float, daq: DeviceBase = "nidaq", compression: bool = False, **kwargs ): - """ The NIDAQ continuous scan is used to measure with the NIDAQ without moving the + """The NIDAQ continuous scan is used to measure with the NIDAQ without moving the monochromator or any other motor. The NIDAQ thus runs in continuous mode, with a set scan_duration. @@ -78,7 +71,7 @@ class NIDAQContinuousScan(AsyncFlyScanBase): Kickoff the acquisition of the NIDAQ wait for the completion of the scan. """ kickoff_status = yield from self.stubs.kickoff(device=self.daq) - kickoff_status.wait(timeout=5) # wait for proper kickoff of device + kickoff_status.wait(timeout=5) # wait for proper kickoff of device complete_status = yield from self.stubs.complete(device=self.daq, wait=False) @@ -88,4 +81,4 @@ class NIDAQContinuousScan(AsyncFlyScanBase): time.sleep(self.primary_readout_cycle) self.point_id += 1 - self.num_pos = self.point_id \ No newline at end of file + self.num_pos = self.point_id diff --git a/tests/tests_devices/test_nidaq.py b/tests/tests_devices/test_nidaq.py new file mode 100644 index 0000000..0bd8c6f --- /dev/null +++ b/tests/tests_devices/test_nidaq.py @@ -0,0 +1,167 @@ +# pylint: skip-file +import threading +from typing import Generator +from unittest import mock + +import ophyd +import pytest +from bec_server.scan_server.scan_worker import ScanWorker +from ophyd.status import WaitTimeoutError +from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError +from ophyd_devices.tests.utils import MockPV + +# from bec_server.device_server.tests.utils import DMMock +from debye_bec.devices.nidaq.nidaq import Nidaq, NidaqError + +# TODO move this function to ophyd_devices, it is duplicated in csaxs_bec and needed for other pluging repositories +from debye_bec.devices.test_utils.utils import patch_dual_pvs + + +@pytest.fixture(scope="function") +def scan_worker_mock(scan_server_mock): + """Scan worker fixture, utility to generate scan_info for a given scan name.""" + scan_server_mock.device_manager.connector = mock.MagicMock() + scan_worker = ScanWorker(parent=scan_server_mock) + yield scan_worker + + +@pytest.fixture(scope="function") +def mock_nidaq() -> Generator[Nidaq, None, None]: + """Fixture for the Nidaq device.""" + name = "nidaq" + prefix = "nidaq:prefix_test:" + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + dev = Nidaq(name=name, prefix=prefix) + patch_dual_pvs(dev) + yield dev + + +def test_init(mock_nidaq): + """Test the initialization of the Nidaq device.""" + dev = mock_nidaq + assert dev.name == "nidaq" + assert dev.prefix == "nidaq:prefix_test:" + assert dev.valid_scan_names == [ + "xas_simple_scan", + "xas_simple_scan_with_xrd", + "xas_advanced_scan", + "xas_advanced_scan_with_xrd", + "nidaq_continuous_scan", + ] + + +def test_check_if_scan_name_is_valid(mock_nidaq): + """Test the check_if_scan_name_is_valid method.""" + dev = mock_nidaq + dev.scan_info.msg.scan_name = "xas_simple_scan" + assert dev._check_if_scan_name_is_valid() + dev.scan_info.msg.scan_name = "invalid_scan_name" + assert not dev._check_if_scan_name_is_valid() + + +def test_set_config(mock_nidaq): + dev = mock_nidaq + # TODO #21 Add test logic for set_config, issue created # + + +def test_on_connected(mock_nidaq): + """Test the on_connected method of the Nidaq device.""" + dev = mock_nidaq + dev.power.put(0) + dev.heartbeat._read_pv.mock_data = 0 + # First scenario, raise timeout error + + # This will raise a WaitTimeoutError error as we currently do not support callbacks in the MockPV + dev.timeout_wait_for_signal = 0.1 + # To check that it raised, we check that dev.power PV is set to 1 + # Set state PV to 0, 1 is expected value + dev.state._read_pv.mock_data = 0 + with pytest.raises(WaitTimeoutError): + dev.on_connected() + assert dev.power.get() == 1 + # TODO, once the MOCKPv supports callbacks, we can test the rest of the logic issue #22 + + +# def test_on_stage(mock_nidaq): +# dev = mock_nidaq +# #TODO Add once MockPV supports callbacks #22 + + +def test_on_kickoff(mock_nidaq): + """Test the on_kickoff method of the Nidaq device.""" + dev = mock_nidaq + dev.kickoff_call.put(0) + dev.kickoff() + assert dev.kickoff_call.get() == 1 + + +def test_on_unstage(mock_nidaq): + """Test the on_unstage method of the Nidaq device.""" + dev = mock_nidaq + dev.state._read_pv.mock_data = 0 # Set state to 0, 1 is Standby + dev._timeout_wait_for_pv = 0.1 # Set a short timeout for testing + dev.enable_compression._read_pv.mock_data = 0 # Compression enabled + with pytest.raises(NidaqError): + dev.on_unstage() + dev.state._read_pv.mock_data = 1 + dev.on_unstage() + assert dev.enable_compression.get() == 1 + + +@pytest.mark.parametrize( + ["scan_name", "raise_error", "nidaq_state"], + [ + ("line_scan", False, 0), + ("xas_simple_scan", False, 3), + ("xas_simple_scan", True, 0), + ("nidaq_continuous_scan", False, 0), + ], +) +def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state): + """Test the on_pre_scan method of the Nidaq device.""" + dev = mock_nidaq + dev.state.put(nidaq_state) + dev.scan_info.msg.scan_name = scan_name + dev._timeout_wait_for_pv = 0.1 # Set a short timeout for testing + if not raise_error: + dev.pre_scan() + else: + with pytest.raises(NidaqError): + dev.pre_scan() + + +def test_on_complete(mock_nidaq): + """Test the on_complete method of the Nidaq device.""" + dev = mock_nidaq + # Check for nidaq_continuous_scan + dev.scan_info.msg.scan_name = "nidaq_continuous_scan" + dev.state.put(0) # Set state to DISABLED + status = dev.complete() + assert status.done is False + dev.state.put(1) + # Should resolve now + status.wait(timeout=5) # Wait for the status to complete + assert status.done is True + + # Check for XAS simple scan + dev.scan_info.msg.scan_name = "xas_simple_scan" + dev.state.put(0) # Set state to ACQUIRE + dev.stop_call.put(0) + dev._timeout_wait_for_pv = 5 + status = dev.on_complete() + assert status.done is False + assert dev.stop_call.get() == 1 # Should have called stop + dev.state.put(1) # Set state to STANDBY + # Should resolve now + status.wait(timeout=5) # Wait for the status to complete + assert status.done is True + + # Test that it resolves if device is stopped + dev.state.put(0) # Set state to DISABLED + dev.stopped = True # Reset stopped state + status = dev.on_complete() + with pytest.raises(NidaqError): + status.wait(timeout=5) + assert status.done is True diff --git a/tests/tests_scans/test_nidaq_continous_scan.py b/tests/tests_scans/test_nidaq_continous_scan.py new file mode 100644 index 0000000..f71cbc2 --- /dev/null +++ b/tests/tests_scans/test_nidaq_continous_scan.py @@ -0,0 +1,127 @@ +# pylint: skip-file +from unittest import mock + +from bec_lib.messages import DeviceInstructionMessage +from bec_server.device_server.tests.utils import DMMock + +from debye_bec.scans import NIDAQContinuousScan + + +def get_instructions(request, ScanStubStatusMock): + request.metadata["RID"] = "my_test_request_id" + + def fake_done(): + """ + Fake done function for ScanStubStatusMock. Upon each call, it returns the next value from the generator. + This is used to simulate the completion of the scan. + """ + yield False + yield False + yield True + + def fake_complete(*args, **kwargs): + yield "fake_complete" + return ScanStubStatusMock(done_func=fake_done) + + with ( + mock.patch.object(request.stubs, "complete", side_effect=fake_complete), + mock.patch.object(request.stubs, "_get_result_from_status", return_value=None), + ): + reference_commands = list(request.run()) + + for cmd in reference_commands: + if not cmd or isinstance(cmd, str): + continue + if "RID" in cmd.metadata: + cmd.metadata["RID"] = "my_test_request_id" + if "rpc_id" in cmd.parameter: + cmd.parameter["rpc_id"] = "my_test_rpc_id" + cmd.metadata.pop("device_instr_id", None) + + return reference_commands + + +def test_xas_simple_scan(scan_assembler, ScanStubStatusMock): + + request = scan_assembler(NIDAQContinuousScan, scan_duration=10) + request.device_manager.add_device("nidaq") + reference_commands = get_instructions(request, ScanStubStatusMock) + assert reference_commands == [ + None, + None, + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id"}, + device=None, + action="scan_report_instruction", + parameter={"device_progress": ["nidaq"]}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id"}, + device=None, + action="open_scan", + parameter={ + "scan_motors": [], + "readout_priority": { + "monitored": [], + "baseline": [], + "on_request": [], + "async": ["nidaq"], + }, + "num_points": 0, + "positions": [], + "scan_name": "nidaq_continuous_scan", + "scan_type": "fly", + }, + ), + DeviceInstructionMessage(metadata={}, device="nidaq", action="stage", parameter={}), + DeviceInstructionMessage( + metadata={}, + device=["bpm4i", "eiger", "mo1_bragg", "samx"], + action="stage", + parameter={}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "baseline", "RID": "my_test_request_id"}, + device=["samx"], + action="read", + parameter={}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id"}, + device=["bpm4i", "eiger", "mo1_bragg", "nidaq", "samx"], + action="pre_scan", + parameter={}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id"}, + device="nidaq", + action="kickoff", + parameter={"configure": {}}, + ), + "fake_complete", + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id", "point_id": 0}, + device=["bpm4i", "eiger", "mo1_bragg"], + action="read", + parameter={"group": "monitored"}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id", "point_id": 1}, + device=["bpm4i", "eiger", "mo1_bragg"], + action="read", + parameter={"group": "monitored"}, + ), + "fake_complete", + DeviceInstructionMessage( + metadata={}, + device=["bpm4i", "eiger", "mo1_bragg", "nidaq", "samx"], + action="unstage", + parameter={}, + ), + DeviceInstructionMessage( + metadata={"readout_priority": "monitored", "RID": "my_test_request_id"}, + device=None, + action="close_scan", + parameter={}, + ), + ]