From 77d3ff9293e092ed783c5b61fddb03ef981c42d4 Mon Sep 17 00:00:00 2001 From: Lisa Dorofeeva Date: Mon, 14 Jul 2025 22:07:08 +0200 Subject: [PATCH] TEST scale whitefield --- dap/algos/streakfind.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 0df56b3..658d6fc 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -11,8 +11,8 @@ import numpy as np from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks +from streak_finder._src.src.median import robust_lsq -DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 @@ -38,27 +38,24 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): _handle_negative_values(data, pf_pixel_mask, negative_val_handler) try: - cryst_data = _generate_cryst_data(results, data, pf_pixel_mask) + snr = _calc_snr(results, data, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" + results["cbd_error"] = f"SNR - Error processing CBD data:\n{error}" return data if do_snr: # Changes data in-place - data = cryst_data #.snr[0].copy() + data = snr try: - _calc_streakfinder_analysis(results, cryst_data, pf_pixel_mask) + _calc_streakfinder_analysis(results, snr, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" - # print(f"Deleting cryst data;\nmem used before {virtual_memory().used // pow(2, 30)} Gb ") - # cryst_data.clear() - # del cryst_data + results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}" print(f"mem used after {virtual_memory().used // pow(2, 30)} Gb \n") return data -def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: +def _calc_snr(results, data, pf_pixel_mask): params_required = [ "cbd_whitefield_data_file", "cbd_std_data_file", @@ -71,14 +68,12 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: whitefield_data_file = results["cbd_whitefield_data_file"] std_data_file = results["cbd_std_data_file"] - # scale_whitefield = results["cbd_scale_whitefield"] + scale_whitefield = results["cbd_scale_whitefield"] # Using CXI Store specification as default whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield") std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std") - # num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS) - with h5py.File(whitefield_data_file, "r") as hf: whitefield = hf[whitefield_dataset][:] @@ -93,7 +88,8 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: with h5py.File(mask_data_file, "r") as hf: mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask - + if scale_whitefield: + _scale_whitefield(data, mask, whitefield, std) snr = np.divide( data * mask - whitefield, std, @@ -105,6 +101,19 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData: del mask return snr +def _scale_whitefield(data, mask, whitefield, std, + r0: float = 0.0, r1: float = 0.5, + n_iter: int = 12, lm: float = 9.0, num_threads: int = 1 + ): + mask = mask & (std > 0.0) + y = np.where(mask, data / std, 0.0)[mask] # must be newaxis + W = np.where(mask, whitefield / std, 0.0)[mask] # must be newaxis + + scales = robust_lsq(W=W[np.newaxis, :], y=y[np.newaxis, :], axis=1, r0=r0, r1=r1, n_iter=n_iter, lm=lm, + num_threads=num_threads) + scales=np.ravel(scales) + whitefield *= scales + def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: