refactor: cleanup, fix tests
CI for debye_bec / test (pull_request) Successful in 53s
CI for debye_bec / test (push) Successful in 1m1s

This commit is contained in:
2026-05-22 14:45:11 +02:00
parent 87758710d9
commit 74ff173f98
7 changed files with 280 additions and 260 deletions
+175 -147
View File
@@ -70,6 +70,13 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
super().__init__(name=name, scan_info=scan_info, prefix=prefix, **kwargs)
self.scan_parameters: ScanServerScanInfo = None
self.timeout_for_pvwait = 7.5
self.valid_scan_names = [
"xas_simple_scan",
"xas_simple_scan_with_xrd",
"xas_advanced_scan",
"xas_advanced_scan_with_xrd",
"nidaq_continuous_scan",
]
########################################
# Beamline Specific Implementations #
@@ -104,158 +111,172 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
status.wait(timeout=self.timeout_for_pvwait)
scan_name = self.scan_parameters.scan_name
start, stop = self.scan_parameters.positions or (None, None)
scan_time = self.scan_parameters.additional_scan_parameters.get("scan_time", None)
scan_duration = self.scan_parameters.additional_scan_parameters.get("scan_duration", None)
if scan_name == "xas_simple_scan":
if any(param is None for param in [start, stop, scan_time, scan_duration]):
raise Mo1BraggError(
f"Missing scan parameters for xas_simple_scan. Required parameters: start, stop, scan_time, scan_duration in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
if self._check_if_scan_name_is_valid(self.scan_parameters):
start, stop = self.scan_parameters.positions or (None, None)
scan_time = self.scan_parameters.additional_scan_parameters.get("scan_time", None)
scan_duration = self.scan_parameters.additional_scan_parameters.get(
"scan_duration", None
)
if scan_name == "xas_simple_scan":
if any(param is None for param in [start, stop, scan_time, scan_duration]):
raise Mo1BraggError(
f"Missing scan parameters for xas_simple_scan. Required parameters: start, stop, scan_time, scan_duration in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
)
self.set_xas_settings(low=start, high=stop, scan_time=scan_time)
self.set_trig_settings(
enable_low=False,
enable_high=False,
break_time_low=0,
break_time_high=0,
cycle_low=0,
cycle_high=0,
exp_time=0,
n_of_trigger=0,
)
self.set_xas_settings(low=start, high=stop, scan_time=scan_time)
self.set_trig_settings(
enable_low=False,
enable_high=False,
break_time_low=0,
break_time_high=0,
cycle_low=0,
cycle_high=0,
exp_time=0,
n_of_trigger=0,
)
self.set_scan_control_settings(mode=ScanControlMode.SIMPLE, scan_duration=scan_duration)
elif scan_name == "xas_simple_scan_with_xrd":
break_enable_low = self.scan_parameters.additional_scan_parameters.get(
"break_enable_low", None
)
break_enable_high = self.scan_parameters.additional_scan_parameters.get(
"break_enable_high", None
)
break_time_low = self.scan_parameters.additional_scan_parameters.get(
"break_time_low", None
)
break_time_high = self.scan_parameters.additional_scan_parameters.get(
"break_time_high", None
)
cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None)
cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None)
exp_time = self.scan_parameters.exp_time
n_of_trigger = self.scan_parameters.additional_scan_parameters.get("n_of_trigger", None)
if any(
param is None
for param in [
start,
stop,
scan_time,
scan_duration,
break_enable_low,
break_enable_high,
break_time_low,
break_time_high,
cycle_low,
cycle_high,
exp_time,
n_of_trigger,
]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_simple_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
self.set_scan_control_settings(
mode=ScanControlMode.SIMPLE, scan_duration=scan_duration
)
self.set_xas_settings(low=start, high=stop, scan_time=scan_time)
self.set_trig_settings(
enable_low=break_enable_low,
enable_high=break_enable_high,
break_time_low=break_time_low,
break_time_high=break_time_high,
cycle_low=cycle_low,
cycle_high=cycle_high,
exp_time=exp_time,
n_of_trigger=n_of_trigger,
)
self.set_scan_control_settings(mode=ScanControlMode.SIMPLE, scan_duration=scan_duration)
elif scan_name == "xas_advanced_scan":
p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None)
e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None)
if any(
param is None for param in [start, stop, scan_time, scan_duration, p_kink, e_kink]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_advanced_scan. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
elif scan_name == "xas_simple_scan_with_xrd":
break_enable_low = self.scan_parameters.additional_scan_parameters.get(
"break_enable_low", None
)
self.set_advanced_xas_settings(
low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink
)
self.set_trig_settings(
enable_low=False,
enable_high=False,
break_time_low=0,
break_time_high=0,
cycle_low=0,
cycle_high=0,
exp_time=0,
n_of_trigger=0,
)
self.set_scan_control_settings(
mode=ScanControlMode.ADVANCED, scan_duration=scan_duration
)
elif scan_name == "xas_advanced_scan_with_xrd":
p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None)
e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None)
break_enable_low = self.scan_parameters.additional_scan_parameters.get(
"break_enable_low", None
)
break_enable_high = self.scan_parameters.additional_scan_parameters.get(
"break_enable_high", None
)
break_time_low = self.scan_parameters.additional_scan_parameters.get(
"break_time_low", None
)
break_time_high = self.scan_parameters.additional_scan_parameters.get(
"break_time_high", None
)
cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None)
cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None)
exp_time = self.scan_parameters.exp_time
n_of_trigger = self.scan_parameters.additional_scan_parameters.get("n_of_trigger", None)
if any(
param is None
for param in [
start,
stop,
scan_time,
scan_duration,
p_kink,
e_kink,
break_enable_low,
break_enable_high,
break_time_low,
break_time_high,
cycle_low,
cycle_high,
exp_time,
n_of_trigger,
]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_advanced_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
break_enable_high = self.scan_parameters.additional_scan_parameters.get(
"break_enable_high", None
)
break_time_low = self.scan_parameters.additional_scan_parameters.get(
"break_time_low", None
)
break_time_high = self.scan_parameters.additional_scan_parameters.get(
"break_time_high", None
)
cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None)
cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None)
exp_time = self.scan_parameters.exp_time
n_of_trigger = self.scan_parameters.additional_scan_parameters.get(
"n_of_trigger", None
)
if any(
param is None
for param in [
start,
stop,
scan_time,
scan_duration,
break_enable_low,
break_enable_high,
break_time_low,
break_time_high,
cycle_low,
cycle_high,
exp_time,
n_of_trigger,
]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_simple_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
)
self.set_xas_settings(low=start, high=stop, scan_time=scan_time)
self.set_trig_settings(
enable_low=break_enable_low,
enable_high=break_enable_high,
break_time_low=break_time_low,
break_time_high=break_time_high,
cycle_low=cycle_low,
cycle_high=cycle_high,
exp_time=exp_time,
n_of_trigger=n_of_trigger,
)
self.set_scan_control_settings(
mode=ScanControlMode.SIMPLE, scan_duration=scan_duration
)
elif scan_name == "xas_advanced_scan":
p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None)
e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None)
if any(
param is None
for param in [start, stop, scan_time, scan_duration, p_kink, e_kink]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_advanced_scan. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
)
self.set_advanced_xas_settings(
low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink
)
self.set_trig_settings(
enable_low=False,
enable_high=False,
break_time_low=0,
break_time_high=0,
cycle_low=0,
cycle_high=0,
exp_time=0,
n_of_trigger=0,
)
self.set_scan_control_settings(
mode=ScanControlMode.ADVANCED, scan_duration=scan_duration
)
elif scan_name == "xas_advanced_scan_with_xrd":
p_kink = self.scan_parameters.additional_scan_parameters.get("p_kink", None)
e_kink = self.scan_parameters.additional_scan_parameters.get("e_kink", None)
break_enable_low = self.scan_parameters.additional_scan_parameters.get(
"break_enable_low", None
)
break_enable_high = self.scan_parameters.additional_scan_parameters.get(
"break_enable_high", None
)
break_time_low = self.scan_parameters.additional_scan_parameters.get(
"break_time_low", None
)
break_time_high = self.scan_parameters.additional_scan_parameters.get(
"break_time_high", None
)
cycle_low = self.scan_parameters.additional_scan_parameters.get("cycle_low", None)
cycle_high = self.scan_parameters.additional_scan_parameters.get("cycle_high", None)
exp_time = self.scan_parameters.exp_time
n_of_trigger = self.scan_parameters.additional_scan_parameters.get(
"n_of_trigger", None
)
if any(
param is None
for param in [
start,
stop,
scan_time,
scan_duration,
p_kink,
e_kink,
break_enable_low,
break_enable_high,
break_time_low,
break_time_high,
cycle_low,
cycle_high,
exp_time,
n_of_trigger,
]
):
raise Mo1BraggError(
f"Missing scan parameters for xas_advanced_scan_with_xrd. Required parameters: start, stop, scan_time, scan_duration, p_kink, e_kink, break_enable_low, break_enable_high, break_time_low, break_time_high, cycle_low, cycle_high, exp_time, n_of_trigger in additional_scan_parameters dict {self.scan_parameters.additional_scan_parameters}"
)
self.set_advanced_xas_settings(
low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink
)
self.set_trig_settings(
enable_low=break_enable_low,
enable_high=break_enable_high,
break_time_low=break_time_low,
break_time_high=break_time_high,
cycle_low=cycle_low,
cycle_high=cycle_high,
exp_time=exp_time,
n_of_trigger=n_of_trigger,
)
self.set_scan_control_settings(
mode=ScanControlMode.ADVANCED, scan_duration=scan_duration
)
self.set_advanced_xas_settings(
low=start, high=stop, scan_time=scan_time, p_kink=p_kink, e_kink=e_kink
)
self.set_trig_settings(
enable_low=break_enable_low,
enable_high=break_enable_high,
break_time_low=break_time_low,
break_time_high=break_time_high,
cycle_low=cycle_low,
cycle_high=cycle_high,
exp_time=exp_time,
n_of_trigger=n_of_trigger,
)
self.set_scan_control_settings(
mode=ScanControlMode.ADVANCED, scan_duration=scan_duration
)
else:
return # Should never happen.
else:
return
# Setting scan duration seems to lag behind slightly in the backend, include small sleep
@@ -335,6 +356,13 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
self.stopped = True # Needs to be set to stop motion
######### Utility Methods #########
def _check_if_scan_name_is_valid(self, scan_parameters: ScanServerScanInfo) -> bool:
"""Check if the scan is within the list of scans for which the backend is working"""
if scan_parameters.scan_name in self.valid_scan_names:
return True
return False
def _progress_update(self, value, **kwargs) -> None:
"""Callback method to update the scan progress, runs a callback
to SUB_PROGRESS subscribers, i.e. BEC.
+12 -11
View File
@@ -260,7 +260,7 @@ class Pilatus(PSIDeviceBase, ADBase):
# self._live_mode_run_event = threading.Event()
# self._live_mode_stopped_event = threading.Event()
# self._live_mode_stopped_event.set() # Initial state is stopped
self.scan_parameters: ScanServerScanInfo | None = None
self.scan_parameters: ScanServerScanInfo = None
########################################
# Custom Beamline Methods #
@@ -495,8 +495,7 @@ class Pilatus(PSIDeviceBase, ADBase):
status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value)
status_cam.wait(timeout=5)
scan_msg: ScanStatusMessage = self.scan_info.msg
if scan_msg.scan_name in self.xas_xrd_scan_names:
if self.scan_parameters.scan_name in self.xas_xrd_scan_names:
# Compute number of triggers
total_trig_lo, total_trig_hi = self._calculate_trigger(self.scan_parameters)
# Set the number of images, we may also set this to a higher values if preferred and stop the acquisition
@@ -511,10 +510,12 @@ class Pilatus(PSIDeviceBase, ADBase):
self.trigger_n_of.set(
self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1)
).wait(5)
elif scan_msg.scan_type == "step":
# TODO migrate logic to v4 once old scans are deprecated,
# TODO if num_points=None and no logic from scan_name applies, can't measure with this detector..
elif self.scan_parameters.scan_type == "software_triggered":
self.n_images = (
self.scan_parameters.num_monitored_readouts * scan_msg.frames_per_trigger
self.scan_parameters.num_monitored_readouts
* self.scan_parameters.frames_per_trigger
)
exp_time = self.scan_parameters.exp_time
self.trigger_source.set(MONOTRIGGERSOURCE.EPICS).wait(5)
@@ -537,7 +538,7 @@ class Pilatus(PSIDeviceBase, ADBase):
)
)
detector_exp_time = exp_time - self._readout_time
self._full_path = get_full_path(scan_msg, name="pilatus")
self._full_path = get_full_path(self.scan_info.msg, name="pilatus")
file_path = "/".join(self._full_path.split("/")[:-1])
file_name = self._full_path.split("/")[-1]
# Prepare detector and backend
@@ -571,7 +572,7 @@ class Pilatus(PSIDeviceBase, ADBase):
"""Called right before the scan starts on all devices automatically."""
if (
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
or self.scan_parameters.scan_type == "software_triggered"
): # TODO how to deal with fly scans?
status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.ACQUIRING.value)
status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.ACQUIRING.value)
@@ -586,7 +587,7 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_trigger(self) -> DeviceStatus | None:
"""Called when the device is triggered."""
if not self.scan_parameters.scan_type == "step":
if not self.scan_parameters.scan_type == "software_triggered":
return None
start_time = time.time()
img_counter = self.hdf.num_captured.get()
@@ -601,7 +602,7 @@ class Pilatus(PSIDeviceBase, ADBase):
"""Callback for when the device completes a scan."""
if (
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
or self.scan_parameters.scan_type == "software_triggered"
): # TODO how to deal with fly scans?
if status.success:
self.file_event.put(
@@ -624,7 +625,7 @@ class Pilatus(PSIDeviceBase, ADBase):
"""Called to inquire if a device has completed a scans."""
if (
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
or self.scan_parameters.scan_type == "software_triggered"
): # TODO how to deal with fly scans?
status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value)
status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value)
+15 -2
View File
@@ -1,8 +1,11 @@
"""Utility functions for the devices."""
from copy import deepcopy
import numpy as np
from bec_lib.devicemanager import ScanInfo
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
import numpy as np
from pydantic import ValidationError
def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo:
@@ -10,4 +13,14 @@ def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo:
info = scan_info.msg.info
if isinstance(info["positions"], list):
info["positions"] = np.array(info["positions"])
return ScanServerScanInfo.model_validate(info)
try:
msg = ScanServerScanInfo.model_validate(info)
except ValidationError: # This means we have an old scan_info object.
info = deepcopy(info)
# We need to convert a few parameters manually.
info["scan_type"] = (
"hardware_triggered" if info["scan_type"] == "fly" else "software_triggered"
)
msg = ScanServerScanInfo.model_validate(info)
return msg
+8 -8
View File
@@ -37,7 +37,7 @@ class XasSimpleScan(ScanBase):
def __init__(
self,
#fmt: off
# fmt: off
start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV, ge=4500, le=64000)],
stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV, ge=4500, le=64000)],
scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0.05)],
@@ -45,7 +45,7 @@ class XasSimpleScan(ScanBase):
motor: Annotated[DeviceBase | None, ScanArgument(display_name="Motor", description="Bragg motor device.")] = None,
daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None,
monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.",units=Units.s, gt=0)] = 1,
#fmt: on
# fmt: on
**kwargs,
):
"""
@@ -186,7 +186,7 @@ class XasSimpleScanWithXrd(XasSimpleScan):
def __init__(
self,
#fmt: off
# fmt: off
start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)],
stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)],
scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)],
@@ -203,7 +203,7 @@ class XasSimpleScanWithXrd(XasSimpleScan):
daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None,
monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1,
**kwargs,
#fmt: on
# fmt: on
):
super().__init__(
start=start,
@@ -239,7 +239,7 @@ class XasAdvancedScan(XasSimpleScan):
def __init__(
self,
#fmt: off
# fmt: off
start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)],
stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)],
scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)],
@@ -250,7 +250,7 @@ class XasAdvancedScan(XasSimpleScan):
daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None,
monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1,
**kwargs,
#fmt: on
# fmt: on
):
super().__init__(
start=start,
@@ -279,7 +279,7 @@ class XasAdvancedScanWithXrd(XasAdvancedScan):
def __init__(
self,
#fmt: off
# fmt: off
start: Annotated[float, ScanArgument(display_name="Start Energy", description="Start energy.", units=Units.eV)],
stop: Annotated[float, ScanArgument(display_name="Stop Energy", description="Stop energy.", units=Units.eV)],
scan_time: Annotated[float, ScanArgument(display_name="Scan Time", description="Time for one scan cycle.", units=Units.s, ge=0)],
@@ -298,7 +298,7 @@ class XasAdvancedScanWithXrd(XasAdvancedScan):
daq: Annotated[DeviceBase | None, ScanArgument(display_name="DAQ", description="NIDAQ device.")] = None,
monitored_readout_cycle: Annotated[float, ScanArgument(display_name="Monitored Readout Cycle", description="Delay between monitored readouts.", units=Units.s, gt=0)] = 1,
**kwargs,
#fmt: on
# fmt: on
):
super().__init__(
start=start,
-54
View File
@@ -159,60 +159,6 @@ def test_set_control_settings(mock_bragg):
assert dev.scan_control.scan_duration.get() == 5
def test_update_scan_parameters(mock_bragg):
dev = mock_bragg
msg = ScanStatusMessage(
scan_id="my_scan_id",
status="closed",
request_inputs={
"inputs": {},
"kwargs": {
"start": 0,
"stop": 5,
"scan_time": 1,
"scan_duration": 10,
"xrd_enable_low": True,
"xrd_enable_high": False,
"num_trigger_low": 1,
"num_trigger_high": 7,
"exp_time_low": 1,
"exp_time_high": 3,
"cycle_low": 1,
"cycle_high": 5,
"p_kink": 50,
"e_kink": 8000,
},
},
info={
"kwargs": {
"start": 0,
"stop": 5,
"scan_time": 1,
"scan_duration": 10,
"xrd_enable_low": True,
"xrd_enable_high": False,
"num_trigger_low": 1,
"num_trigger_high": 7,
"exp_time_low": 1,
"exp_time_high": 3,
"cycle_low": 1,
"cycle_high": 5,
"p_kink": 50,
"e_kink": 8000,
}
},
metadata={},
)
mock_bragg.scan_info.msg = msg
scan_param = dev.scan_parameter.model_dump()
for _, v in scan_param.items():
assert v == None
dev._update_scan_parameter()
scan_param = dev.scan_parameter.model_dump()
for k, v in scan_param.items():
assert v == msg.content["request_inputs"]["kwargs"].get(k, None)
def test_kickoff_scan(mock_bragg):
dev = mock_bragg
dev.scan_control.scan_status._read_pv.mock_data = ScanControlScanStatus.READY
+28 -10
View File
@@ -6,6 +6,7 @@ from unittest import mock
import ophyd
import pytest
from bec_server.scan_server.scan_worker import ScanWorker
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
from ophyd.status import WaitTimeoutError
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
from ophyd_devices.tests.utils import MockPV
@@ -15,6 +16,13 @@ from debye_bec.devices.nidaq.nidaq import Nidaq, NidaqError
# TODO move this function to ophyd_devices, it is duplicated in csaxs_bec and needed for other pluging repositories
from debye_bec.devices.test_utils.utils import patch_dual_pvs
from debye_bec.devices.utils.utils import fetch_scan_info
@pytest.fixture(scope="function")
def scan_info_mock():
"""Fixture for the ScanInfo object."""
return ScanServerScanInfo(scan_name="xas_simple_scan", scan_id="test")
@pytest.fixture(scope="function")
@@ -52,13 +60,17 @@ def test_init(mock_nidaq):
]
def test_check_if_scan_name_is_valid(mock_nidaq):
def test_check_if_scan_name_is_valid(mock_nidaq, scan_info_mock):
"""Test the check_if_scan_name_is_valid method."""
dev = mock_nidaq
dev.scan_info.msg.scan_name = "xas_simple_scan"
assert dev._check_if_scan_name_is_valid()
dev.scan_info.msg.scan_name = "invalid_scan_name"
assert not dev._check_if_scan_name_is_valid()
scan_info_mock.scan_name = "xas_simple_scan"
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
scan_parameters = fetch_scan_info(dev.scan_info)
assert dev._check_if_scan_name_is_valid(scan_parameters)
scan_info_mock.scan_name = "invalid_scan_name"
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
scan_parameters = fetch_scan_info(dev.scan_info)
assert not dev._check_if_scan_name_is_valid(scan_parameters)
def test_set_config(mock_nidaq):
@@ -120,11 +132,13 @@ def test_on_unstage(mock_nidaq):
("nidaq_continuous_scan", False, 0),
],
)
def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state):
def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state, scan_info_mock):
"""Test the on_pre_scan method of the Nidaq device."""
dev = mock_nidaq
dev.state.put(nidaq_state)
dev.scan_info.msg.scan_name = scan_name
scan_info_mock.scan_name = scan_name
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
dev.scan_parameters = fetch_scan_info(dev.scan_info)
dev._timeout_wait_for_pv = 0.1 # Set a short timeout for testing
if not raise_error:
dev.pre_scan()
@@ -133,11 +147,13 @@ def test_on_pre_scan(mock_nidaq, scan_name, raise_error, nidaq_state):
dev.pre_scan()
def test_on_complete(mock_nidaq):
def test_on_complete(mock_nidaq, scan_info_mock):
"""Test the on_complete method of the Nidaq device."""
dev = mock_nidaq
scan_info_mock.scan_name = "nidaq_continuous_scan"
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
dev.scan_parameters = fetch_scan_info(dev.scan_info)
# Check for nidaq_continuous_scan
dev.scan_info.msg.scan_name = "nidaq_continuous_scan"
dev.state.put(0) # Set state to DISABLED
status = dev.complete()
assert status.done is False
@@ -147,7 +163,9 @@ def test_on_complete(mock_nidaq):
assert status.done is True
# Check for XAS simple scan
dev.scan_info.msg.scan_name = "xas_simple_scan"
scan_info_mock.scan_name = "xas_simple_scan"
dev.scan_info.msg.info.update(scan_info_mock.model_dump())
dev.scan_parameters = fetch_scan_info(dev.scan_info)
dev.state.put(0) # Set state to ACQUIRE
dev.stop_call.put(0)
dev._timeout_wait_for_pv = 5
+42 -28
View File
@@ -7,6 +7,9 @@ import ophyd
import pytest
from bec_lib.messages import ScanStatusMessage
from bec_server.scan_server.scan_worker import ScanWorker
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
from bec_server.scan_server.tests.scan_fixtures import *
from bec_server.scan_server.tests.scan_fixtures import _MockDevice
from ophyd_devices import CompareStatus, DeviceStatus
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
@@ -20,6 +23,7 @@ from debye_bec.devices.pilatus.pilatus import (
TRIGGERMODE,
Pilatus,
)
from debye_bec.devices.utils.utils import fetch_scan_info
if TYPE_CHECKING: # pragma no cover
from bec_lib.messages import FileMessage
@@ -34,32 +38,38 @@ if TYPE_CHECKING: # pragma no cover
@pytest.fixture(
scope="function",
params=[
(0.1, 1, 1, "line_scan", "step"),
(0.2, 2, 2, "time_scan", "step"),
(0.5, 5, 5, "xas_advanced_scan", "fly"),
(("samx", 0.1, 1, 5, "samy", 0, 1, 5), {"relative": True}, "_v4_hexagonal_scan"),
((1, 0.2), {}, "_v4_time_scan"),
((9000, 10000, 1, 20, 0.1, 9500), {}, "xas_advanced_scan"),
],
)
def mock_scan_info(request, tmpdir):
exp_time, frames_per_trigger, num_points, scan_name, scan_type = request.param
scan_info = ScanStatusMessage(
scan_id="test_id",
status="open",
scan_type=scan_type,
scan_number=1,
scan_parameters={
"exp_time": exp_time,
"frames_per_trigger": frames_per_trigger,
"system_config": {},
},
info={"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")},
num_points=num_points,
scan_name=scan_name,
)
yield scan_info
def mock_scan_info(request, tmpdir, v4_scan_assembler, device_manager):
args, kwargs, scan_name = request.param
mo1_bragg = _MockDevice(name="mo1_bragg")
nidaq = _MockDevice(name="nidaq")
device_manager.add_device(mo1_bragg)
device_manager.add_device(nidaq)
scan = v4_scan_assembler(scan_name, *args, **kwargs)
yield scan.scan_info
@pytest.fixture(scope="function")
def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]:
def mock_scan_status_message(mock_scan_info, tmpdir) -> ScanStatusMessage:
info = mock_scan_info.model_dump()
info.update({"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")})
return ScanStatusMessage(
scan_id=mock_scan_info.scan_id,
status="open",
scan_number=1,
scan_name=mock_scan_info.scan_name,
scan_type="fly" if mock_scan_info.scan_type == "hardware_triggered" else "step",
num_points=mock_scan_info.num_points,
info=info,
)
@pytest.fixture(scope="function")
def pilatus(mock_scan_status_message) -> Generator[Pilatus, None, None]:
name = "pilatus"
prefix = "X01DA-OP-MO1:PILATUS:"
with mock.patch.object(ophyd, "cl") as mock_cl:
@@ -70,8 +80,9 @@ def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]:
# dev.image1 = mock.MagicMock()
# with mock.patch.object(dev, "image1"):
with mock.patch.object(dev, "task_handler"):
dev.scan_info.msg = mock_scan_info
dev.scan_info.msg = mock_scan_status_message
try:
dev.scan_parameters = fetch_scan_info(dev.scan_info)
yield dev
finally:
try:
@@ -177,7 +188,6 @@ def test_pilatus_on_trigger_cancel_on_stop(pilatus):
def test_pilatus_on_complete(pilatus: Pilatus):
"""Test the on_complete logic of the Pilatus detector."""
if pilatus.scan_info.msg.scan_name.startswith("xas"):
# TODO add test cases for xas scans
# status = pilatus.complete()
@@ -196,8 +206,9 @@ def test_pilatus_on_complete(pilatus: Pilatus):
pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.ARMED.value
num_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get(
"frames_per_trigger", 1
num_images = (
pilatus.scan_parameters.num_points
* pilatus.scan_parameters.additional_scan_parameters.get("frames_per_trigger", 1)
)
pilatus.hdf.num_captured._read_pv.mock_data = num_images - 1
# Call on complete
@@ -275,9 +286,12 @@ def test_pilatus_on_complete(pilatus: Pilatus):
def test_pilatus_on_stage_raises_low_exp_time(pilatus):
"""Test that on_stage raises a ValueError if the exposure time is too low."""
pilatus.scan_info.msg.scan_parameters["exp_time"] = 0.09
scan_msg = pilatus.scan_info.msg
if scan_msg.scan_type != "step" and scan_msg.scan_name not in pilatus.xas_xrd_scan_names:
pilatus.scan_info.msg.info["exp_time"] = 0.09
pilatus.scan_parameters = fetch_scan_info(pilatus.scan_info)
if (
pilatus.scan_parameters.scan_type != "software_triggered"
and pilatus.scan_parameters.scan_name not in pilatus.xas_xrd_scan_names
):
return
with pytest.raises(ValueError):
pilatus.on_stage()