refactor(pilatus): migrate to scans v4 interface

This commit is contained in:
2026-05-22 10:04:49 +02:00
parent 98d5c22667
commit 262a0b6318
3 changed files with 82 additions and 55 deletions
+73 -55
View File
@@ -11,16 +11,26 @@ from typing import TYPE_CHECKING, Tuple
import numpy as np
from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
from ophyd import Component as Cpt
from ophyd import EpicsSignal, EpicsSignalRO, Kind
from ophyd.areadetector.cam import ADBase, PilatusDetectorCam
from ophyd.areadetector.plugins import HDF5Plugin_V22 as HDF5Plugin
from ophyd.areadetector.plugins import ImagePlugin_V22 as ImagePlugin
from ophyd.status import WaitTimeoutError
from ophyd_devices import AndStatus, CompareStatus, DeviceStatus, FileEventSignal, PreviewSignal
from ophyd_devices import (
AndStatus,
CompareStatus,
DeviceStatus,
ExceptionStatus,
FileEventSignal,
PreviewSignal,
)
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from pydantic import BaseModel, Field
from debye_bec.devices.utils.utils import fetch_scan_info
if TYPE_CHECKING: # pragma: no cover
from bec_lib.devicemanager import ScanInfo
from bec_lib.messages import DevicePreviewMessage, ScanStatusMessage
@@ -145,17 +155,17 @@ class Pilatus(PSIDeviceBase, ADBase):
# USER_ACCESS = ["start_live_mode", "stop_live_mode"]
cam_gain_menu_string = Cpt(EpicsSignalRO, suffix='cam1:GainMenu', string=True)
cam_gain_menu_string = Cpt(EpicsSignalRO, suffix="cam1:GainMenu", string=True)
_default_configuration_attrs = [
'cam.threshold_energy',
'cam.threshold_auto_apply',
'cam.gain_menu',
'cam_gain_menu_string',
'cam.pixel_cut_off',
'cam.acquire_time',
'cam.num_exposures',
'cam.model',
"cam.threshold_energy",
"cam.threshold_auto_apply",
"cam.gain_menu",
"cam_gain_menu_string",
"cam.pixel_cut_off",
"cam.acquire_time",
"cam.num_exposures",
"cam.model",
]
cam = Cpt(PilatusDetectorCam, "cam1:")
@@ -233,7 +243,6 @@ class Pilatus(PSIDeviceBase, ADBase):
super().__init__(
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
)
self.scan_parameter = ScanParameter()
self.device_manager = device_manager
self._readout_time = PILATUS_READOUT_TIME
self._full_path = ""
@@ -251,6 +260,7 @@ class Pilatus(PSIDeviceBase, ADBase):
# self._live_mode_run_event = threading.Event()
# self._live_mode_stopped_event = threading.Event()
# self._live_mode_stopped_event.set() # Initial state is stopped
self.scan_parameters: ScanServerScanInfo | None = None
########################################
# Custom Beamline Methods #
@@ -368,19 +378,22 @@ class Pilatus(PSIDeviceBase, ADBase):
status = status_acquire & status_writing & status_cam_server
return status
def _calculate_trigger(self, scan_msg: ScanStatusMessage) -> Tuple[float, float]:
self._update_scan_parameter()
def _calculate_trigger(self, scan_parameters: ScanServerScanInfo) -> Tuple[float, float]:
total_osc = 0
calc_duration = 0
total_trig_lo = 0
total_trig_hi = 0
# Switching high/low is intended as angle is inverse to energy and settings in BEC are always in energy
loc_break_enable_low = self.scan_parameter.break_enable_high
loc_break_time_low = self.scan_parameter.break_time_high
loc_cycle_low = self.scan_parameter.cycle_high
loc_break_enable_high = self.scan_parameter.break_enable_low
loc_break_time_high = self.scan_parameter.break_time_low
loc_cycle_high = self.scan_parameter.cycle_low
loc_break_enable_low = scan_parameters.additional_scan_parameters.get(
"break_enable_high", False
)
loc_break_time_low = scan_parameters.additional_scan_parameters.get("break_time_high", 0)
loc_cycle_low = scan_parameters.additional_scan_parameters.get("cycle_high", 1)
loc_break_enable_high = scan_parameters.additional_scan_parameters.get(
"break_enable_low", False
)
loc_break_time_high = scan_parameters.additional_scan_parameters.get("break_time_low", 0)
loc_cycle_high = scan_parameters.additional_scan_parameters.get("cycle_low", 1)
if not loc_break_enable_low:
loc_break_time_low = 0
@@ -389,28 +402,36 @@ class Pilatus(PSIDeviceBase, ADBase):
loc_break_time_high = 0
loc_cycle_high = 1
total_osc = self.scan_parameter.scan_duration / (
self.scan_parameter.scan_time +
loc_break_time_low / (2 * loc_cycle_low) +
loc_break_time_high / (2 * loc_cycle_high)
total_osc = scan_parameters.additional_scan_parameters.get("scan_duration", 0) / (
scan_parameters.additional_scan_parameters.get("scan_time", 0)
+ loc_break_time_low / (2 * loc_cycle_low)
+ loc_break_time_high / (2 * loc_cycle_high)
)
total_osc = np.ceil(total_osc)
total_osc = total_osc + total_osc % 2 # round up to the next even number
total_osc = total_osc + total_osc % 2 # round up to the next even number
if loc_break_enable_low:
total_trig_lo = np.floor(total_osc / (2 * loc_cycle_low))
if loc_break_enable_high:
total_trig_hi = np.floor(total_osc / (2 * loc_cycle_high))
calc_duration = total_osc * self.scan_parameter.scan_time + total_trig_lo * loc_break_time_low + total_trig_hi * loc_break_time_high
if calc_duration < self.scan_parameter.scan_duration:
calc_duration = (
total_osc * scan_parameters.additional_scan_parameters.get("scan_time", 0)
+ total_trig_lo * loc_break_time_low
+ total_trig_hi * loc_break_time_high
)
if calc_duration < scan_parameters.additional_scan_parameters.get("scan_duration", 0):
# Due to inaccuracy in formula, this can happen, we then need to manually add two oscillations and recalculate the triggers
total_osc = total_osc + 2
if loc_break_enable_low:
total_trig_lo = np.floor(total_osc / (2 * loc_cycle_low))
if loc_break_enable_high:
total_trig_hi = np.floor(total_osc / (2 * loc_cycle_high))
calc_duration = total_osc * self.scan_parameter.scan_time + total_trig_lo * loc_break_time_low + total_trig_hi * loc_break_time_high
calc_duration = (
total_osc * scan_parameters.additional_scan_parameters.get("scan_time", 0)
+ total_trig_lo * loc_break_time_low
+ total_trig_hi * loc_break_time_high
)
return total_trig_lo, total_trig_hi
@@ -464,6 +485,7 @@ class Pilatus(PSIDeviceBase, ADBase):
(self.scan_info.msg) object.
"""
# self.stop_live_mode() # Make sure that live mode is stopped if scan runs
self.scan_parameters = fetch_scan_info(self.scan_info)
# If user has activated alignment mode on qt panel, switch back to multitrigger and stop acquisition
if self.cam.trigger_mode.get() != TRIGGERMODE.MULT_TRIGGER.value:
@@ -475,23 +497,26 @@ class Pilatus(PSIDeviceBase, ADBase):
scan_msg: ScanStatusMessage = self.scan_info.msg
if scan_msg.scan_name in self.xas_xrd_scan_names:
self._update_scan_parameter()
# Compute number of triggers
total_trig_lo, total_trig_hi = self._calculate_trigger(scan_msg)
total_trig_lo, total_trig_hi = self._calculate_trigger(self.scan_parameters)
# Set the number of images, we may also set this to a higher values if preferred and stop the acquisition
# TODO This logic is prone to errors, as we rely on the scans to nicely resolve to n_images. We should
# use here instead a way of settings the n_images independently of the scan parameters to avoid running out of sync
# with the complete method. Ideally we comput them in the scan itself.. This is much safer IMO!
self.n_images = (total_trig_lo + total_trig_hi) * self.scan_parameter.n_of_trigger
exp_time = self.scan_parameter.exp_time
self.n_images = (
total_trig_lo + total_trig_hi
) * self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1)
exp_time = self.scan_parameters.exp_time
self.trigger_source.set(MONOTRIGGERSOURCE.INPOS).wait(5)
self.trigger_n_of.set(self.scan_parameter.n_of_trigger).wait(5)
self.trigger_n_of.set(
self.scan_parameters.additional_scan_parameters.get("n_of_trigger", 1)
).wait(5)
elif scan_msg.scan_type == "step":
self.n_images = scan_msg.num_points * scan_msg.scan_parameters.get(
"frames_per_trigger", 1
self.n_images = (
self.scan_parameters.num_monitored_readouts * scan_msg.frames_per_trigger
)
exp_time = scan_msg.scan_parameters.get("exp_time")
exp_time = self.scan_parameters.exp_time
self.trigger_source.set(MONOTRIGGERSOURCE.EPICS).wait(5)
self.trigger_n_of.set(1).wait(5) # BEC will trigger each acquisition
else:
@@ -544,9 +569,9 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_pre_scan(self) -> DeviceStatus | None:
"""Called right before the scan starts on all devices automatically."""
scan_msg: ScanStatusMessage = self.scan_info.msg
if (
scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step"
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
): # TODO how to deal with fly scans?
status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.ACQUIRING.value)
status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.ACQUIRING.value)
@@ -561,8 +586,7 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_trigger(self) -> DeviceStatus | None:
"""Called when the device is triggered."""
scan_msg: ScanStatusMessage = self.scan_info.msg
if not scan_msg.scan_type == "step":
if not self.scan_parameters.scan_type == "step":
return None
start_time = time.time()
img_counter = self.hdf.num_captured.get()
@@ -575,9 +599,9 @@ class Pilatus(PSIDeviceBase, ADBase):
def _complete_callback(self, status: DeviceStatus):
"""Callback for when the device completes a scan."""
scan_msg: ScanStatusMessage = self.scan_info.msg
if (
scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step"
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
): # TODO how to deal with fly scans?
if status.success:
self.file_event.put(
@@ -598,14 +622,15 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_complete(self) -> DeviceStatus | None:
"""Called to inquire if a device has completed a scans."""
scan_msg: ScanStatusMessage = self.scan_info.msg
if (
scan_msg.scan_name in self.xas_xrd_scan_names or scan_msg.scan_type == "step"
self.scan_parameters.scan_name in self.xas_xrd_scan_names
or self.scan_parameters.scan_type == "step"
): # TODO how to deal with fly scans?
status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value)
status_cam = CompareStatus(self.cam.acquire, ACQUIREMODE.DONE.value)
status_cam_server = CompareStatus(self.cam.armed, DETECTORSTATE.UNARMED.value)
if self.scan_info.msg.scan_name in self.xas_xrd_scan_names:
# status_write_error = ExceptionStatus(self.hdf.write_status, 0, operation="!=")
if self.scan_parameters.scan_name in self.xas_xrd_scan_names:
# For long scans, it can be that the mono will execute one cycle more,
# meaning a few more XRD triggers will be sent
status_img_written = CompareStatus(
@@ -614,7 +639,9 @@ class Pilatus(PSIDeviceBase, ADBase):
else:
status_img_written = CompareStatus(self.hdf.num_captured, self.n_images)
status_img_written = CompareStatus(self.hdf.num_captured, self.n_images)
status = status_hdf & status_cam & status_img_written & status_cam_server
status = (
status_hdf & status_cam & status_img_written & status_cam_server
) # & status_write_error
status.add_callback(self._complete_callback) # Callback that writing was successful
self.cancel_on_stop(status)
return status
@@ -635,15 +662,6 @@ class Pilatus(PSIDeviceBase, ADBase):
# TODO do we need to clean the poll thread ourselves?
self.on_stop()
def _update_scan_parameter(self):
"""Get the scan_info parameters for the scan."""
for key, value in self.scan_info.msg.request_inputs["inputs"].items():
if hasattr(self.scan_parameter, key):
setattr(self.scan_parameter, key, value)
for key, value in self.scan_info.msg.request_inputs["kwargs"].items():
if hasattr(self.scan_parameter, key):
setattr(self.scan_parameter, key, value)
if __name__ == "__main__":
try:
View File
+9
View File
@@ -0,0 +1,9 @@
"""Utility functions for the devices."""
from bec_lib.devicemanager import ScanInfo
from bec_server.scan_server.scans.scan_base import ScanInfo as ScanServerScanInfo
def fetch_scan_info(scan_info: ScanInfo) -> ScanServerScanInfo:
"""Fetch the scan parameters from the scan_info object and return them as a ScanServerScanInfo object."""
return ScanServerScanInfo.model_validate(scan_info.msg.info)