From 74ff173f9824c29da33aefd3a80fd88d7c3e355c Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 22 May 2026 14:45:11 +0200 Subject: [PATCH] refactor: cleanup, fix tests --- debye_bec/devices/mo1_bragg/mo1_bragg.py | 322 ++++++++++++----------- debye_bec/devices/pilatus/pilatus.py | 23 +- debye_bec/devices/utils/utils.py | 17 +- debye_bec/scans/xas_simple_scan.py | 16 +- tests/tests_devices/test_mo1_bragg.py | 54 ---- tests/tests_devices/test_nidaq.py | 38 ++- tests/tests_devices/test_pilatus.py | 70 +++-- 7 files changed, 280 insertions(+), 260 deletions(-) diff --git a/debye_bec/devices/mo1_bragg/mo1_bragg.py b/debye_bec/devices/mo1_bragg/mo1_bragg.py index 69df2d2..30419b9 100644 --- a/debye_bec/devices/mo1_bragg/mo1_bragg.py +++ b/debye_bec/devices/mo1_bragg/mo1_bragg.py @@ -70,6 +70,13 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner): super().__init__(name=name, scan_info=scan_info, prefix=prefix, **kwargs) self.scan_parameters: ScanServerScanInfo = None self.timeout_for_pvwait = 7.5 + self.valid_scan_names = [ + "xas_simple_scan", + "xas_simple_scan_with_xrd", + "xas_advanced_scan", + "xas_advanced_scan_with_xrd", + "nidaq_continuous_scan", + ] ######################################## # Beamline Specific Implementations # @@ -104,158 +111,172 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner): status.wait(timeout=self.timeout_for_pvwait) scan_name = self.scan_parameters.scan_name - start, stop = self.scan_parameters.positions or (None, None) - scan_time = self.scan_parameters.additional_scan_parameters.get("scan_time", None) - scan_duration = self.scan_parameters.additional_scan_parameters.get("scan_duration", None) - if scan_name == "xas_simple_scan": - if any(param is None for param in [start, stop, scan_time, scan_duration]): - raise Mo1BraggError( - f"Missing scan parameters for xas_simple_scan. Required parameters: start, stop, scan_time, scan_duration in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + if self._check_if_scan_name_is_valid(self.scan_parameters): + start, stop = self.scan_parameters.positions or (None, None) + scan_time = self.scan_parameters.additional_scan_parameters.get("scan_time", None) + scan_duration = self.scan_parameters.additional_scan_parameters.get( + "scan_duration", None + ) + if scan_name == "xas_simple_scan": + if any(param is None for param in [start, stop, scan_time, scan_duration]): + raise Mo1BraggError( + f"Missing scan parameters for xas_simple_scan. Required parameters: start, stop, scan_time, scan_duration in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + ) + self.set_xas_settings(low=start, high=stop, scan_time=scan_time) + self.set_trig_settings( + enable_low=False, + enable_high=False, + break_time_low=0, + break_time_high=0, + cycle_low=0, + cycle_high=0, + exp_time=0, + n_of_trigger=0, ) - self.set_xas_settings(low=start, high=stop, scan_time=scan_time) - self.set_trig_settings( - enable_low=False, - enable_high=False, - break_time_low=0, - break_time_high=0, - cycle_low=0, - cycle_high=0, - exp_time=0, - n_of_trigger=0, - ) - self.set_scan_control_settings(mode=ScanControlMode.SIMPLE, scan_duration=scan_duration) - elif scan_name == "xas_simple_scan_with_xrd": - break_enable_low = self.scan_parameters.additional_scan_parameters.get( - "break_enable_low", None - ) - break_enable_high = self.scan_parameters.additional_scan_parameters.get( - "break_enable_high", None - ) - break_time_low = self.scan_parameters.additional_scan_parameters.get( - "break_time_low", None - ) - break_time_high = self.scan_parameters.additional_scan_parameters.get( - "break_time_high", None - ) - cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None) - cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None) - exp_time = self.scan_parameters.exp_time - n_of_trigger = self.scan_parameters.additional_scan_parameters.get("n_of_trigger", None) - if any( - param is None - for param in [ - start, - stop, - scan_time, - scan_duration, - break_enable_low, - break_enable_high, - break_time_low, - break_time_high, - cycle_low, - cycle_high, - exp_time, - n_of_trigger, - ] - ): - raise Mo1BraggError( - f"Missing scan parameters for xas_simple_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + self.set_scan_control_settings( + mode=ScanControlMode.SIMPLE, scan_duration=scan_duration ) - self.set_xas_settings(low=start, high=stop, scan_time=scan_time) - self.set_trig_settings( - enable_low=break_enable_low, - enable_high=break_enable_high, - break_time_low=break_time_low, - break_time_high=break_time_high, - cycle_low=cycle_low, - cycle_high=cycle_high, - exp_time=exp_time, - n_of_trigger=n_of_trigger, - ) - self.set_scan_control_settings(mode=ScanControlMode.SIMPLE, scan_duration=scan_duration) - elif scan_name == "xas_advanced_scan": - p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None) - e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None) - if any( - param is None for param in [start, stop, scan_time, scan_duration, p_kink, e_kink] - ): - raise Mo1BraggError( - f"Missing scan parameters for xas_advanced_scan. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + elif scan_name == "xas_simple_scan_with_xrd": + break_enable_low = self.scan_parameters.additional_scan_parameters.get( + "break_enable_low", None ) - self.set_advanced_xas_settings( - low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink - ) - self.set_trig_settings( - enable_low=False, - enable_high=False, - break_time_low=0, - break_time_high=0, - cycle_low=0, - cycle_high=0, - exp_time=0, - n_of_trigger=0, - ) - self.set_scan_control_settings( - mode=ScanControlMode.ADVANCED, scan_duration=scan_duration - ) - elif scan_name == "xas_advanced_scan_with_xrd": - p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None) - e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None) - break_enable_low = self.scan_parameters.additional_scan_parameters.get( - "break_enable_low", None - ) - break_enable_high = self.scan_parameters.additional_scan_parameters.get( - "break_enable_high", None - ) - break_time_low = self.scan_parameters.additional_scan_parameters.get( - "break_time_low", None - ) - break_time_high = self.scan_parameters.additional_scan_parameters.get( - "break_time_high", None - ) - cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None) - cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None) - exp_time = self.scan_parameters.exp_time - n_of_trigger = self.scan_parameters.additional_scan_parameters.get("n_of_trigger", None) - if any( - param is None - for param in [ - start, - stop, - scan_time, - scan_duration, - p_kink, - e_kink, - break_enable_low, - break_enable_high, - break_time_low, - break_time_high, - cycle_low, - cycle_high, - exp_time, - n_of_trigger, - ] - ): - raise Mo1BraggError( - f"Missing scan parameters for xas_advanced_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + break_enable_high = self.scan_parameters.additional_scan_parameters.get( + "break_enable_high", None ) + break_time_low = self.scan_parameters.additional_scan_parameters.get( + "break_time_low", None + ) + break_time_high = self.scan_parameters.additional_scan_parameters.get( + "break_time_high", None + ) + cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None) + cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None) + exp_time = self.scan_parameters.exp_time + n_of_trigger = self.scan_parameters.additional_scan_parameters.get( + "n_of_trigger", None + ) + if any( + param is None + for param in [ + start, + stop, + scan_time, + scan_duration, + break_enable_low, + break_enable_high, + break_time_low, + break_time_high, + cycle_low, + cycle_high, + exp_time, + n_of_trigger, + ] + ): + raise Mo1BraggError( + f"Missing scan parameters for xas_simple_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + ) + self.set_xas_settings(low=start, high=stop, scan_time=scan_time) + self.set_trig_settings( + enable_low=break_enable_low, + enable_high=break_enable_high, + break_time_low=break_time_low, + break_time_high=break_time_high, + cycle_low=cycle_low, + cycle_high=cycle_high, + exp_time=exp_time, + n_of_trigger=n_of_trigger, + ) + self.set_scan_control_settings( + mode=ScanControlMode.SIMPLE, scan_duration=scan_duration + ) + elif scan_name == "xas_advanced_scan": + p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None) + e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None) + if any( + param is None + for param in [start, stop, scan_time, scan_duration, p_kink, e_kink] + ): + raise Mo1BraggError( + f"Missing scan parameters for xas_advanced_scan. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + ) + self.set_advanced_xas_settings( + low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink + ) + self.set_trig_settings( + enable_low=False, + enable_high=False, + break_time_low=0, + break_time_high=0, + cycle_low=0, + cycle_high=0, + exp_time=0, + n_of_trigger=0, + ) + self.set_scan_control_settings( + mode=ScanControlMode.ADVANCED, scan_duration=scan_duration + ) + elif scan_name == "xas_advanced_scan_with_xrd": + p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None) + e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None) + break_enable_low = self.scan_parameters.additional_scan_parameters.get( + "break_enable_low", None + ) + break_enable_high = self.scan_parameters.additional_scan_parameters.get( + "break_enable_high", None + ) + break_time_low = self.scan_parameters.additional_scan_parameters.get( + "break_time_low", None + ) + break_time_high = self.scan_parameters.additional_scan_parameters.get( + "break_time_high", None + ) + cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None) + cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None) + exp_time = self.scan_parameters.exp_time + n_of_trigger = self.scan_parameters.additional_scan_parameters.get( + "n_of_trigger", None + ) + if any( + param is None + for param in [ + start, + stop, + scan_time, + scan_duration, + p_kink, + e_kink, + break_enable_low, + break_enable_high, + break_time_low, + break_time_high, + cycle_low, + cycle_high, + exp_time, + n_of_trigger, + ] + ): + raise Mo1BraggError( + f"Missing scan parameters for xas_advanced_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}" + ) - self.set_advanced_xas_settings( - low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink - ) - self.set_trig_settings( - enable_low=break_enable_low, - enable_high=break_enable_high, - break_time_low=break_time_low, - break_time_high=break_time_high, - cycle_low=cycle_low, - cycle_high=cycle_high, - exp_time=exp_time, - n_of_trigger=n_of_trigger, - ) - self.set_scan_control_settings( - mode=ScanControlMode.ADVANCED, scan_duration=scan_duration - ) + self.set_advanced_xas_settings( + low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink + ) + self.set_trig_settings( + enable_low=break_enable_low, + enable_high=break_enable_high, + break_time_low=break_time_low, + break_time_high=break_time_high, + cycle_low=cycle_low, + cycle_high=cycle_high, + exp_time=exp_time, + n_of_trigger=n_of_trigger, + ) + self.set_scan_control_settings( + mode=ScanControlMode.ADVANCED, scan_duration=scan_duration + ) + else: + return # Should never happen. else: return # Setting scan duration seems to lag behind slightly in the backend, include small sleep @@ -335,6 +356,13 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner): self.stopped = True # Needs to be set to stop motion ######### Utility Methods ######### + + def _check_if_scan_name_is_valid(self, scan_parameters: ScanServerScanInfo) -> bool: + """Check if the scan is within the list of scans for which the backend is working""" + if scan_parameters.scan_name in self.valid_scan_names: + return True + return False + def _progress_update(self, value, **kwargs) -> None: """Callback method to update the scan progress, runs a callback to SUB_PROGRESS subscribers, i.e. BEC. diff --git a/debye_bec/devices/pilatus/pilatus.py b/debye_bec/devices/pilatus/pilatus.py index 6c721fe..91a422e 100644 --- a/debye_bec/devices/pilatus/pilatus.py +++ b/debye_bec/devices/pilatus/pilatus.py @@ -260,7 +260,7 @@ class Pilatus(PSIDeviceBase, ADBase): # self._live_mode_run_event = threading.Event() # self._live_mode_stopped_event = threading.Event() # self._live_mode_stopped_event.set() # Initial state is stopped - self.scan_parameters: ScanServerScanInfo | None = None + self.scan_parameters: ScanServerScanInfo = None ######################################## # Custom Beamline Methods # @@ -495,8 +495,7 @@ class Pilatus(PSIDeviceBase, ADBase): status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value) status_cam.wait(timeout=5) - scan_msg: ScanStatusMessage = self.scan_info.msg - if scan_msg.scan_name in self.xas_xrd_scan_names: + if self.scan_parameters.scan_name in self.xas_xrd_scan_names: # Compute number of triggers total_trig_lo, total_trig_hi = self._calculate_trigger(self.scan_parameters) # Set the number of images, we may also set this to a higher values if preferred and stop the acquisition @@ -511,10 +510,12 @@ class Pilatus(PSIDeviceBase, ADBase): self.trigger_n_of.set( self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1) ).wait(5) - - elif scan_msg.scan_type == "step": + # TODO migrate logic to v4 once old scans are deprecated, + # TODO if num_points=None and no logic from scan_name applies, can't measure with this detector.. + elif self.scan_parameters.scan_type == "software_triggered": self.n_images = ( - self.scan_parameters.num_monitored_readouts * scan_msg.frames_per_trigger + self.scan_parameters.num_monitored_readouts + * self.scan_parameters.frames_per_trigger ) exp_time = self.scan_parameters.exp_time self.trigger_source.set(MONOTRIGGERSOURCE.EPICS).wait(5) @@ -537,7 +538,7 @@ class Pilatus(PSIDeviceBase, ADBase): ) ) detector_exp_time = exp_time - self._readout_time - self._full_path = get_full_path(scan_msg, name="pilatus") + self._full_path = get_full_path(self.scan_info.msg, name="pilatus") file_path = "/".join(self._full_path.split("/")[:-1]) file_name = self._full_path.split("/")[-1] # Prepare detector and backend @@ -571,7 +572,7 @@ class Pilatus(PSIDeviceBase, ADBase): """Called right before the scan starts on all devices automatically.""" if ( self.scan_parameters.scan_name in self.xas_xrd_scan_names - or self.scan_parameters.scan_type == "step" + or self.scan_parameters.scan_type == "software_triggered" ): # TODO how to deal with fly scans? status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.ACQUIRING.value) status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.ACQUIRING.value) @@ -586,7 +587,7 @@ class Pilatus(PSIDeviceBase, ADBase): def on_trigger(self) -> DeviceStatus | None: """Called when the device is triggered.""" - if not self.scan_parameters.scan_type == "step": + if not self.scan_parameters.scan_type == "software_triggered": return None start_time = time.time() img_counter = self.hdf.num_captured.get() @@ -601,7 +602,7 @@ class Pilatus(PSIDeviceBase, ADBase): """Callback for when the device completes a scan.""" if ( self.scan_parameters.scan_name in self.xas_xrd_scan_names - or self.scan_parameters.scan_type == "step" + or self.scan_parameters.scan_type == "software_triggered" ): # TODO how to deal with fly scans? if status.success: self.file_event.put( @@ -624,7 +625,7 @@ class Pilatus(PSIDeviceBase, ADBase): """Called to inquire if a device has completed a scans.""" if ( self.scan_parameters.scan_name in self.xas_xrd_scan_names - or self.scan_parameters.scan_type == "step" + or self.scan_parameters.scan_type == "software_triggered" ): # TODO how to deal with fly scans? status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value) status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value) diff --git a/debye_bec/devices/utils/utils.py b/debye_bec/devices/utils/utils.py index 8e92287..1475961 100644 --- a/debye_bec/devices/utils/utils.py +++ b/debye_bec/devices/utils/utils.py @@ -1,8 +1,11 @@ """Utility functions for the devices.""" +from copy import deepcopy + +import numpy as np from bec_lib.devicemanager import ScanInfo from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo -import numpy as np +from pydantic import ValidationError def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo: @@ -10,4 +13,14 @@ def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo: info = scan_info.msg.info if isinstance(info["positions"], list): info["positions"] = np.array(info["positions"]) - return ScanServerScanInfo.model_validate(info) + try: + msg = ScanServerScanInfo.model_validate(info) + except ValidationError: # This means we have an old scan_info object. + info = deepcopy(info) + # We need to convert a few parameters manually. + info["scan_type"] = ( + "hardware_triggered" if info["scan_type"] == "fly" else "software_triggered" + ) + msg = ScanServerScanInfo.model_validate(info) + + return msg diff --git a/debye_bec/scans/xas_simple_scan.py b/debye_bec/scans/xas_simple_scan.py index 871093a..d6a313f 100644 --- a/debye_bec/scans/xas_simple_scan.py +++ b/debye_bec/scans/xas_simple_scan.py @@ -37,7 +37,7 @@ class XasSimpleScan(ScanBase): def __init__( self, - #fmt: off + # fmt: off start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV, ge=4500, le=64000)], stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV, ge=4500, le=64000)], scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0.05)], @@ -45,7 +45,7 @@ class XasSimpleScan(ScanBase): motor: Annotated[DeviceBase | None, ScanArgument(display_name="Motor", description="Bragg motor device.")] = None, daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None, monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.",units=Units.s, gt=0)] = 1, - #fmt: on + # fmt: on **kwargs, ): """ @@ -186,7 +186,7 @@ class XasSimpleScanWithXrd(XasSimpleScan): def __init__( self, - #fmt: off + # fmt: off start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)], stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)], scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)], @@ -203,7 +203,7 @@ class XasSimpleScanWithXrd(XasSimpleScan): daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None, monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1, **kwargs, - #fmt: on + # fmt: on ): super().__init__( start=start, @@ -239,7 +239,7 @@ class XasAdvancedScan(XasSimpleScan): def __init__( self, - #fmt: off + # fmt: off start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)], stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)], scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)], @@ -250,7 +250,7 @@ class XasAdvancedScan(XasSimpleScan): daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None, monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1, **kwargs, - #fmt: on + # fmt: on ): super().__init__( start=start, @@ -279,7 +279,7 @@ class XasAdvancedScanWithXrd(XasAdvancedScan): def __init__( self, - #fmt: off + # fmt: off start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)], stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)], scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)], @@ -298,7 +298,7 @@ class XasAdvancedScanWithXrd(XasAdvancedScan): daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None, monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1, **kwargs, - #fmt: on + # fmt: on ): super().__init__( start=start, diff --git a/tests/tests_devices/test_mo1_bragg.py b/tests/tests_devices/test_mo1_bragg.py index e1fd819..89fd6e2 100644 --- a/tests/tests_devices/test_mo1_bragg.py +++ b/tests/tests_devices/test_mo1_bragg.py @@ -159,60 +159,6 @@ def test_set_control_settings(mock_bragg): assert dev.scan_control.scan_duration.get() == 5 -def test_update_scan_parameters(mock_bragg): - dev = mock_bragg - msg = ScanStatusMessage( - scan_id="my_scan_id", - status="closed", - request_inputs={ - "inputs": {}, - "kwargs": { - "start": 0, - "stop": 5, - "scan_time": 1, - "scan_duration": 10, - "xrd_enable_low": True, - "xrd_enable_high": False, - "num_trigger_low": 1, - "num_trigger_high": 7, - "exp_time_low": 1, - "exp_time_high": 3, - "cycle_low": 1, - "cycle_high": 5, - "p_kink": 50, - "e_kink": 8000, - }, - }, - info={ - "kwargs": { - "start": 0, - "stop": 5, - "scan_time": 1, - "scan_duration": 10, - "xrd_enable_low": True, - "xrd_enable_high": False, - "num_trigger_low": 1, - "num_trigger_high": 7, - "exp_time_low": 1, - "exp_time_high": 3, - "cycle_low": 1, - "cycle_high": 5, - "p_kink": 50, - "e_kink": 8000, - } - }, - metadata={}, - ) - mock_bragg.scan_info.msg = msg - scan_param = dev.scan_parameter.model_dump() - for _, v in scan_param.items(): - assert v == None - dev._update_scan_parameter() - scan_param = dev.scan_parameter.model_dump() - for k, v in scan_param.items(): - assert v == msg.content["request_inputs"]["kwargs"].get(k, None) - - def test_kickoff_scan(mock_bragg): dev = mock_bragg dev.scan_control.scan_status._read_pv.mock_data = ScanControlScanStatus.READY diff --git a/tests/tests_devices/test_nidaq.py b/tests/tests_devices/test_nidaq.py index 972eb4a..85057a2 100644 --- a/tests/tests_devices/test_nidaq.py +++ b/tests/tests_devices/test_nidaq.py @@ -6,6 +6,7 @@ from unittest import mock import ophyd import pytest from bec_server.scan_server.scan_worker import ScanWorker +from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo from ophyd.status import WaitTimeoutError from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError from ophyd_devices.tests.utils import MockPV @@ -15,6 +16,13 @@ from debye_bec.devices.nidaq.nidaq import Nidaq, NidaqError # TODO move this function to ophyd_devices, it is duplicated in csaxs_bec and needed for other pluging repositories from debye_bec.devices.test_utils.utils import patch_dual_pvs +from debye_bec.devices.utils.utils import fetch_scan_info + + +@pytest.fixture(scope="function") +def scan_info_mock(): + """Fixture for the ScanInfo object.""" + return ScanServerScanInfo(scan_name="xas_simple_scan", scan_id="test") @pytest.fixture(scope="function") @@ -52,13 +60,17 @@ def test_init(mock_nidaq): ] -def test_check_if_scan_name_is_valid(mock_nidaq): +def test_check_if_scan_name_is_valid(mock_nidaq, scan_info_mock): """Test the check_if_scan_name_is_valid method.""" dev = mock_nidaq - dev.scan_info.msg.scan_name = "xas_simple_scan" - assert dev._check_if_scan_name_is_valid() - dev.scan_info.msg.scan_name = "invalid_scan_name" - assert not dev._check_if_scan_name_is_valid() + scan_info_mock.scan_name = "xas_simple_scan" + dev.scan_info.msg.info.update(scan_info_mock.model_dump()) + scan_parameters = fetch_scan_info(dev.scan_info) + assert dev._check_if_scan_name_is_valid(scan_parameters) + scan_info_mock.scan_name = "invalid_scan_name" + dev.scan_info.msg.info.update(scan_info_mock.model_dump()) + scan_parameters = fetch_scan_info(dev.scan_info) + assert not dev._check_if_scan_name_is_valid(scan_parameters) def test_set_config(mock_nidaq): @@ -120,11 +132,13 @@ def test_on_unstage(mock_nidaq): ("nidaq_continuous_scan", False, 0), ], ) -def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state): +def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state, scan_info_mock): """Test the on_pre_scan method of the Nidaq device.""" dev = mock_nidaq dev.state.put(nidaq_state) - dev.scan_info.msg.scan_name = scan_name + scan_info_mock.scan_name = scan_name + dev.scan_info.msg.info.update(scan_info_mock.model_dump()) + dev.scan_parameters = fetch_scan_info(dev.scan_info) dev._timeout_wait_for_pv = 0.1 # Set a short timeout for testing if not raise_error: dev.pre_scan() @@ -133,11 +147,13 @@ def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state): dev.pre_scan() -def test_on_complete(mock_nidaq): +def test_on_complete(mock_nidaq, scan_info_mock): """Test the on_complete method of the Nidaq device.""" dev = mock_nidaq + scan_info_mock.scan_name = "nidaq_continuous_scan" + dev.scan_info.msg.info.update(scan_info_mock.model_dump()) + dev.scan_parameters = fetch_scan_info(dev.scan_info) # Check for nidaq_continuous_scan - dev.scan_info.msg.scan_name = "nidaq_continuous_scan" dev.state.put(0) # Set state to DISABLED status = dev.complete() assert status.done is False @@ -147,7 +163,9 @@ def test_on_complete(mock_nidaq): assert status.done is True # Check for XAS simple scan - dev.scan_info.msg.scan_name = "xas_simple_scan" + scan_info_mock.scan_name = "xas_simple_scan" + dev.scan_info.msg.info.update(scan_info_mock.model_dump()) + dev.scan_parameters = fetch_scan_info(dev.scan_info) dev.state.put(0) # Set state to ACQUIRE dev.stop_call.put(0) dev._timeout_wait_for_pv = 5 diff --git a/tests/tests_devices/test_pilatus.py b/tests/tests_devices/test_pilatus.py index 6403e94..3c0bac2 100644 --- a/tests/tests_devices/test_pilatus.py +++ b/tests/tests_devices/test_pilatus.py @@ -7,6 +7,9 @@ import ophyd import pytest from bec_lib.messages import ScanStatusMessage from bec_server.scan_server.scan_worker import ScanWorker +from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo +from bec_server.scan_server.tests.scan_fixtures import * +from bec_server.scan_server.tests.scan_fixtures import _MockDevice from ophyd_devices import CompareStatus, DeviceStatus from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError from ophyd_devices.tests.utils import MockPV, patch_dual_pvs @@ -20,6 +23,7 @@ from debye_bec.devices.pilatus.pilatus import ( TRIGGERMODE, Pilatus, ) +from debye_bec.devices.utils.utils import fetch_scan_info if TYPE_CHECKING: # pragma no cover from bec_lib.messages import FileMessage @@ -34,32 +38,38 @@ if TYPE_CHECKING: # pragma no cover @pytest.fixture( scope="function", params=[ - (0.1, 1, 1, "line_scan", "step"), - (0.2, 2, 2, "time_scan", "step"), - (0.5, 5, 5, "xas_advanced_scan", "fly"), + (("samx", 0.1, 1, 5, "samy", 0, 1, 5), {"relative": True}, "_v4_hexagonal_scan"), + ((1, 0.2), {}, "_v4_time_scan"), + ((9000, 10000, 1, 20, 0.1, 9500), {}, "xas_advanced_scan"), ], ) -def mock_scan_info(request, tmpdir): - exp_time, frames_per_trigger, num_points, scan_name, scan_type = request.param - scan_info = ScanStatusMessage( - scan_id="test_id", - status="open", - scan_type=scan_type, - scan_number=1, - scan_parameters={ - "exp_time": exp_time, - "frames_per_trigger": frames_per_trigger, - "system_config": {}, - }, - info={"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")}, - num_points=num_points, - scan_name=scan_name, - ) - yield scan_info +def mock_scan_info(request, tmpdir, v4_scan_assembler, device_manager): + args, kwargs, scan_name = request.param + mo1_bragg = _MockDevice(name="mo1_bragg") + nidaq = _MockDevice(name="nidaq") + device_manager.add_device(mo1_bragg) + device_manager.add_device(nidaq) + scan = v4_scan_assembler(scan_name, *args, **kwargs) + yield scan.scan_info @pytest.fixture(scope="function") -def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]: +def mock_scan_status_message(mock_scan_info, tmpdir) -> ScanStatusMessage: + info = mock_scan_info.model_dump() + info.update({"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")}) + return ScanStatusMessage( + scan_id=mock_scan_info.scan_id, + status="open", + scan_number=1, + scan_name=mock_scan_info.scan_name, + scan_type="fly" if mock_scan_info.scan_type == "hardware_triggered" else "step", + num_points=mock_scan_info.num_points, + info=info, + ) + + +@pytest.fixture(scope="function") +def pilatus(mock_scan_status_message) -> Generator[Pilatus, None, None]: name = "pilatus" prefix = "X01DA-OP-MO1:PILATUS:" with mock.patch.object(ophyd, "cl") as mock_cl: @@ -70,8 +80,9 @@ def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]: # dev.image1 = mock.MagicMock() # with mock.patch.object(dev, "image1"): with mock.patch.object(dev, "task_handler"): - dev.scan_info.msg = mock_scan_info + dev.scan_info.msg = mock_scan_status_message try: + dev.scan_parameters = fetch_scan_info(dev.scan_info) yield dev finally: try: @@ -177,7 +188,6 @@ def test_pilatus_on_trigger_cancel_on_stop(pilatus): def test_pilatus_on_complete(pilatus: Pilatus): """Test the on_complete logic of the Pilatus detector.""" - if pilatus.scan_info.msg.scan_name.startswith("xas"): # TODO add test cases for xas scans # status = pilatus.complete() @@ -196,8 +206,9 @@ def test_pilatus_on_complete(pilatus: Pilatus): pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.ARMED.value - num_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get( - "frames_per_trigger", 1 + num_images = ( + pilatus.scan_parameters.num_points + * pilatus.scan_parameters.additional_scan_parameters.get("frames_per_trigger", 1) ) pilatus.hdf.num_captured._read_pv.mock_data = num_images - 1 # Call on complete @@ -275,9 +286,12 @@ def test_pilatus_on_complete(pilatus: Pilatus): def test_pilatus_on_stage_raises_low_exp_time(pilatus): """Test that on_stage raises a ValueError if the exposure time is too low.""" - pilatus.scan_info.msg.scan_parameters["exp_time"] = 0.09 - scan_msg = pilatus.scan_info.msg - if scan_msg.scan_type != "step" and scan_msg.scan_name not in pilatus.xas_xrd_scan_names: + pilatus.scan_info.msg.info["exp_time"] = 0.09 + pilatus.scan_parameters = fetch_scan_info(pilatus.scan_info) + if ( + pilatus.scan_parameters.scan_type != "software_triggered" + and pilatus.scan_parameters.scan_name not in pilatus.xas_xrd_scan_names + ): return with pytest.raises(ValueError): pilatus.on_stage()