From 901587db7981761db028c7d1046573c8e5fde4ab Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 20:58:20 +0200 Subject: [PATCH] TEST: try to get rid of data containers in streak finder package; make it simple --- dap/algos/streakfind.py | 79 ++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 9eacf80..6ebcd81 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -9,22 +9,13 @@ https://github.com/simply-nicky/streak_finder/ import h5py import numpy as np from psutil import virtual_memory -from streak_finder import CrystData as CrystDataBase from streak_finder.label import Structure2D +from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 -class CrystData(CrystDataBase): - - def clear(self): - self.data = None - self.mask = None - self.snr = None - self.std = None - - def _handle_negative_values(data, mask, handler: str): if not handler or np.all(data>=0): return @@ -107,14 +98,14 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: data * mask - whitefield, std, out=np.zeros_like(data), - where=(std==0.0) + where=(std!=0.0) ) del whitefield del std del mask return snr -def _calc_streakfinder_analysis(results, cryst_data: CrystData): +def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: return @@ -155,26 +146,33 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): x_center = results.get("beam_center_x", None) y_center = results.get("beam_center_y", None) - mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] - - for mask_roi in mask_rois: - cryst_data = cryst_data.mask_region(mask_roi) + # mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] + # + # for mask_roi in mask_rois: + # cryst_data = cryst_data.mask_region(mask_roi) crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max] - - if crop_roi is not None: - crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1 - cryst_data = cryst_data.crop(crop_roi_t) + lookahead = results.get("cbd_lookahead", 1) + # if crop_roi is not None: + # crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1 + # cryst_data = cryst_data.crop(crop_roi_t) peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) - - det_obj = cryst_data.streak_detector(streaks_structure) - peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) - detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, - num_threads=num_threads) - + peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin, + num_threads=num_threads) + peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts, + num_threads=num_threads) + + # det_obj = cryst_data.streak_detector(streaks_structure) + + + # peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) + # detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, + # num_threads=num_threads) + detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size, + lookahead, nfa, num_threads=num_threads) if isinstance(detected, list): detected = detected[0] @@ -186,9 +184,17 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): results["bragg_counts"] = [] return - streaks = det_obj.to_streaks(detected) + + if isinstance(detected, list): + streaks = [np.asarray(pattern.to_lines()) for pattern in detected] + else: + streaks = [np.asarray(detected.to_lines()),] + streak_lines = np.concatenate(streaks) + # rstreaks = Streaks(index=IndexArray(idxs), lines=lines) + + # streaks = det_obj.to_streaks(detected) detected_streaks = np.array(detected.streaks) - streak_lines = streaks.lines + # streak_lines = streaks.lines # Adjust to crop region if crop_roi is not None: @@ -196,10 +202,19 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): streak_lines = streak_lines + shift if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - streaks_mask = streaks.concentric_only(x_center, y_center) + # if crop_roi is not None: + # x_center -= crop_roi[0] + # y_center -= crop_roi[2] + threshold = 0.33 + centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers -np.asarray([x_center, y_center]) + prod =np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + # return mask + # streaks_mask = streaks.concentric_only(x_center, y_center) streak_lines = streak_lines[streaks_mask] detected_streaks = detected_streaks[streaks_mask]