From 262a0b6318bd72f8b0f04176d18dcc87913eab30 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 22 May 2026 10:04:49 +0200 Subject: [PATCH] refactor(pilatus): migrate to scans v4 interface --- debye_bec/devices/pilatus/pilatus.py | 128 +++++++++++++++------------ debye_bec/devices/utils/__init__.py | 0 debye_bec/devices/utils/utils.py | 9 ++ 3 files changed, 82 insertions(+), 55 deletions(-) create mode 100644 debye_bec/devices/utils/__init__.py create mode 100644 debye_bec/devices/utils/utils.py diff --git a/debye_bec/devices/pilatus/pilatus.py b/debye_bec/devices/pilatus/pilatus.py index 2a7668f..6c721fe 100644 --- a/debye_bec/devices/pilatus/pilatus.py +++ b/debye_bec/devices/pilatus/pilatus.py @@ -11,16 +11,26 @@ from typing import TYPE_CHECKING, Tuple import numpy as np from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger +from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo from ophyd import Component as Cpt from ophyd import EpicsSignal, EpicsSignalRO, Kind from ophyd.areadetector.cam import ADBase, PilatusDetectorCam from ophyd.areadetector.plugins import HDF5Plugin_V22 as HDF5Plugin from ophyd.areadetector.plugins import ImagePlugin_V22 as ImagePlugin from ophyd.status import WaitTimeoutError -from ophyd_devices import AndStatus, CompareStatus, DeviceStatus, FileEventSignal, PreviewSignal +from ophyd_devices import ( + AndStatus, + CompareStatus, + DeviceStatus, + ExceptionStatus, + FileEventSignal, + PreviewSignal, +) from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from pydantic import BaseModel, Field +from debye_bec.devices.utils.utils import fetch_scan_info + if TYPE_CHECKING: # pragma: no cover from bec_lib.devicemanager import ScanInfo from bec_lib.messages import DevicePreviewMessage, ScanStatusMessage @@ -145,17 +155,17 @@ class Pilatus(PSIDeviceBase, ADBase): # USER_ACCESS = ["start_live_mode", "stop_live_mode"] - cam_gain_menu_string = Cpt(EpicsSignalRO, suffix='cam1:GainMenu', string=True) + cam_gain_menu_string = Cpt(EpicsSignalRO, suffix="cam1:GainMenu", string=True) _default_configuration_attrs = [ - 'cam.threshold_energy', - 'cam.threshold_auto_apply', - 'cam.gain_menu', - 'cam_gain_menu_string', - 'cam.pixel_cut_off', - 'cam.acquire_time', - 'cam.num_exposures', - 'cam.model', + "cam.threshold_energy", + "cam.threshold_auto_apply", + "cam.gain_menu", + "cam_gain_menu_string", + "cam.pixel_cut_off", + "cam.acquire_time", + "cam.num_exposures", + "cam.model", ] cam = Cpt(PilatusDetectorCam, "cam1:") @@ -233,7 +243,6 @@ class Pilatus(PSIDeviceBase, ADBase): super().__init__( name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs ) - self.scan_parameter = ScanParameter() self.device_manager = device_manager self._readout_time = PILATUS_READOUT_TIME self._full_path = "" @@ -251,6 +260,7 @@ class Pilatus(PSIDeviceBase, ADBase): # self._live_mode_run_event = threading.Event() # self._live_mode_stopped_event = threading.Event() # self._live_mode_stopped_event.set() # Initial state is stopped + self.scan_parameters: ScanServerScanInfo | None = None ######################################## # Custom Beamline Methods # @@ -368,19 +378,22 @@ class Pilatus(PSIDeviceBase, ADBase): status = status_acquire & status_writing & status_cam_server return status - def _calculate_trigger(self, scan_msg: ScanStatusMessage) -> Tuple[float, float]: - self._update_scan_parameter() + def _calculate_trigger(self, scan_parameters: ScanServerScanInfo) -> Tuple[float, float]: total_osc = 0 calc_duration = 0 total_trig_lo = 0 total_trig_hi = 0 # Switching high/low is intended as angle is inverse to energy and settings in BEC are always in energy - loc_break_enable_low = self.scan_parameter.break_enable_high - loc_break_time_low = self.scan_parameter.break_time_high - loc_cycle_low = self.scan_parameter.cycle_high - loc_break_enable_high = self.scan_parameter.break_enable_low - loc_break_time_high = self.scan_parameter.break_time_low - loc_cycle_high = self.scan_parameter.cycle_low + loc_break_enable_low = scan_parameters.additional_scan_parameters.get( + "break_enable_high", False + ) + loc_break_time_low = scan_parameters.additional_scan_parameters.get("break_time_high", 0) + loc_cycle_low = scan_parameters.additional_scan_parameters.get("cycle_high", 1) + loc_break_enable_high = scan_parameters.additional_scan_parameters.get( + "break_enable_low", False + ) + loc_break_time_high = scan_parameters.additional_scan_parameters.get("break_time_low", 0) + loc_cycle_high = scan_parameters.additional_scan_parameters.get("cycle_low", 1) if not loc_break_enable_low: loc_break_time_low = 0 @@ -389,28 +402,36 @@ class Pilatus(PSIDeviceBase, ADBase): loc_break_time_high = 0 loc_cycle_high = 1 - total_osc = self.scan_parameter.scan_duration / ( - self.scan_parameter.scan_time + - loc_break_time_low / (2 * loc_cycle_low) + - loc_break_time_high / (2 * loc_cycle_high) + total_osc = scan_parameters.additional_scan_parameters.get("scan_duration", 0) / ( + scan_parameters.additional_scan_parameters.get("scan_time", 0) + + loc_break_time_low / (2 * loc_cycle_low) + + loc_break_time_high / (2 * loc_cycle_high) ) total_osc = np.ceil(total_osc) - total_osc = total_osc + total_osc % 2 # round up to the next even number + total_osc = total_osc + total_osc % 2 # round up to the next even number if loc_break_enable_low: total_trig_lo = np.floor(total_osc / (2 * loc_cycle_low)) if loc_break_enable_high: total_trig_hi = np.floor(total_osc / (2 * loc_cycle_high)) - calc_duration = total_osc * self.scan_parameter.scan_time + total_trig_lo * loc_break_time_low + total_trig_hi * loc_break_time_high - - if calc_duration < self.scan_parameter.scan_duration: + calc_duration = ( + total_osc * scan_parameters.additional_scan_parameters.get("scan_time", 0) + + total_trig_lo * loc_break_time_low + + total_trig_hi * loc_break_time_high + ) + + if calc_duration < scan_parameters.additional_scan_parameters.get("scan_duration", 0): # Due to inaccuracy in formula, this can happen, we then need to manually add two oscillations and recalculate the triggers total_osc = total_osc + 2 if loc_break_enable_low: total_trig_lo = np.floor(total_osc / (2 * loc_cycle_low)) if loc_break_enable_high: total_trig_hi = np.floor(total_osc / (2 * loc_cycle_high)) - calc_duration = total_osc * self.scan_parameter.scan_time + total_trig_lo * loc_break_time_low + total_trig_hi * loc_break_time_high + calc_duration = ( + total_osc * scan_parameters.additional_scan_parameters.get("scan_time", 0) + + total_trig_lo * loc_break_time_low + + total_trig_hi * loc_break_time_high + ) return total_trig_lo, total_trig_hi @@ -464,6 +485,7 @@ class Pilatus(PSIDeviceBase, ADBase): (self.scan_info.msg) object. """ # self.stop_live_mode() # Make sure that live mode is stopped if scan runs + self.scan_parameters = fetch_scan_info(self.scan_info) # If user has activated alignment mode on qt panel, switch back to multitrigger and stop acquisition if self.cam.trigger_mode.get() != TRIGGERMODE.MULT_TRIGGER.value: @@ -475,23 +497,26 @@ class Pilatus(PSIDeviceBase, ADBase): scan_msg: ScanStatusMessage = self.scan_info.msg if scan_msg.scan_name in self.xas_xrd_scan_names: - self._update_scan_parameter() # Compute number of triggers - total_trig_lo, total_trig_hi = self._calculate_trigger(scan_msg) + total_trig_lo, total_trig_hi = self._calculate_trigger(self.scan_parameters) # Set the number of images, we may also set this to a higher values if preferred and stop the acquisition # TODO This logic is prone to errors, as we rely on the scans to nicely resolve to n_images. We should # use here instead a way of settings the n_images independently of the scan parameters to avoid running out of sync # with the complete method. Ideally we comput them in the scan itself.. This is much safer IMO! - self.n_images = (total_trig_lo + total_trig_hi) * self.scan_parameter.n_of_trigger - exp_time = self.scan_parameter.exp_time + self.n_images = ( + total_trig_lo + total_trig_hi + ) * self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1) + exp_time = self.scan_parameters.exp_time self.trigger_source.set(MONOTRIGGERSOURCE.INPOS).wait(5) - self.trigger_n_of.set(self.scan_parameter.n_of_trigger).wait(5) + self.trigger_n_of.set( + self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1) + ).wait(5) elif scan_msg.scan_type == "step": - self.n_images = scan_msg.num_points * scan_msg.scan_parameters.get( - "frames_per_trigger", 1 + self.n_images = ( + self.scan_parameters.num_monitored_readouts * scan_msg.frames_per_trigger ) - exp_time = scan_msg.scan_parameters.get("exp_time") + exp_time = self.scan_parameters.exp_time self.trigger_source.set(MONOTRIGGERSOURCE.EPICS).wait(5) self.trigger_n_of.set(1).wait(5) # BEC will trigger each acquisition else: @@ -544,9 +569,9 @@ class Pilatus(PSIDeviceBase, ADBase): def on_pre_scan(self) -> DeviceStatus | None: """Called right before the scan starts on all devices automatically.""" - scan_msg: ScanStatusMessage = self.scan_info.msg if ( - scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step" + self.scan_parameters.scan_name in self.xas_xrd_scan_names + or self.scan_parameters.scan_type == "step" ): # TODO how to deal with fly scans? status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.ACQUIRING.value) status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.ACQUIRING.value) @@ -561,8 +586,7 @@ class Pilatus(PSIDeviceBase, ADBase): def on_trigger(self) -> DeviceStatus | None: """Called when the device is triggered.""" - scan_msg: ScanStatusMessage = self.scan_info.msg - if not scan_msg.scan_type == "step": + if not self.scan_parameters.scan_type == "step": return None start_time = time.time() img_counter = self.hdf.num_captured.get() @@ -575,9 +599,9 @@ class Pilatus(PSIDeviceBase, ADBase): def _complete_callback(self, status: DeviceStatus): """Callback for when the device completes a scan.""" - scan_msg: ScanStatusMessage = self.scan_info.msg if ( - scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step" + self.scan_parameters.scan_name in self.xas_xrd_scan_names + or self.scan_parameters.scan_type == "step" ): # TODO how to deal with fly scans? if status.success: self.file_event.put( @@ -598,14 +622,15 @@ class Pilatus(PSIDeviceBase, ADBase): def on_complete(self) -> DeviceStatus | None: """Called to inquire if a device has completed a scans.""" - scan_msg: ScanStatusMessage = self.scan_info.msg if ( - scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step" + self.scan_parameters.scan_name in self.xas_xrd_scan_names + or self.scan_parameters.scan_type == "step" ): # TODO how to deal with fly scans? status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value) status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value) status_cam_server = CompareStatus(self.cam.armed, DETECTORSTATE.UNARMED.value) - if self.scan_info.msg.scan_name in self.xas_xrd_scan_names: + # status_write_error = ExceptionStatus(self.hdf.write_status, 0, operation="!=") + if self.scan_parameters.scan_name in self.xas_xrd_scan_names: # For long scans, it can be that the mono will execute one cycle more, # meaning a few more XRD triggers will be sent status_img_written = CompareStatus( @@ -614,7 +639,9 @@ class Pilatus(PSIDeviceBase, ADBase): else: status_img_written = CompareStatus(self.hdf.num_captured, self.n_images) status_img_written = CompareStatus(self.hdf.num_captured, self.n_images) - status = status_hdf & status_cam & status_img_written & status_cam_server + status = ( + status_hdf & status_cam & status_img_written & status_cam_server + ) # & status_write_error status.add_callback(self._complete_callback) # Callback that writing was successful self.cancel_on_stop(status) return status @@ -635,15 +662,6 @@ class Pilatus(PSIDeviceBase, ADBase): # TODO do we need to clean the poll thread ourselves? self.on_stop() - def _update_scan_parameter(self): - """Get the scan_info parameters for the scan.""" - for key, value in self.scan_info.msg.request_inputs["inputs"].items(): - if hasattr(self.scan_parameter, key): - setattr(self.scan_parameter, key, value) - for key, value in self.scan_info.msg.request_inputs["kwargs"].items(): - if hasattr(self.scan_parameter, key): - setattr(self.scan_parameter, key, value) - if __name__ == "__main__": try: diff --git a/debye_bec/devices/utils/__init__.py b/debye_bec/devices/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/debye_bec/devices/utils/utils.py b/debye_bec/devices/utils/utils.py new file mode 100644 index 0000000..f747655 --- /dev/null +++ b/debye_bec/devices/utils/utils.py @@ -0,0 +1,9 @@ +"""Utility functions for the devices.""" + +from bec_lib.devicemanager import ScanInfo +from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo + + +def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo: + """Fetch the scan parameters from the scan_info object and return them as a ScanServerScanInfo object.""" + return ScanServerScanInfo.model_validate(scan_info.msg.info)