diff --git a/README.md b/README.md index 227ba92..580dd62 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,9 @@ options: * `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm. * `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle). * `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5. - * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding. + * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding. * `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup. + * `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak. Algorithm Output: * `'number_of_streaks': int` - Indicates the count of identified streaks. diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index cd892f0..8ee069c 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -11,7 +11,7 @@ import numpy as np from psutil import virtual_memory from streak_finder.label import Structure2D from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks -from streak_finder._src.src.median import robust_lsq, median +from streak_finder._src.src.median import median DEFAULT_MIN_HIT_STREAKS = 5 DEFAULT_NUM_THREADS = 16 @@ -97,9 +97,6 @@ def _calc_snr(results, data, pf_pixel_mask): out=np.zeros_like(data), where=(std!=0.0) ) - del whitefield - del std - del mask return snr def _scale_whitefield(data, mask, whitefield, std, num_threads): @@ -110,7 +107,21 @@ def _scale_whitefield(data, mask, whitefield, std, num_threads): scales = median(y * w, axis=0, num_threads=num_threads) / \ median(w * w, axis=0, num_threads=num_threads) whitefield *= scales - + +def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33): + if x_center is not None and y_center is not None: + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers - np.asarray([x_center, y_center]) + prod = np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + return streaks_mask + def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: @@ -152,16 +163,13 @@ def _calc_streakfinder_analysis(results, snr, mask): x_center = results.get("beam_center_x", None) y_center = results.get("beam_center_y", None) - # mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] - # - # for mask_roi in mask_rois: - # cryst_data = cryst_data.mask_region(mask_roi) + mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max] - crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max] + for mask_roi in mask_rois: + mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False + + crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max] lookahead = results.get("cbd_lookahead", 1) - # if crop_roi is not None: - # crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1 - # cryst_data = cryst_data.crop(crop_roi_t) peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) @@ -179,7 +187,6 @@ def _calc_streakfinder_analysis(results, snr, mask): detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size, lookahead, nfa, num_threads=num_threads) - del peaks if isinstance(detected, list): detected = detected[0] @@ -194,27 +201,16 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = detected.to_lines() detected_streaks = np.array(detected.streaks) - del detected # Adjust to crop region if crop_roi is not None: shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]] streak_lines = streak_lines + shift - if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - threshold = 0.33 - centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) - norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], - streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) - r = centers -np.asarray([x_center, y_center]) - prod = np.sum(norm * r, axis=-1)[..., None] - proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None] - streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold - streak_lines = streak_lines[streaks_mask] - detected_streaks = detected_streaks[streaks_mask] + concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines) + if concentric_streaks_mask is not None: + streak_lines = streak_lines[concentric_streaks_mask] + detected_streaks = detected_streaks[concentric_streaks_mask] streak_lengths = np.sqrt( np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) + @@ -224,12 +220,9 @@ def _calc_streakfinder_analysis(results, snr, mask): streak_lines = streak_lines.T _, number_of_streaks = streak_lines.shape - list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1 + list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1 bragg_counts = [streak.total_mass() for streak in detected_streaks] - del detected_streaks - del streak_lines - results["number_of_streaks"] = number_of_streaks results["is_hit_frame"] = (number_of_streaks > min_hit_streaks) results["streaks"] = list_result