diff --git a/tests/tests_devices/test_std_daq_live_processing.py b/tests/tests_devices/test_std_daq_live_processing.py index 6bf4745..991066e 100644 --- a/tests/tests_devices/test_std_daq_live_processing.py +++ b/tests/tests_devices/test_std_daq_live_processing.py @@ -148,3 +148,25 @@ def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_li assert isinstance(corrected_image, np.ndarray) assert corrected_image.shape == (100, 100) assert np.all(corrected_image >= 0), "Corrected image should not have negative values" + + +def test_std_daq_live_processing_apply_flat_correction_zero_division(std_daq_live_processing): + + # Create a mock image + image = np.random.rand(100, 100) * 1000 + 10 # Scale to simulate a realistic image + + # Set flat reference with epsilon values + flat = np.ones((100, 100)) * 2 + std_daq_live_processing.references["flat_(100, 100)"] = flat + + # Set dark reference to ones + dark = np.ones((100, 100)) * 2 + + std_daq_live_processing.references["dark_(100, 100)"] = dark + + # Apply flat correction + corrected_image = std_daq_live_processing.apply_flat_dark_correction(image) + assert isinstance(corrected_image, np.ndarray) + assert corrected_image.shape == (100, 100) + assert np.all(corrected_image >= 0), "Corrected image should not have negative values" + assert np.any(corrected_image < np.inf), "Corrected image should not have infinite values" diff --git a/tomcat_bec/devices/pco_edge/pcoedge_base.py b/tomcat_bec/devices/pco_edge/pcoedge_base.py index 15905ae..e78bb81 100644 --- a/tomcat_bec/devices/pco_edge/pcoedge_base.py +++ b/tomcat_bec/devices/pco_edge/pcoedge_base.py @@ -154,7 +154,7 @@ class PcoEdgeBase(Device): kind=Kind.config, ) - statuscode = Cpt(EpicsSignalRO, "STATUSCODE", auto_monitor=True, kind=Kind.config, string=True) + statuscode = Cpt(EpicsSignalRO, "STATUSCODE", auto_monitor=True, kind=Kind.config) init = Cpt( EpicsSignalRO, "INIT", diff --git a/tomcat_bec/devices/pco_edge/pcoedgecamera.py b/tomcat_bec/devices/pco_edge/pcoedgecamera.py index 108fe52..aad85f7 100644 --- a/tomcat_bec/devices/pco_edge/pcoedgecamera.py +++ b/tomcat_bec/devices/pco_edge/pcoedgecamera.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from collections import deque from typing import TYPE_CHECKING, Literal, cast import numpy as np @@ -69,7 +70,14 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): analysis_signal = Cpt(Signal, name="analysis_signal", kind=Kind.hinted, doc="Analysis Signal") analysis_signal2 = Cpt(Signal, name="analysis_signal2", kind=Kind.hinted, doc="Analysis Signal") - preview = Cpt(PreviewSignal, ndim=2, name="preview", doc="Camera raw data preview signal") + preview = Cpt(PreviewSignal, ndim=2, name="preview", doc="Camera raw data preview signal", num_rotation_90=1, transpose=False) + preview_corrected = Cpt( + PreviewSignal, + ndim=2, + name="preview_corrected", + doc="Camera preview signal with flat and dark correction", + num_rotation_90=1, transpose=False + ) progress = Cpt( ProgressSignal, name="progress", @@ -117,6 +125,8 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): self.backend = StdDaqClient(parent=self, ws_url=std_daq_ws, rest_url=std_daq_rest) self.backend.add_count_callback(self._on_count_update) self.live_preview = None + self.converted_files = deque(maxlen=100) # Store the last 10 converted files + self.target_files = deque(maxlen=100) # Store the last 10 target files self.acq_configs = {} if std_daq_live is not None: self.live_preview = StdDaqPreview(url=std_daq_live, cb=self._on_preview_update) @@ -207,8 +217,8 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): def _on_preview_update(self, img: np.ndarray): corrected_img = self.live_processing.apply_flat_dark_correction(img) self.live_processing.on_new_data(corrected_img) - self.preview.put(corrected_img) - self._run_subs(sub_type=self.SUB_DEVICE_MONITOR_2D, obj=self, value=corrected_img) + self.preview.put(img) + self.preview_corrected.put(corrected_img) def _on_count_update(self, count: int): """ @@ -245,7 +255,8 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): self, name: str, file_path: str = "", - file_prefix: str = "", + file_name: str | None = None, + file_suffix: str = "", num_images: int | None = None, frames_per_trigger: int | None = None, ) -> StatusBase: @@ -263,14 +274,20 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): Returns: DeviceStatus: The status of the restart operation. It resolves when the camera is ready to receive the first image. """ + if file_name is not None and file_suffix: + raise ValueError("Both file_name and file_suffix are specified. Please choose one.") + self.acq_configs[name] = {} conf = {} if file_path: self.acq_configs[name]["file_path"] = self.file_path.get() conf["file_path"] = file_path - if file_prefix: + if file_suffix: self.acq_configs[name]["file_prefix"] = self.file_prefix.get() - conf["file_prefix"] = file_prefix + conf["file_prefix"] = "_".join([self.file_prefix.get(), file_suffix]) + if file_name: + self.acq_configs[name]["file_prefix"] = self.file_prefix.get() + conf["file_prefix"] = file_name if num_images is not None: self.acq_configs[name]["num_images"] = self.num_images.get() conf["num_images"] = num_images @@ -322,7 +339,7 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): status = self.live_processing.update_reference_with_file( reference_type=reference_type, file_path=self.target_file, - entry="tomcat-gigafrost/data", # type: ignore + entry="tomcat-pco/data", # type: ignore wait=False, # Do not wait for the update to finish ) return status @@ -346,6 +363,7 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): num_images=self.num_images.get(), # type: ignore ) self.camera_status.set(CameraStatus.RUNNING).wait() + self.target_files.append(self.target_file) def is_running(*, value, timestamp, **_): return bool(value == CameraStatusCode.RUNNING) @@ -367,10 +385,15 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): """Stop the camera acquisition and set it to idle state""" self.set_idle().wait() status = DeviceStatus(self) - self.backend.add_status_callback( - status, success=[StdDaqStatus.IDLE], error=[StdDaqStatus.REJECTED, StdDaqStatus.ERROR] - ) - self.backend.stop() + if self.backend.status != StdDaqStatus.IDLE: + self.backend.add_status_callback( + status, + success=[StdDaqStatus.IDLE], + error=[StdDaqStatus.REJECTED, StdDaqStatus.ERROR], + ) + self.backend.stop() + else: + status.set_finished() return status @property @@ -431,11 +454,16 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): "/gpfs/test/test-beamline" # FIXME: This should be from the scan message ) if "file_prefix" not in scan_args: - scan_args["file_prefix"] = scan_msg.info["file_components"][0].split("/")[-1] + "_" + file_base = scan_msg.info["file_components"][0].split("/")[-1] + file_suffix = scan_msg.info.get("file_suffix") or "" + comps = [file_base, self.name] + if file_suffix: + comps.append(file_suffix) + scan_args["file_prefix"] = "_".join(comps) self.configure(scan_args) if scan_msg.scan_type == "step": - num_points = self.frames_per_trigger.get() * scan_msg.num_points # type: ignore + num_points = self.frames_per_trigger.get() * max(scan_msg.num_points, 1) # type: ignore else: num_points = self.frames_per_trigger.get() @@ -516,6 +544,12 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): """Called to inquire if a device has completed a scans.""" def _create_dataset(_status: DeviceStatus): + if ( + self.target_file in self.converted_files + or self.target_file not in self.target_files + ): + logger.info(f"File {self.target_file} already processed or not in target files.") + return self.backend.create_virtual_datasets( self.file_path.get(), file_prefix=self.file_prefix.get() # type: ignore ) @@ -524,8 +558,10 @@ class PcoEdge5M(PSIDeviceBase, PcoEdgeBase): file_path=self.target_file, done=True, successful=True, - hinted_location={"data": "data"}, + hinted_location={"data": "tomcat-pco/data"}, ) + self.converted_files.append(self.target_file) + logger.info(f"Finished writing to {self.target_file}") status = self.acq_done() status.add_callback(_create_dataset) diff --git a/tomcat_bec/devices/std_daq/std_daq_live_processing.py b/tomcat_bec/devices/std_daq/std_daq_live_processing.py index e427786..5368d24 100644 --- a/tomcat_bec/devices/std_daq/std_daq_live_processing.py +++ b/tomcat_bec/devices/std_daq/std_daq_live_processing.py @@ -121,8 +121,12 @@ class StdDaqLiveProcessing: corrected_data = data - dark return corrected_data - corrected_data = (data - dark) / (flat - dark) - return corrected_data + # Ensure that the division does not lead to division by zero + flat_corr = np.abs(flat-dark) + corrected_data = np.divide( + data - dark, flat_corr, out=np.zeros_like(data, dtype=np.float32), where=flat_corr != 0 + ) + return np.clip(corrected_data, a_min=0, a_max=None) @typechecked def _load_and_update_reference( @@ -145,6 +149,15 @@ class StdDaqLiveProcessing: """ try: + + ######################################################## + # Remove these lines once the mount is fixed + if not isinstance(file_path, str): + file_path = str(file_path) + if file_path.startswith("/gpfs/test"): + file_path = file_path.replace("/gpfs/test", "/data/test") + ######################################################## + with h5py.File(file_path, "r") as file: if entry not in file: raise ValueError(f"Entry '{entry}' not found in the file.") diff --git a/tomcat_bec/scans/__init__.py b/tomcat_bec/scans/__init__.py index ad3c815..46e6b20 100644 --- a/tomcat_bec/scans/__init__.py +++ b/tomcat_bec/scans/__init__.py @@ -1,9 +1,10 @@ -from .simple_scans import TomoFlyScan, TomoScan +from .simple_scans import AcquireDark, AcquireFlat, AcquireReferences, TomoFlyScan, TomoScan from .tomcat_scans import TomcatSimpleSequence, TomcatSnapNStep -from .tutorial_fly_scan import ( - AcquireDark, - AcquireProjections, - AcquireRefs, - AcquireWhite, - TutorialFlyScanContLine, -) + +# from .tutorial_fly_scan import ( +# # AcquireDark, +# AcquireProjections, +# AcquireRefs, +# AcquireWhite, +# TutorialFlyScanContLine, +# ) diff --git a/tomcat_bec/scans/simple_scans.py b/tomcat_bec/scans/simple_scans.py index 695dff8..8419ee6 100644 --- a/tomcat_bec/scans/simple_scans.py +++ b/tomcat_bec/scans/simple_scans.py @@ -16,6 +16,7 @@ class TomoComponents: self.scan = scan self.stubs = scan.stubs self.device_manager = scan.device_manager + self.connector = scan.device_manager.connector # Update the available cameras for the current scan self.cameras = self._get_cameras() @@ -44,26 +45,60 @@ class TomoComponents: self, name: str, num_images: int, - prefix: str = "", + file_suffix: str = "", file_path: str = "", frames_per_trigger: int = 1, ): - if not prefix: - return + """ + Restart the cameras with a new configuration. + This is typically used to reset the cameras during another scan, e.g. before acquiring dark or flat images. + Args: + name (str): Name of the configuration to restart with. + num_images (int): Number of images to acquire. + file_suffix (str): Suffix for the file names. + file_path (str): Path where the files will be saved. + frames_per_trigger (int): Number of frames to acquire per trigger. + """ for cam in self.cameras: yield from self.stubs.send_rpc_and_wait( device=cam, func_name="restart_with_new_config", name=name, - file_prefix=prefix, + file_suffix=file_suffix, file_path=file_path, num_images=num_images, frames_per_trigger=frames_per_trigger, ) - def complete_and_restore_configs(self, name: str): + def scan_report_instructions(self): + """ + Generate scan report instructions for the acquisition. + This method provides the necessary instructions to listen to the camera progress during the scan. + """ + if not self.cameras: + return + + # Use the first camera or "gfcam" if available for reporting + report_camera = "gfcam" if "gfcam" in self.cameras else self.cameras[0] + yield from self.stubs.scan_report_instruction({"device_progress": [report_camera]}) + + def complete(self): + """ + Complete the acquisition by sending an RPC to each camera. + This method is typically called after the acquisition is done to finalize the process and start + writing the virtual dataset. + """ for cam in self.cameras: yield from self.stubs.send_rpc_and_wait(device=cam, func_name="on_complete") + + def restore_configs(self, name: str): + """ + Restore the camera configurations after an acquisition. + + Args: + name (str): Name of the configuration to restore. + """ + for cam in self.cameras: yield from self.stubs.send_rpc_and_wait( device=cam, func_name="restore_config", name=name ) @@ -84,7 +119,7 @@ class TomoComponents: device=cam, func_name="update_live_processing_reference", reference_type=ref_type ) - def acquire_dark(self, num_images: int, exposure_time: float, name="dark"): + def acquire_dark(self, num_images: int, exposure_time: float, name="dark", restart=True, restore=True): """ Acquire dark images. @@ -95,23 +130,24 @@ class TomoComponents: if not num_images: return logger.info(f"Acquiring {num_images} dark images with exposure time {exposure_time}s.") + self.connector.send_client_info(f"Acquiring {num_images} dark images.") - yield from self.restart_cameras( - name=name, prefix=name, num_images=num_images, frames_per_trigger=1 - ) + if restart: + yield from self.restart_cameras( + name=name, file_suffix=name, num_images=num_images, frames_per_trigger=num_images + ) # yield from self.close_shutter() - for i in range(num_images): - logger.debug(f"Acquiring dark image {i+1}/{num_images}.") - yield from self.stubs.trigger(min_wait=exposure_time) - yield from self.stubs.read(group="monitored", point_id=self.scan.point_id) - self.scan.point_id += 1 - yield from self.complete_and_restore_configs(name=name) + yield from self.stubs.trigger(min_wait=exposure_time * num_images) + yield from self.complete() yield from self.update_live_processing_references(ref_type="dark") + if restore: + yield from self.restore_configs(name=name) # yield from self.open_shutter() + self.connector.send_client_info("") logger.info("Dark image acquisition complete.") - def acquire_flat(self, num_images: int, exposure_time: float, name="flat"): + def acquire_flat(self, num_images: int, exposure_time: float, name="flat", restart=True, restore=True): """ Acquire flat images. @@ -122,25 +158,116 @@ class TomoComponents: if not num_images: return logger.info(f"Acquiring {num_images} flat images with exposure time {exposure_time}s.") + self.connector.send_client_info(f"Acquiring {num_images} flat images.") - yield from self.restart_cameras( - name=name, prefix=name, num_images=num_images, frames_per_trigger=1 - ) + if restart: + yield from self.restart_cameras( + name=name, file_suffix=name, num_images=num_images, frames_per_trigger=num_images + ) # yield from self.open_shutter() - for i in range(num_images): - logger.debug(f"Acquiring flat image {i+1}/{num_images}.") - yield from self.stubs.trigger(min_wait=exposure_time) - yield from self.stubs.read(group="monitored", point_id=self.scan.point_id) - self.scan.point_id += 1 - yield from self.complete_and_restore_configs(name=name) - yield from self.update_live_processing_references(ref_type="dark") + yield from self.stubs.trigger(min_wait=exposure_time * num_images) + yield from self.complete() + yield from self.update_live_processing_references(ref_type="flat") + + if restore: + yield from self.restore_configs(name=name) logger.info("Flat image acquisition complete.") + self.connector.send_client_info("") - def acquire_references(self, num_darks: int, num_flats: int, exp_time: float, name: str): - yield from self.acquire_dark(num_darks, exposure_time=exp_time, name=name) - yield from self.acquire_flat(num_flats, exposure_time=exp_time, name=name) + def acquire_references(self, num_darks: int, num_flats: int, exp_time: float, restart=True, restore=True): + yield from self.acquire_dark(num_darks, exposure_time=exp_time, restart=restart, restore=restore) + yield from self.acquire_flat(num_flats, exposure_time=exp_time, restart=restart, restore=restore) +class AcquireDark(ScanBase): + scan_name = "acquire_dark" + gui_config = {"Acquisition Parameters": ["num_images", "exp_time"]} + + def __init__(self, num_images: int, exp_time: float, **kwargs): + """ + Acquire dark images. + + Args: + num_images (int): Number of dark images to acquire. + exp_time (float): Exposure time for each dark image in seconds. + + Returns: + ScanReport + """ + frames_per_trigger = num_images if num_images > 0 else 1 + super().__init__(frames_per_trigger=frames_per_trigger, exp_time=exp_time, **kwargs) + self.components = TomoComponents(self) + + def scan_report_instructions(self): + yield from self.components.scan_report_instructions() + + def scan_core(self): + yield from self.components.acquire_dark( + self.frames_per_trigger, self.exp_time, restart=False + ) + + +class AcquireFlat(ScanBase): + scan_name = "acquire_flat" + gui_config = {"Acquisition Parameters": ["num_images", "exp_time"]} + + def __init__(self, num_images: int, exp_time: float, **kwargs): + """ + Acquire flat images. + + Args: + num_images (int): Number of flat images to acquire. + exp_time (float): Exposure time for each flat image in seconds. + frames_per_trigger (int): Number of frames to acquire per trigger. + + Returns: + ScanReport + """ + frames_per_trigger = num_images if num_images > 0 else 1 + super().__init__(frames_per_trigger=frames_per_trigger, exp_time=exp_time, **kwargs) + self.components = TomoComponents(self) + + def scan_report_instructions(self): + yield from self.components.scan_report_instructions() + + def scan_core(self): + yield from self.components.acquire_flat( + self.frames_per_trigger, self.exp_time, restart=False + ) + + +class AcquireReferences(ScanBase): + scan_name = "acquire_refs" + gui_config = {"Acquisition Parameters": ["num_darks", "num_flats", "exp_time"]} + + def __init__(self, num_darks: int, num_flats: int, exp_time: float, **kwargs): + """ + Acquire flats and darks. + + Args: + num_darks (int): Number of dark images to acquire. + num_flats (int): Number of flat images to acquire. + exp_time (float): Exposure time for each flat image in seconds. + frames_per_trigger (int): Number of frames to acquire per trigger. + + Returns: + ScanReport + """ + super().__init__(exp_time=exp_time, **kwargs) + self.num_darks = num_darks + self.num_flats = num_flats + self.components = TomoComponents(self) + + def scan_report_instructions(self): + yield from self.components.scan_report_instructions() + + def pre_scan(self): + yield from self.components.acquire_references(self.num_darks, self.num_flats, self.exp_time) + + def scan_core(self): + yield None + + class TomoScan(LineScan): scan_name = "tomo_line_scan" @@ -184,23 +311,19 @@ class TomoScan(LineScan): self.num_flats = num_flats self.components = TomoComponents(self) - def prepare_positions(self): - yield from super().prepare_positions() - self.num_pos += 2 * (self.num_darks + self.num_flats) - def pre_scan(self): yield from self.components.acquire_dark(self.num_darks, self.exp_time, name="pre_scan_dark") yield from self.components.acquire_flat(self.num_flats, self.exp_time, name="pre_scan_flat") yield from super().pre_scan() - def finalize(self): - yield from super().finalize() - yield from self.components.acquire_dark( - self.num_darks, self.exp_time, name="post_scan_dark" - ) - yield from self.components.acquire_flat( - self.num_flats, self.exp_time, name="post_scan_flat" - ) + # def finalize(self): + # yield from super().finalize() + # yield from self.components.acquire_dark( + # self.num_darks, self.exp_time, name="post_scan_dark" + # ) + # yield from self.components.acquire_flat( + # self.num_flats, self.exp_time, name="post_scan_flat" + # ) class TomoFlyScan(AsyncFlyScanBase):