From 8ffad86f9ae9b741c3ad6d27aab937aa85cc19ba Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 18:59:27 +0200 Subject: [PATCH 01/15] Use bool rather than np.bool to avoid deprecation errors --- dap/algos/addmaskfile.py | 4 ++-- dap/algos/streakfind.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dap/algos/addmaskfile.py b/dap/algos/addmaskfile.py index 4f82ee5..f337f70 100644 --- a/dap/algos/addmaskfile.py +++ b/dap/algos/addmaskfile.py @@ -18,14 +18,14 @@ def calc_apply_additional_mask_from_file(results, pixel_mask_pf): # Support for hdf5 and npy if mask_file.endswith(".npy"): try: - mask = np.load(mask_file).astype(np.bool) + mask = np.load(mask_file).astype(bool) except Exception as error: results["mask_error"] = f"Error loading mask data from NumPy file {mask_file}:\n{error}" return else: try: with h5py.File(mask_file, "r") as h5f: - mask = h5f[mask_dataset][:].astype(np.bool) + mask = h5f[mask_dataset][:].astype(bool) except Exception as error: results["mask_error"] = f"Error loading mask from hdf5 file {mask_file}:\n{error}" return diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index fc91a28..1b72053 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -86,7 +86,7 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: else: mask_dataset = results.get("cbd_mask_dataset", "entry/instrument/detector/mask") with h5py.File(mask_data_file, "r") as hf: - mask = hf[mask_dataset][:].astype(np.bool) + mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask data = CrystData( -- 2.49.1 From c09bdc5b4ee32d86005a9a63d7fbb4abe8e35e8c Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 19:09:52 +0200 Subject: [PATCH 02/15] Explicitly delete cryst_data object after algorithm completion --- dap/algos/streakfind.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 1b72053..d7a0e6f 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -50,6 +50,7 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): _calc_streakfinder_analysis(results, cryst_data) except Exception as error: # Broad exception - we don't want to break anything here results["cbd_error"] = f"Error processing CBD data:\n{error}" + del cryst_data return data -- 2.49.1 From 6ebaf0ae6a71654a6628de66de92721e440764ca Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 19:29:29 +0200 Subject: [PATCH 03/15] Try cleanup cryst data, track memory usage --- dap/algos/streakfind.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index d7a0e6f..d1cdd00 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -8,13 +8,23 @@ https://github.com/simply-nicky/streak_finder/ """ import h5py import numpy as np -from streak_finder import CrystData +from psutil import virtual_memory +from streak_finder import CrystData as CrystDataBase from streak_finder.label import Structure2D DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 +class CrystData(CrystDataBase): + + def clear(self): + self.data = None + self.mask = None + self.snr = None + self.std = None + + def _handle_negative_values(data, mask, handler: str): if not handler or np.all(data>=0): return @@ -43,14 +53,17 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): return data if do_snr: - # Changes data and mask in-place + # Changes data in-place data = cryst_data.snr[0].copy() try: _calc_streakfinder_analysis(results, cryst_data) except Exception as error: # Broad exception - we don't want to break anything here results["cbd_error"] = f"Error processing CBD data:\n{error}" + print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") + cryst_data.clear() del cryst_data + print(f"mem used after {virtual_memory().used // pow(2, 30)} Gb \n") return data -- 2.49.1 From 2ea2129d5b1bc1aedc828ea3788666f8f7c4acbe Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 19:40:28 +0200 Subject: [PATCH 04/15] Try to get rid of cryst data entirely --- dap/algos/streakfind.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index d1cdd00..9eacf80 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -54,20 +54,20 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): if do_snr: # Changes data in-place - data = cryst_data.snr[0].copy() + data = cryst_data #.snr[0].copy() try: _calc_streakfinder_analysis(results, cryst_data) except Exception as error: # Broad exception - we don't want to break anything here results["cbd_error"] = f"Error processing CBD data:\n{error}" - print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") - cryst_data.clear() - del cryst_data + # print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") + # cryst_data.clear() + # del cryst_data print(f"mem used after {virtual_memory().used // pow(2, 30)} Gb \n") return data -def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: +def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: params_required = [ "cbd_whitefield_data_file", "cbd_std_data_file", @@ -80,13 +80,13 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: whitefield_data_file = results["cbd_whitefield_data_file"] std_data_file = results["cbd_std_data_file"] - scale_whitefield = results["cbd_scale_whitefield"] + # scale_whitefield = results["cbd_scale_whitefield"] # Using CXI Store specification as default whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield") std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std") - num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS) + # num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS) with h5py.File(whitefield_data_file, "r") as hf: whitefield = hf[whitefield_dataset][:] @@ -103,17 +103,16 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask - data = CrystData( - data=data[np.newaxis, :], - mask=mask, - std=std, - whitefield=whitefield + snr = np.divide( + data * mask - whitefield, + std, + out=np.zeros_like(data), + where=(std==0.0) ) - if scale_whitefield: - data = data.scale_whitefield(method='median', num_threads=num_threads) - - data = data.update_snr() - return data + del whitefield + del std + del mask + return snr def _calc_streakfinder_analysis(results, cryst_data: CrystData): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) -- 2.49.1 From 901587db7981761db028c7d1046573c8e5fde4ab Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 20:58:20 +0200 Subject: [PATCH 05/15] TEST: try to get rid of data containers in streak finder package; make it simple --- dap/algos/streakfind.py | 79 ++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 9eacf80..6ebcd81 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -9,22 +9,13 @@ https://github.com/simply-nicky/streak_finder/ import h5py import numpy as np from psutil import virtual_memory -from streak_finder import CrystData as CrystDataBase from streak_finder.label import Structure2D +from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 -class CrystData(CrystDataBase): - - def clear(self): - self.data = None - self.mask = None - self.snr = None - self.std = None - - def _handle_negative_values(data, mask, handler: str): if not handler or np.all(data>=0): return @@ -107,14 +98,14 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: data * mask - whitefield, std, out=np.zeros_like(data), - where=(std==0.0) + where=(std!=0.0) ) del whitefield del std del mask return snr -def _calc_streakfinder_analysis(results, cryst_data: CrystData): +def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: return @@ -155,26 +146,33 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): x_center = results.get("beam_center_x", None) y_center = results.get("beam_center_y", None) - mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] - - for mask_roi in mask_rois: - cryst_data = cryst_data.mask_region(mask_roi) + # mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] + # + # for mask_roi in mask_rois: + # cryst_data = cryst_data.mask_region(mask_roi) crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max] - - if crop_roi is not None: - crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1 - cryst_data = cryst_data.crop(crop_roi_t) + lookahead = results.get("cbd_lookahead", 1) + # if crop_roi is not None: + # crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1 + # cryst_data = cryst_data.crop(crop_roi_t) peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) - - det_obj = cryst_data.streak_detector(streaks_structure) - peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) - detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, - num_threads=num_threads) - + peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin, + num_threads=num_threads) + peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts, + num_threads=num_threads) + + # det_obj = cryst_data.streak_detector(streaks_structure) + + + # peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) + # detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, + # num_threads=num_threads) + detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size, + lookahead, nfa, num_threads=num_threads) if isinstance(detected, list): detected = detected[0] @@ -186,9 +184,17 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): results["bragg_counts"] = [] return - streaks = det_obj.to_streaks(detected) + + if isinstance(detected, list): + streaks = [np.asarray(pattern.to_lines()) for pattern in detected] + else: + streaks = [np.asarray(detected.to_lines()),] + streak_lines = np.concatenate(streaks) + # rstreaks = Streaks(index=IndexArray(idxs), lines=lines) + + # streaks = det_obj.to_streaks(detected) detected_streaks = np.array(detected.streaks) - streak_lines = streaks.lines + # streak_lines = streaks.lines # Adjust to crop region if crop_roi is not None: @@ -196,10 +202,19 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): streak_lines = streak_lines + shift if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - streaks_mask = streaks.concentric_only(x_center, y_center) + # if crop_roi is not None: + # x_center -= crop_roi[0] + # y_center -= crop_roi[2] + threshold = 0.33 + centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers -np.asarray([x_center, y_center]) + prod =np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + # return mask + # streaks_mask = streaks.concentric_only(x_center, y_center) streak_lines = streak_lines[streaks_mask] detected_streaks = detected_streaks[streaks_mask] -- 2.49.1 From e3efd82b2c28e6bbd9c33ca13c7af12fd7dd8864 Mon Sep 17 00:00:00 2001 From: "Dorofeeva Elizaveta (EXT)" Date: Mon, 14 Jul 2025 21:25:59 +0200 Subject: [PATCH 06/15] Bugfixes, no leak visible, TODO: cleanup, wf norm, crop, maybe mask roi --- dap/algos/streakfind.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 6ebcd81..0df56b3 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -48,7 +48,7 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): data = cryst_data #.snr[0].copy() try: - _calc_streakfinder_analysis(results, cryst_data) + _calc_streakfinder_analysis(results, cryst_data, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here results["cbd_error"] = f"Error processing CBD data:\n{error}" # print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") @@ -219,8 +219,8 @@ def _calc_streakfinder_analysis(results, snr, mask): detected_streaks = detected_streaks[streaks_mask] streak_lengths = np.sqrt( - np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + - np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) + + np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) ).tolist() streak_lines = streak_lines.T -- 2.49.1 From 77d3ff9293e092ed783c5b61fddb03ef981c42d4 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 22:07:08 +0200 Subject: [PATCH 07/15] TEST scale whitefield --- dap/algos/streakfind.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 0df56b3..658d6fc 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -11,8 +11,8 @@ import numpy as np from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks +from streak_finder._src.src.median import robust_lsq -DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 @@ -38,27 +38,24 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): _handle_negative_values(data, pf_pixel_mask, negative_val_handler) try: - cryst_data = _generate_cryst_data(results, data, pf_pixel_mask) + snr = _calc_snr(results, data, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" + results["cbd_error"] = f"SNR - Error processing CBD data:\n{error}" return data if do_snr: # Changes data in-place - data = cryst_data #.snr[0].copy() + data = snr try: - _calc_streakfinder_analysis(results, cryst_data, pf_pixel_mask) + _calc_streakfinder_analysis(results, snr, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" - # print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") - # cryst_data.clear() - # del cryst_data + results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}" print(f"mem used after {virtual_memory().used // pow(2, 30)} Gb \n") return data -def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: +def _calc_snr(results, data, pf_pixel_mask): params_required = [ "cbd_whitefield_data_file", "cbd_std_data_file", @@ -71,14 +68,12 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: whitefield_data_file = results["cbd_whitefield_data_file"] std_data_file = results["cbd_std_data_file"] - # scale_whitefield = results["cbd_scale_whitefield"] + scale_whitefield = results["cbd_scale_whitefield"] # Using CXI Store specification as default whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield") std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std") - # num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS) - with h5py.File(whitefield_data_file, "r") as hf: whitefield = hf[whitefield_dataset][:] @@ -93,7 +88,8 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: with h5py.File(mask_data_file, "r") as hf: mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask - + if scale_whitefield: + _scale_whitefield(data, mask, whitefield, std) snr = np.divide( data * mask - whitefield, std, @@ -105,6 +101,19 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: del mask return snr +def _scale_whitefield(data, mask, whitefield, std, + r0: float = 0.0, r1: float = 0.5, + n_iter: int = 12, lm: float = 9.0, num_threads: int = 1 + ): + mask = mask & (std > 0.0) + y = np.where(mask, data / std, 0.0)[mask] # must be newaxis + W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis + + scales = robust_lsq(W=W[np.newaxis, :], y=y[np.newaxis, :], axis=1, r0=r0, r1=r1, n_iter=n_iter, lm=lm, + num_threads=num_threads) + scales=np.ravel(scales) + whitefield *= scales + def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: -- 2.49.1 From f1828cb4b5c78d296dd514d89733d0176d609a4b Mon Sep 17 00:00:00 2001 From: "Dorofeeva Elizaveta (EXT)" Date: Mon, 14 Jul 2025 22:26:58 +0200 Subject: [PATCH 08/15] True divide --- dap/algos/streakfind.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 658d6fc..a3630d1 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -103,11 +103,15 @@ def _calc_snr(results, data, pf_pixel_mask): def _scale_whitefield(data, mask, whitefield, std, r0: float = 0.0, r1: float = 0.5, - n_iter: int = 12, lm: float = 9.0, num_threads: int = 1 + n_iter: int = 12, lm: float = 9.0, num_threads: int = 16 ): mask = mask & (std > 0.0) - y = np.where(mask, data / std, 0.0)[mask] # must be newaxis - W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis + #y = np.where(mask, data / std, 0.0)[mask] # must be newaxis + y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] # must be newaxis + #W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis + + W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] # must be newaxis + #W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis scales = robust_lsq(W=W[np.newaxis, :], y=y[np.newaxis, :], axis=1, r0=r0, r1=r1, n_iter=n_iter, lm=lm, num_threads=num_threads) -- 2.49.1 From 90c0b9cf3d8d4794219ac348b19b6f12723a01bf Mon Sep 17 00:00:00 2001 From: "Dorofeeva Elizaveta (EXT)" Date: Mon, 14 Jul 2025 22:42:05 +0200 Subject: [PATCH 09/15] No newaxis, still extremely slow method for white field scale; TODO: use median instead --- dap/algos/streakfind.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index a3630d1..2a21e28 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -106,15 +106,12 @@ def _scale_whitefield(data, mask, whitefield, std, n_iter: int = 12, lm: float = 9.0, num_threads: int = 16 ): mask = mask & (std > 0.0) - #y = np.where(mask, data / std, 0.0)[mask] # must be newaxis - y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] # must be newaxis - #W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis + y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] + W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] - W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] # must be newaxis - #W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis - - scales = robust_lsq(W=W[np.newaxis, :], y=y[np.newaxis, :], axis=1, r0=r0, r1=r1, n_iter=n_iter, lm=lm, + scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm, num_threads=num_threads) + print(f"{scales=}") scales=np.ravel(scales) whitefield *= scales -- 2.49.1 From 9ed33350211e38e13c9b2c7a93585923f9fb3b51 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 22:44:03 +0200 Subject: [PATCH 10/15] TEST: use median method for whitefield scale --- dap/algos/streakfind.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 2a21e28..8023960 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -11,7 +11,7 @@ import numpy as np from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks -from streak_finder._src.src.median import robust_lsq +from streak_finder._src.src.median import robust_lsq, median DEFAULT_MIN_HIT_STREAKS = 5 @@ -108,9 +108,12 @@ def _scale_whitefield(data, mask, whitefield, std, mask = mask & (std > 0.0) y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] + # + # scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm, + # num_threads=num_threads) - scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm, - num_threads=num_threads) + scales = median(y * W, axis=0, num_threads=num_threads) / \ + median(W * W, axis=0, num_threads=num_threads) print(f"{scales=}") scales=np.ravel(scales) whitefield *= scales -- 2.49.1 From d30adf3e14a3f326341eb9a80326928c5534dc11 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 23:04:17 +0200 Subject: [PATCH 11/15] TEST: try return crop roi; cleanup --- dap/algos/streakfind.py | 69 ++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 8023960..981a214 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -14,6 +14,7 @@ from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_pea from streak_finder._src.src.median import robust_lsq, median DEFAULT_MIN_HIT_STREAKS = 5 +DEFAULT_NUM_THREADS = 16 def _handle_negative_values(data, mask, handler: str): @@ -89,7 +90,7 @@ def _calc_snr(results, data, pf_pixel_mask): mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask if scale_whitefield: - _scale_whitefield(data, mask, whitefield, std) + _scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS)) snr = np.divide( data * mask - whitefield, std, @@ -101,21 +102,13 @@ def _calc_snr(results, data, pf_pixel_mask): del mask return snr -def _scale_whitefield(data, mask, whitefield, std, - r0: float = 0.0, r1: float = 0.5, - n_iter: int = 12, lm: float = 9.0, num_threads: int = 16 - ): +def _scale_whitefield(data, mask, whitefield, std, num_threads): mask = mask & (std > 0.0) y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] - W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] - # - # scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm, - # num_threads=num_threads) + w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] - scales = median(y * W, axis=0, num_threads=num_threads) / \ - median(W * W, axis=0, num_threads=num_threads) - print(f"{scales=}") - scales=np.ravel(scales) + scales = median(y * w, axis=0, num_threads=num_threads) / \ + median(w * w, axis=0, num_threads=num_threads) whitefield *= scales def _calc_streakfinder_analysis(results, snr, mask): @@ -173,19 +166,19 @@ def _calc_streakfinder_analysis(results, snr, mask): peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) - peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin, + if crop_roi is None: + region = snr + else: + region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] + peaks = detect_peaks(region, mask, peaks_structure.rank, peak_vmin, num_threads=num_threads) - peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts, + peaks = filter_peaks(peaks, region, mask, peaks_structure, peak_vmin, npts, num_threads=num_threads) - # det_obj = cryst_data.streak_detector(streaks_structure) - - - # peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) - # detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, - # num_threads=num_threads) - detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size, + detected = detect_streaks(peaks, region, mask, streaks_structure, xtol, streak_vmin, min_size, lookahead, nfa, num_threads=num_threads) + del peaks + if isinstance(detected, list): detected = detected[0] @@ -197,17 +190,9 @@ def _calc_streakfinder_analysis(results, snr, mask): results["bragg_counts"] = [] return - - if isinstance(detected, list): - streaks = [np.asarray(pattern.to_lines()) for pattern in detected] - else: - streaks = [np.asarray(detected.to_lines()),] - streak_lines = np.concatenate(streaks) - # rstreaks = Streaks(index=IndexArray(idxs), lines=lines) - - # streaks = det_obj.to_streaks(detected) + streak_lines = detected.to_lines() detected_streaks = np.array(detected.streaks) - # streak_lines = streaks.lines + del detected # Adjust to crop region if crop_roi is not None: @@ -215,19 +200,17 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = streak_lines + shift if x_center is not None and y_center is not None: - # if crop_roi is not None: - # x_center -= crop_roi[0] - # y_center -= crop_roi[2] + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] threshold = 0.33 - centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1) - norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1], + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) r = centers -np.asarray([x_center, y_center]) - prod =np.sum(norm * r, axis=-1)[..., None] + prod = np.sum(norm * r, axis=-1)[..., None] proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] - streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold - # return mask - # streaks_mask = streaks.concentric_only(x_center, y_center) + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold streak_lines = streak_lines[streaks_mask] detected_streaks = detected_streaks[streaks_mask] @@ -240,9 +223,11 @@ def _calc_streakfinder_analysis(results, snr, mask): _, number_of_streaks = streak_lines.shape list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1 - bragg_counts = [streak.total_mass() for streak in detected_streaks] + del detected_streaks + del streak_lines + results["number_of_streaks"] = number_of_streaks results["is_hit_frame"] = (number_of_streaks > min_hit_streaks) results["streaks"] = list_result -- 2.49.1 From b2cc1e27a475698b3d1762c38e6dba6a19e3962b Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 23:10:46 +0200 Subject: [PATCH 12/15] TEST: crop data AND mask --- dap/algos/streakfind.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 981a214..6e7ecf6 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -168,14 +168,16 @@ def _calc_streakfinder_analysis(results, snr, mask): if crop_roi is None: region = snr + region_mask = mask else: region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] - peaks = detect_peaks(region, mask, peaks_structure.rank, peak_vmin, + region_mask = mask[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] + peaks = detect_peaks(region, region_mask, peaks_structure.rank, peak_vmin, num_threads=num_threads) - peaks = filter_peaks(peaks, region, mask, peaks_structure, peak_vmin, npts, + peaks = filter_peaks(peaks, region, region_mask, peaks_structure, peak_vmin, npts, num_threads=num_threads) - detected = detect_streaks(peaks, region, mask, streaks_structure, xtol, streak_vmin, min_size, + detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size, lookahead, nfa, num_threads=num_threads) del peaks -- 2.49.1 From ad0752aa92e7acb30ee24daeaeb948f20036e8d7 Mon Sep 17 00:00:00 2001 From: "Dorofeeva Elizaveta (EXT)" Date: Mon, 14 Jul 2025 23:35:26 +0200 Subject: [PATCH 13/15] Correct region mask --- dap/algos/streakfind.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 6e7ecf6..cd892f0 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -170,8 +170,8 @@ def _calc_streakfinder_analysis(results, snr, mask): region = snr region_mask = mask else: - region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] - region_mask = mask[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] + region = snr[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]] + region_mask = mask[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]] peaks = detect_peaks(region, region_mask, peaks_structure.rank, peak_vmin, num_threads=num_threads) peaks = filter_peaks(peaks, region, region_mask, peaks_structure, peak_vmin, npts, -- 2.49.1 From 118b97acfbdbf769e99370c83cf3afb472c15194 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Tue, 15 Jul 2025 08:35:08 +0200 Subject: [PATCH 14/15] Restore mask ROI fuctionality; Concentric streaks filter in a function; Cleanup --- README.md | 3 ++- dap/algos/streakfind.py | 59 ++++++++++++++++++----------------------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 227ba92..580dd62 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,9 @@ options: * `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm. * `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle). * `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5. - * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding. + * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding. * `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup. + * `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak. Algorithm Output: * `'number_of_streaks': int` - Indicates the count of identified streaks. diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index cd892f0..8ee069c 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -11,7 +11,7 @@ import numpy as np from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks -from streak_finder._src.src.median import robust_lsq, median +from streak_finder._src.src.median import median DEFAULT_MIN_HIT_STREAKS = 5 DEFAULT_NUM_THREADS = 16 @@ -97,9 +97,6 @@ def _calc_snr(results, data, pf_pixel_mask): out=np.zeros_like(data), where=(std!=0.0) ) - del whitefield - del std - del mask return snr def _scale_whitefield(data, mask, whitefield, std, num_threads): @@ -110,7 +107,21 @@ def _scale_whitefield(data, mask, whitefield, std, num_threads): scales = median(y * w, axis=0, num_threads=num_threads) / \ median(w * w, axis=0, num_threads=num_threads) whitefield *= scales - + +def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33): + if x_center is not None and y_center is not None: + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers - np.asarray([x_center, y_center]) + prod = np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + return streaks_mask + def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: @@ -152,16 +163,13 @@ def _calc_streakfinder_analysis(results, snr, mask): x_center = results.get("beam_center_x", None) y_center = results.get("beam_center_y", None) - # mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] - # - # for mask_roi in mask_rois: - # cryst_data = cryst_data.mask_region(mask_roi) + mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max] - crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max] + for mask_roi in mask_rois: + mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False + + crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max] lookahead = results.get("cbd_lookahead", 1) - # if crop_roi is not None: - # crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1 - # cryst_data = cryst_data.crop(crop_roi_t) peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) @@ -179,7 +187,6 @@ def _calc_streakfinder_analysis(results, snr, mask): detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size, lookahead, nfa, num_threads=num_threads) - del peaks if isinstance(detected, list): detected = detected[0] @@ -194,27 +201,16 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = detected.to_lines() detected_streaks = np.array(detected.streaks) - del detected # Adjust to crop region if crop_roi is not None: shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]] streak_lines = streak_lines + shift - if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - threshold = 0.33 - centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) - norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], - streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) - r = centers -np.asarray([x_center, y_center]) - prod = np.sum(norm * r, axis=-1)[..., None] - proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] - streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold - streak_lines = streak_lines[streaks_mask] - detected_streaks = detected_streaks[streaks_mask] + concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines) + if concentric_streaks_mask is not None: + streak_lines = streak_lines[concentric_streaks_mask] + detected_streaks = detected_streaks[concentric_streaks_mask] streak_lengths = np.sqrt( np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) + @@ -224,12 +220,9 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = streak_lines.T _, number_of_streaks = streak_lines.shape - list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1 + list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1 bragg_counts = [streak.total_mass() for streak in detected_streaks] - del detected_streaks - del streak_lines - results["number_of_streaks"] = number_of_streaks results["is_hit_frame"] = (number_of_streaks > min_hit_streaks) results["streaks"] = list_result -- 2.49.1 From abae2288e793bda74a8bc6c33ac744ce48a305a0 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Tue, 15 Jul 2025 08:50:35 +0200 Subject: [PATCH 15/15] Final cleanup --- dap/algos/streakfind.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 8ee069c..0c3a7f6 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -8,7 +8,6 @@ https://github.com/simply-nicky/streak_finder/ """ import h5py import numpy as np -from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks from streak_finder._src.src.median import median @@ -52,7 +51,6 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): _calc_streakfinder_analysis(results, snr, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}" - print(f"mem used after {virtual_memory().used // pow(2, 30)} Gb \n") return data @@ -142,8 +140,8 @@ def _calc_streakfinder_analysis(results, snr, mask): ] if not all(param in results.keys() for param in params_required): - print(f"ERROR: Not enough parameters for streak finder analysis. Skipping.\n" - f"{params_required=}") + results["cbd_error"] = (f"ERROR: Not enough parameters for streak finder analysis. " + f"Skipping.\n{params_required=}") return peak_structure_radius = results["cbd_peak_structure_radius"] # peak @@ -164,7 +162,6 @@ def _calc_streakfinder_analysis(results, snr, mask): y_center = results.get("beam_center_y", None) mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max] - for mask_roi in mask_rois: mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False -- 2.49.1