add on_stage for xas_xrd scans

This commit is contained in:
gac-x01da
2025-09-10 16:55:52 +02:00
committed by appel_c
parent ee748d56c4
commit 1f7fdb89d7

View File

@@ -25,6 +25,7 @@ from ophyd_devices import (
PreviewSignal,
)
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from pydantic import BaseModel, Field
if TYPE_CHECKING: # pragma: no cover
from bec_lib.devicemanager import ScanInfo
@@ -95,6 +96,32 @@ class TRIGGERMODE(int, enum.Enum):
def __str__(self):
return self.description()
class ScanParameter(BaseModel):
"""Dataclass to store the scan parameters for the Pilatus.
This needs to be in sync with the kwargs of the XRD related scans from Debye, to
ensure that the scan parameters are correctly set. Any changes in the scan kwargs,
i.e. renaming or adding new parameters, need to be represented here as well."""
scan_time: float | None = Field(None, description="Scan time for a half oscillation")
scan_duration: float | None = Field(None, description="Duration of the scan")
break_enable_low: bool | None = Field(
None, description="Break enabled for low, should be PV trig_ena_lo_enum"
) # trig_enable_low: bool = None
break_enable_high: bool | None = Field(
None, description="Break enabled for high, should be PV trig_ena_hi_enum"
) # trig_enable_high: bool = None
break_time_low: float | None = Field(None, description="Break time low energy/angle")
break_time_high: float | None = Field(None, description="Break time high energy/angle")
cycle_low: int | None = Field(None, description="Cycle for low energy/angle")
cycle_high: int | None = Field(None, description="Cycle for high energy/angle")
exp_time: float | None = Field(None, description="XRD trigger period")
n_of_trigger: int | None = Field(None, description="Amount of XRD triggers")
start: float | None = Field(None, description="Start value for energy/angle")
stop: float | None = Field(None, description="Stop value for energy/angle")
model_config: dict = {"validate_assignment": True}
class Pilatus(PSIDeviceBase, ADBase):
@@ -170,6 +197,7 @@ class Pilatus(PSIDeviceBase, ADBase):
super().__init__(
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
)
self.scan_parameter = ScanParameter()
self.device_manager = device_manager
self._readout_time = PILATUS_READOUT_TIME
self._full_path = ""
@@ -178,6 +206,10 @@ class Pilatus(PSIDeviceBase, ADBase):
)
self._poll_thread_kill_event = threading.Event()
self._poll_rate = 1 # Poll rate in Hz
self.xas_xrd_scan_names = [
"xas_simple_scan_with_xrd",
"xas_advanced_scan_with_xrd",
]
# self._live_mode_thread = threading.Thread(
# target=self._live_mode_loop, daemon=True, name=f"{self.name}_live_mode_thread"
# )
@@ -354,50 +386,108 @@ class Pilatus(PSIDeviceBase, ADBase):
(self.scan_info.msg) object.
"""
# self.stop_live_mode() # Make sure that live mode is stopped if scan runs
self._update_scan_parameter()
scan_msg: ScanStatusMessage = self.scan_info.msg
if scan_msg.scan_name.startswith("xas"):
return None
# TODO implement logic for 'xas' scans
else:
exp_time = scan_msg.scan_parameters.get("exp_time", 0.0)
if exp_time - self._readout_time <= 0:
raise ValueError(
f"Exposure time {exp_time} is too short for Pilatus with readout_time {self._readout_time}."
)
detector_exp_time = exp_time - self._readout_time
n_images = scan_msg.num_points * scan_msg.scan_parameters.get("frames_per_trigger", 1)
self._full_path = get_full_path(scan_msg, name="pilatus")
file_path = "/".join(self._full_path.split("/")[:-1])
file_name = self._full_path.split("/")[-1]
if scan_msg.scan_name in self.xas_xrd_scan_names:
total_osc = 0
total_trig_lo = 0
total_trig_hi = 0
calc_duration = 0
n_trig_lo = 1
n_trig_hi = 1
init_lo = 1
init_hi = 1
lo_done = 0
hi_done = 0
if not self.scan_parameter.break_enable_low:
lo_done = 1
if not self.scan_parameter.break_enable_high:
hi_done = 1
while True:
total_osc = total_osc + 2
calc_duration = calc_duration + 2 * self.scan_parameter.scan_time
if self.scan_parameter.break_enable_low and n_trig_lo >= self.scan_parameter.cycle_low:
n_trig_lo = 1
calc_duration = calc_duration + self.scan_parameter.break_time_low
if init_lo:
lo_done = 1
init_lo = 0
else:
n_trig_lo += 1
# Prepare detector and backend
self.cam.array_callbacks.set(1).wait(5) # Enable array callbacks
self.hdf.enable.set(1).wait(5) # Enable HDF5 plugin
# Camera settings
self.cam.num_exposures.set(1).wait(5)
self.cam.num_images.set(n_images).wait(5)
self.cam.acquire_time.set(detector_exp_time).wait(5) # let's try this
self.cam.acquire_period.set(exp_time).wait(5)
self.filter_number.set(0).wait(5)
# HDF5 settings
logger.debug(f"Setting HDF5 file path to {file_path} and file name to {file_name}")
self.hdf.file_path.set(file_path).wait(5)
self.hdf.file_name.set(file_name).wait(5)
self.hdf.num_capture.set(n_images).wait(5)
self.cam.array_counter.set(0).wait(5) # Reset array counter
self.file_event.put(
file_path=self._full_path,
done=False,
successful=False,
hinted_h5_entries={"data": "/entry/data/data"},
)
if self.scan_parameter.break_enable_high and n_trig_hi >= self.scan_parameter.cycle_high:
n_trig_hi = 1
calc_duration = calc_duration + self.scan_parameter.break_time_high
if init_hi:
hi_done = 1
init_hi = 0
else:
n_trig_hi += 1
if lo_done and hi_done:
n = np.floor(self.scan_parameter.scan_duration / calc_duration)
total_osc = total_osc * n
if self.scan_parameter.break_enable_low:
total_trig_lo = n + 1
if self.scan_parameter.break_enable_high:
total_trig_hi = n + 1
calc_duration = calc_duration * n
lo_done = 0
hi_done = 0
if calc_duration >= self.scan_parameter.scan_duration:
break
# logger.info(f'total_osc: {total_osc}')
# logger.info(f'total trig low: {total_trig_lo}')
# logger.info(f'total trig high: {total_trig_hi}')
n_images = total_trig_lo + total_trig_hi
exp_time = self.scan_parameter.exp_time
elif scan_msg.scan_type == 'step':
n_images = scan_msg.num_points * scan_msg.scan_parameters.get("frames_per_trigger", 1)
exp_time = scan_msg.scan_parameters.get("exp_time")
else:
return None
# Common settings
if exp_time - self._readout_time <= 0:
raise ValueError((f"Exposure time {exp_time} is too short ",
f"for Pilatus with readout_time {self._readout_time}."
))
detector_exp_time = exp_time - self._readout_time
self._full_path = get_full_path(scan_msg, name="pilatus")
file_path = "/".join(self._full_path.split("/")[:-1])
file_name = self._full_path.split("/")[-1]
# Prepare detector and backend
self.cam.array_callbacks.set(1).wait(5) # Enable array callbacks
self.hdf.enable.set(1).wait(5) # Enable HDF5 plugin
# Camera settings
self.cam.num_exposures.set(1).wait(5)
self.cam.num_images.set(n_images).wait(5)
self.cam.acquire_time.set(detector_exp_time).wait(5) # let's try this
self.cam.acquire_period.set(exp_time).wait(5)
self.filter_number.set(0).wait(5)
# HDF5 settings
logger.debug(f"Setting HDF5 file path to {file_path} and file name to {file_name}")
self.hdf.file_path.set(file_path).wait(5)
self.hdf.file_name.set(file_name).wait(5)
self.hdf.num_capture.set(n_images).wait(5)
self.cam.array_counter.set(0).wait(5) # Reset array counter
self.file_event.put(
file_path=self._full_path,
done=False,
successful=False,
hinted_h5_entries={"data": "/entry/data/data"},
)
def on_unstage(self) -> None:
"""Called while unstaging the device."""
def on_pre_scan(self) -> DeviceStatus | None:
"""Called right before the scan starts on all devices automatically."""
if self.scan_info.msg.scan_name.startswith("xas"):
if self.scan_info.msg.scan_name in self.xas_xrd_scan_names:
# TODO implement logic for 'xas' scans
return None
else:
@@ -414,7 +504,7 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_trigger(self) -> DeviceStatus | None:
"""Called when the device is triggered."""
if self.scan_info.msg.scan_name.startswith("xas"):
if self.scan_info.msg.scan_name in self.xas_xrd_scan_names:
return None
# TODO implement logic for 'xas' scans
else:
@@ -446,7 +536,7 @@ class Pilatus(PSIDeviceBase, ADBase):
def on_complete(self) -> DeviceStatus | None:
"""Called to inquire if a device has completed a scans."""
if self.scan_info.msg.scan_name.startswith("xas"):
if self.scan_info.msg.scan_name in self.xas_xrd_scan_names:
# TODO implement logic for 'xas' scans
return None
status_hdf = CompareStatus(self.hdf.capture, ACQUIREMODE.DONE.value)
@@ -477,6 +567,15 @@ class Pilatus(PSIDeviceBase, ADBase):
# TODO do we need to clean the poll thread ourselves?
self.on_stop()
def _update_scan_parameter(self):
"""Get the scan_info parameters for the scan."""
for key, value in self.scan_info.msg.request_inputs["inputs"].items():
if hasattr(self.scan_parameter, key):
setattr(self.scan_parameter, key, value)
for key, value in self.scan_info.msg.request_inputs["kwargs"].items():
if hasattr(self.scan_parameter, key):
setattr(self.scan_parameter, key, value)
if __name__ == "__main__":
try: