From 6e695a67346c254bd1291e9d0bbfe4e5cc05b914 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Thu, 3 Jul 2025 15:22:34 +0200 Subject: [PATCH] New streak finder --- dap/algos/__init__.py | 3 +- dap/algos/streakfind.py | 160 +++++++++++++++++++++-------- dap/algos/whitefield_correction.py | 2 + dap/worker.py | 9 +- 4 files changed, 125 insertions(+), 49 deletions(-) diff --git a/dap/algos/__init__.py b/dap/algos/__init__.py index b9bf2a5..7f0ff39 100644 --- a/dap/algos/__init__.py +++ b/dap/algos/__init__.py @@ -8,8 +8,7 @@ from .peakfind import calc_peakfinder_analysis from .radprof import calc_radial_integration from .roi import calc_roi from .spiana import calc_spi_analysis -from .streakfind import calc_streakfinder_analysis -from .whitefield_correction import calc_apply_whitefield_correction +from .streakfind import calc_cbd_analysis from .thresh import calc_apply_threshold diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 757a40e..51e1026 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -3,31 +3,99 @@ Streak Finder algorithm implemented by CFEL Chapman group Requires Convergent beam streak finder package installed: -https://github.com/simply-nicky/streak_finder -(note g++ 11 required for building) +https://github.com/simply-nicky/streak_finder/tree/swiss_fel +(note g++ 11 required for building, numpy 2+ required) """ +import h5py import numpy as np -from math import sqrt, pow - -from streak_finder import PatternStreakFinder +from streak_finder import CrystData from streak_finder.label import Structure2D -from skimage.measure import profile_line +DEFAULT_NUM_THREADS = 16 +def calc_cbd_analysis(results, data, pf_pixel_mask): + try: + cryst_data = _generate_cryst_data(results, data, pf_pixel_mask) + except Exception as error: # Broad exception - we don't want to break anything here + print(f"Error processing CBD data:\n{error}") + results["cbd_error"] = f"Error processing CBD data:\n{error}" + return data -def calc_streakfinder_analysis(results, data, pixel_mask_sf): - do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) - if not do_streakfinder_analysis: - print(f"No streak finder analysis") + try: + _calc_streakfinder_analysis(results, cryst_data) + except Exception as error: # Broad exception - we don't want to break anything here + print(f"Error processing CBD data:\n{error}") + results["cbd_error"] = f"Error processing CBD data:\n{error}" + return cryst_data.snr + +def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: + do_snr = results.get("do_snr", False) + if not do_snr: return params_required = [ - "sf_structure_radius", - "sf_structure_rank", + "whitefield_data_file", + "mask_data_file", + "std_data_file", + "scale_whitefield", # Bool + ] + + if not all([param in results.keys() for param in params_required]): + raise ValueError(f"ERROR: Not enough parameters for CBD correction. Skipping\n" + f"{params_required=}") + + whitefield_data_file = results["whitefield_data_file"] + mask_data_file = results["mask_data_file"] + std_data_file = results["std_data_file"] + scale_whitefield = results["scale_whitefield"] + + # Using CXI Store specification as default + whitefield_dataset = results.get("whitefield_dataset", "entry/crystallography/whitefield") + mask_dataset = results.get("mask_dataset", "entry/instrument/detector/mask") + std_dataset = results.get("std_dataset", "entry/crystallography/std") + + num_threads = results.get("num_threads", DEFAULT_NUM_THREADS) + + with h5py.File(whitefield_data_file, "r") as hf: + whitefield = np.asarray(hf[whitefield_dataset]) + + with h5py.File(mask_data_file, "r") as hf: + mask = np.asarray(hf[mask_dataset]) + + with h5py.File(std_data_file, "r") as hf: + std = np.asarray(hf[std_dataset]) + + data = CrystData( + data=data.reshape((-1,) + data.shape[-2:]), + mask=mask*pf_pixel_mask, + std=std, + whitefield=whitefield + ) + if scale_whitefield: + data = data.scale_whitefield(method='median', num_threads=num_threads) + + return data + +def _calc_streakfinder_analysis(results, cryst_data: CrystData): + do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) + if not do_streakfinder_analysis: + return + + params_required = [ + "sf_peak_structure_radius", + "sf_peak_structure_rank", + "sf_streak_structure_radius", + "sf_streak_structure_rank", + "sf_peak_vmin", + "sf_streak_vmin", "sf_min_size", - "sf_vmin", "sf_npts", - "sf_xtol" + "sf_xtol", + "sf_nfa", + + "sf_num_threads", + # "beam_center_x", + # "beam_center_y" ] if not all([param in results.keys() for param in params_required]): @@ -35,40 +103,50 @@ def calc_streakfinder_analysis(results, data, pixel_mask_sf): f"{params_required=}") return - radius = results["sf_structure_radius"] - rank = results["sf_structure_rank"] + peak_structure_radius = results["sf_peak_structure_radius"] # peak + peak_structure_rank = results["sf_peak_structure_rank"] + streak_structure_radius = results["sf_streak_structure_radius"] # streak + streak_structure_rank = results["sf_streak_structure_rank"] + peak_vmin = results["sf_peak_vmin"] # peak + streak_vmin = results["sf_streak_vmin"] # streak min_size = results["sf_min_size"] - vmin = results["sf_vmin"] npts = results["sf_npts"] xtol = results["sf_xtol"] + nfa = results["sf_nfa"] + num_threads = results["sf_num_threads"] - struct = Structure2D(radius, rank) - psf = PatternStreakFinder( - data=data, - mask=pixel_mask_sf, - structure=struct, - min_size=min_size - ) - # Find peaks in a pattern. Returns a sparse set of peaks which values are above a threshold - # ``vmin`` that have a supporing set of a size larger than ``npts``. The minimal distance - # between peaks is ``2 * structure.radius`` - peaks = psf.detect_peaks(vmin=vmin, npts=npts) + x_center = results.get("beam_center_x", None) + y_center = results.get("beam_center_y", None) - # Streak finding algorithm. Starting from the set of seed peaks, the lines are iteratively - # extended with a connectivity structure. - streaks = psf.detect_streaks(peaks=peaks, xtol=xtol, vmin=vmin).to_lines() - streak_lengths = [] - bragg_counts = [] - for streak in streaks: - x0, y0, x1, y1 = streak - streak_lengths.append(sqrt(pow((x1 - x0), 2) + pow((y1 - y0), 2))) - bragg_counts.append(float(np.sum(profile_line(data, (x0, y0), (x1, y1))))) - streak_lines = streaks.T + peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) + streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) + + + det_obj = cryst_data.streak_detector(streaks_structure) + peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) + detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, + num_threads=num_threads) + if isinstance(detected, list): + detected = detected[0] + + streaks = det_obj.to_streaks(detected) + + if x_center is not None and y_center is not None: + streaks = streaks.concentric_only(x_center, y_center) + + streak_lines = streaks.lines + streak_lengths = np.sqrt( + np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + + np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + ).tolist() + + streak_lines = streak_lines.T _, number_of_streaks = streak_lines.shape print(f"Found {number_of_streaks} streaks") - list_result = [] - for line in streak_lines: # arr(4, n_lines); 0coord x0, y0, x1, y1 - list_result.append(line.tolist()) + + list_result = [line.tolist() for line in streak_lines] # arr(4, n_lines); 0coord x0, y0, x1, y1 + bragg_counts = [streak.total_mass() for streak in detected.streaks.values()] + results.update({"number_of_streaks": number_of_streaks}) results.update({"is_hit_frame": number_of_streaks > 0}) results.update({"streaks": list_result}) diff --git a/dap/algos/whitefield_correction.py b/dap/algos/whitefield_correction.py index 37785e3..cae9c14 100644 --- a/dap/algos/whitefield_correction.py +++ b/dap/algos/whitefield_correction.py @@ -70,3 +70,5 @@ def calc_apply_whitefield_correction(results, data): f"{error=}") else: results["white_field_correction_applied"] = 1 + + return whitefield_image \ No newline at end of file diff --git a/dap/worker.py b/dap/worker.py index c176d6f..30b31eb 100644 --- a/dap/worker.py +++ b/dap/worker.py @@ -3,7 +3,7 @@ import argparse import numpy as np from algos import (calc_apply_aggregation, calc_apply_threshold, calc_mask_pixels, calc_peakfinder_analysis, calc_radial_integration, calc_roi, calc_spi_analysis, - calc_streakfinder_analysis, calc_apply_whitefield_correction, JFData) + calc_cbd_analysis, JFData) from utils import Aggregator, BufferedJSON, randskip, read_bit from zmqsocks import ZMQSockets @@ -117,11 +117,8 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host # ??? - # White-field correction and streak finder processing for convergent-beam diffraction - print(f"Applying whitefield correction") - calc_apply_whitefield_correction(results, image) # changes image in place - print(f"Searching streaks") - calc_streakfinder_analysis(results, image, pixel_mask_pf) + # Correction and streak finder processing for convergent-beam diffraction + image = calc_cbd_analysis(results, image, pixel_mask_pf) print(f"Done\n{results=}") image, aggregation_is_ready = calc_apply_aggregation(results, image, pixel_mask_pf, aggregator)