diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index 8023960..981a214 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -14,6 +14,7 @@ from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_pea from streak_finder._src.src.median import robust_lsq, median DEFAULT_MIN_HIT_STREAKS = 5 +DEFAULT_NUM_THREADS = 16 def _handle_negative_values(data, mask, handler: str): @@ -89,7 +90,7 @@ def _calc_snr(results, data, pf_pixel_mask): mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask if scale_whitefield: - _scale_whitefield(data, mask, whitefield, std) + _scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS)) snr = np.divide( data * mask - whitefield, std, @@ -101,21 +102,13 @@ def _calc_snr(results, data, pf_pixel_mask): del mask return snr -def _scale_whitefield(data, mask, whitefield, std, - r0: float = 0.0, r1: float = 0.5, - n_iter: int = 12, lm: float = 9.0, num_threads: int = 16 - ): +def _scale_whitefield(data, mask, whitefield, std, num_threads): mask = mask & (std > 0.0) y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] - W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] - # - # scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm, - # num_threads=num_threads) + w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] - scales = median(y * W, axis=0, num_threads=num_threads) / \ - median(W * W, axis=0, num_threads=num_threads) - print(f"{scales=}") - scales=np.ravel(scales) + scales = median(y * w, axis=0, num_threads=num_threads) / \ + median(w * w, axis=0, num_threads=num_threads) whitefield *= scales def _calc_streakfinder_analysis(results, snr, mask): @@ -173,19 +166,19 @@ def _calc_streakfinder_analysis(results, snr, mask): peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) - peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin, + if crop_roi is None: + region = snr + else: + region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]] + peaks = detect_peaks(region, mask, peaks_structure.rank, peak_vmin, num_threads=num_threads) - peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts, + peaks = filter_peaks(peaks, region, mask, peaks_structure, peak_vmin, npts, num_threads=num_threads) - # det_obj = cryst_data.streak_detector(streaks_structure) - - - # peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) - # detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, - # num_threads=num_threads) - detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size, + detected = detect_streaks(peaks, region, mask, streaks_structure, xtol, streak_vmin, min_size, lookahead, nfa, num_threads=num_threads) + del peaks + if isinstance(detected, list): detected = detected[0] @@ -197,17 +190,9 @@ def _calc_streakfinder_analysis(results, snr, mask): results["bragg_counts"] = [] return - - if isinstance(detected, list): - streaks = [np.asarray(pattern.to_lines()) for pattern in detected] - else: - streaks = [np.asarray(detected.to_lines()),] - streak_lines = np.concatenate(streaks) - # rstreaks = Streaks(index=IndexArray(idxs), lines=lines) - - # streaks = det_obj.to_streaks(detected) + streak_lines = detected.to_lines() detected_streaks = np.array(detected.streaks) - # streak_lines = streaks.lines + del detected # Adjust to crop region if crop_roi is not None: @@ -215,19 +200,17 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = streak_lines + shift if x_center is not None and y_center is not None: - # if crop_roi is not None: - # x_center -= crop_roi[0] - # y_center -= crop_roi[2] + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] threshold = 0.33 - centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1) - norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1], + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) r = centers -np.asarray([x_center, y_center]) - prod =np.sum(norm * r, axis=-1)[..., None] + prod = np.sum(norm * r, axis=-1)[..., None] proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] - streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold - # return mask - # streaks_mask = streaks.concentric_only(x_center, y_center) + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold streak_lines = streak_lines[streaks_mask] detected_streaks = detected_streaks[streaks_mask] @@ -240,9 +223,11 @@ def _calc_streakfinder_analysis(results, snr, mask): _, number_of_streaks = streak_lines.shape list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1 - bragg_counts = [streak.total_mass() for streak in detected_streaks] + del detected_streaks + del streak_lines + results["number_of_streaks"] = number_of_streaks results["is_hit_frame"] = (number_of_streaks > min_hit_streaks) results["streaks"] = list_result