From 1f7fdb89d72ca8b8530c73ed3715744a80e456a1 Mon Sep 17 00:00:00 2001 From: gac-x01da Date: Wed, 10 Sep 2025 16:55:52 +0200 Subject: [PATCH] add on_stage for xas_xrd scans --- debye_bec/devices/pilatus/pilatus.py | 175 +++++++++++++++++++++------ 1 file changed, 137 insertions(+), 38 deletions(-) diff --git a/debye_bec/devices/pilatus/pilatus.py b/debye_bec/devices/pilatus/pilatus.py index ba307cc..be471c4 100644 --- a/debye_bec/devices/pilatus/pilatus.py +++ b/debye_bec/devices/pilatus/pilatus.py @@ -25,6 +25,7 @@ from ophyd_devices import ( PreviewSignal, ) from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase +from pydantic import BaseModel, Field if TYPE_CHECKING: # pragma: no cover from bec_lib.devicemanager import ScanInfo @@ -95,6 +96,32 @@ class TRIGGERMODE(int, enum.Enum): def __str__(self): return self.description() + + + +class ScanParameter(BaseModel): + """Dataclass to store the scan parameters for the Pilatus. + This needs to be in sync with the kwargs of the XRD related scans from Debye, to + ensure that the scan parameters are correctly set. Any changes in the scan kwargs, + i.e. renaming or adding new parameters, need to be represented here as well.""" + + scan_time: float | None = Field(None, description="Scan time for a half oscillation") + scan_duration: float | None = Field(None, description="Duration of the scan") + break_enable_low: bool | None = Field( + None, description="Break enabled for low, should be PV trig_ena_lo_enum" + ) # trig_enable_low: bool = None + break_enable_high: bool | None = Field( + None, description="Break enabled for high, should be PV trig_ena_hi_enum" + ) # trig_enable_high: bool = None + break_time_low: float | None = Field(None, description="Break time low energy/angle") + break_time_high: float | None = Field(None, description="Break time high energy/angle") + cycle_low: int | None = Field(None, description="Cycle for low energy/angle") + cycle_high: int | None = Field(None, description="Cycle for high energy/angle") + exp_time: float | None = Field(None, description="XRD trigger period") + n_of_trigger: int | None = Field(None, description="Amount of XRD triggers") + start: float | None = Field(None, description="Start value for energy/angle") + stop: float | None = Field(None, description="Stop value for energy/angle") + model_config: dict = {"validate_assignment": True} class Pilatus(PSIDeviceBase, ADBase): @@ -170,6 +197,7 @@ class Pilatus(PSIDeviceBase, ADBase): super().__init__( name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs ) + self.scan_parameter = ScanParameter() self.device_manager = device_manager self._readout_time = PILATUS_READOUT_TIME self._full_path = "" @@ -178,6 +206,10 @@ class Pilatus(PSIDeviceBase, ADBase): ) self._poll_thread_kill_event = threading.Event() self._poll_rate = 1 # Poll rate in Hz + self.xas_xrd_scan_names = [ + "xas_simple_scan_with_xrd", + "xas_advanced_scan_with_xrd", + ] # self._live_mode_thread = threading.Thread( # target=self._live_mode_loop, daemon=True, name=f"{self.name}_live_mode_thread" # ) @@ -354,50 +386,108 @@ class Pilatus(PSIDeviceBase, ADBase): (self.scan_info.msg) object. """ # self.stop_live_mode() # Make sure that live mode is stopped if scan runs + self._update_scan_parameter() scan_msg: ScanStatusMessage = self.scan_info.msg - if scan_msg.scan_name.startswith("xas"): - return None - # TODO implement logic for 'xas' scans - else: - exp_time = scan_msg.scan_parameters.get("exp_time", 0.0) - if exp_time - self._readout_time <= 0: - raise ValueError( - f"Exposure time {exp_time} is too short for Pilatus with readout_time {self._readout_time}." - ) - detector_exp_time = exp_time - self._readout_time - n_images = scan_msg.num_points * scan_msg.scan_parameters.get("frames_per_trigger", 1) - self._full_path = get_full_path(scan_msg, name="pilatus") - file_path = "/".join(self._full_path.split("/")[:-1]) - file_name = self._full_path.split("/")[-1] + if scan_msg.scan_name in self.xas_xrd_scan_names: + total_osc = 0 + total_trig_lo = 0 + total_trig_hi = 0 + calc_duration = 0 + n_trig_lo = 1 + n_trig_hi = 1 + init_lo = 1 + init_hi = 1 + lo_done = 0 + hi_done = 0 + if not self.scan_parameter.break_enable_low: + lo_done = 1 + if not self.scan_parameter.break_enable_high: + hi_done = 1 + while True: + total_osc = total_osc + 2 + calc_duration = calc_duration + 2 * self.scan_parameter.scan_time + + if self.scan_parameter.break_enable_low and n_trig_lo >= self.scan_parameter.cycle_low: + n_trig_lo = 1 + calc_duration = calc_duration + self.scan_parameter.break_time_low + if init_lo: + lo_done = 1 + init_lo = 0 + else: + n_trig_lo += 1 - # Prepare detector and backend - self.cam.array_callbacks.set(1).wait(5) # Enable array callbacks - self.hdf.enable.set(1).wait(5) # Enable HDF5 plugin - # Camera settings - self.cam.num_exposures.set(1).wait(5) - self.cam.num_images.set(n_images).wait(5) - self.cam.acquire_time.set(detector_exp_time).wait(5) # let's try this - self.cam.acquire_period.set(exp_time).wait(5) - self.filter_number.set(0).wait(5) - # HDF5 settings - logger.debug(f"Setting HDF5 file path to {file_path} and file name to {file_name}") - self.hdf.file_path.set(file_path).wait(5) - self.hdf.file_name.set(file_name).wait(5) - self.hdf.num_capture.set(n_images).wait(5) - self.cam.array_counter.set(0).wait(5) # Reset array counter - self.file_event.put( - file_path=self._full_path, - done=False, - successful=False, - hinted_h5_entries={"data": "/entry/data/data"}, - ) + if self.scan_parameter.break_enable_high and n_trig_hi >= self.scan_parameter.cycle_high: + n_trig_hi = 1 + calc_duration = calc_duration + self.scan_parameter.break_time_high + if init_hi: + hi_done = 1 + init_hi = 0 + else: + n_trig_hi += 1 + + if lo_done and hi_done: + n = np.floor(self.scan_parameter.scan_duration / calc_duration) + total_osc = total_osc * n + if self.scan_parameter.break_enable_low: + total_trig_lo = n + 1 + if self.scan_parameter.break_enable_high: + total_trig_hi = n + 1 + calc_duration = calc_duration * n + lo_done = 0 + hi_done = 0 + + if calc_duration >= self.scan_parameter.scan_duration: + break + + # logger.info(f'total_osc: {total_osc}') + # logger.info(f'total trig low: {total_trig_lo}') + # logger.info(f'total trig high: {total_trig_hi}') + + n_images = total_trig_lo + total_trig_hi + exp_time = self.scan_parameter.exp_time + + elif scan_msg.scan_type == 'step': + n_images = scan_msg.num_points * scan_msg.scan_parameters.get("frames_per_trigger", 1) + exp_time = scan_msg.scan_parameters.get("exp_time") + else: + return None + # Common settings + if exp_time - self._readout_time <= 0: + raise ValueError((f"Exposure time {exp_time} is too short ", + f"for Pilatus with readout_time {self._readout_time}." + )) + detector_exp_time = exp_time - self._readout_time + self._full_path = get_full_path(scan_msg, name="pilatus") + file_path = "/".join(self._full_path.split("/")[:-1]) + file_name = self._full_path.split("/")[-1] + # Prepare detector and backend + self.cam.array_callbacks.set(1).wait(5) # Enable array callbacks + self.hdf.enable.set(1).wait(5) # Enable HDF5 plugin + # Camera settings + self.cam.num_exposures.set(1).wait(5) + self.cam.num_images.set(n_images).wait(5) + self.cam.acquire_time.set(detector_exp_time).wait(5) # let's try this + self.cam.acquire_period.set(exp_time).wait(5) + self.filter_number.set(0).wait(5) + # HDF5 settings + logger.debug(f"Setting HDF5 file path to {file_path} and file name to {file_name}") + self.hdf.file_path.set(file_path).wait(5) + self.hdf.file_name.set(file_name).wait(5) + self.hdf.num_capture.set(n_images).wait(5) + self.cam.array_counter.set(0).wait(5) # Reset array counter + self.file_event.put( + file_path=self._full_path, + done=False, + successful=False, + hinted_h5_entries={"data": "/entry/data/data"}, + ) def on_unstage(self) -> None: """Called while unstaging the device.""" def on_pre_scan(self) -> DeviceStatus | None: """Called right before the scan starts on all devices automatically.""" - if self.scan_info.msg.scan_name.startswith("xas"): + if self.scan_info.msg.scan_name in self.xas_xrd_scan_names: # TODO implement logic for 'xas' scans return None else: @@ -414,7 +504,7 @@ class Pilatus(PSIDeviceBase, ADBase): def on_trigger(self) -> DeviceStatus | None: """Called when the device is triggered.""" - if self.scan_info.msg.scan_name.startswith("xas"): + if self.scan_info.msg.scan_name in self.xas_xrd_scan_names: return None # TODO implement logic for 'xas' scans else: @@ -446,7 +536,7 @@ class Pilatus(PSIDeviceBase, ADBase): def on_complete(self) -> DeviceStatus | None: """Called to inquire if a device has completed a scans.""" - if self.scan_info.msg.scan_name.startswith("xas"): + if self.scan_info.msg.scan_name in self.xas_xrd_scan_names: # TODO implement logic for 'xas' scans return None status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value) @@ -477,6 +567,15 @@ class Pilatus(PSIDeviceBase, ADBase): # TODO do we need to clean the poll thread ourselves? self.on_stop() + def _update_scan_parameter(self): + """Get the scan_info parameters for the scan.""" + for key, value in self.scan_info.msg.request_inputs["inputs"].items(): + if hasattr(self.scan_parameter, key): + setattr(self.scan_parameter, key, value) + for key, value in self.scan_info.msg.request_inputs["kwargs"].items(): + if hasattr(self.scan_parameter, key): + setattr(self.scan_parameter, key, value) + if __name__ == "__main__": try: