Fix memory leak in streak finder #3
@@ -136,8 +136,9 @@ options:
|
||||
* `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm.
|
||||
* `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle).
|
||||
* `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5.
|
||||
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding.
|
||||
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding.
|
||||
* `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup.
|
||||
* `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak.
|
||||
|
||||
Algorithm Output:
|
||||
* `'number_of_streaks': int` - Indicates the count of identified streaks.
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
from psutil import virtual_memory
|
||||
from streak_finder.label import Structure2D
|
||||
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
|
||||
from streak_finder._src.src.median import robust_lsq, median
|
||||
from streak_finder._src.src.median import median
|
||||
|
||||
DEFAULT_MIN_HIT_STREAKS = 5
|
||||
DEFAULT_NUM_THREADS = 16
|
||||
@@ -97,9 +97,6 @@ def _calc_snr(results, data, pf_pixel_mask):
|
||||
out=np.zeros_like(data),
|
||||
where=(std!=0.0)
|
||||
)
|
||||
del whitefield
|
||||
del std
|
||||
del mask
|
||||
return snr
|
||||
|
ext-dorofe_e marked this conversation as resolved
Outdated
|
||||
|
||||
def _scale_whitefield(data, mask, whitefield, std, num_threads):
|
||||
@@ -110,7 +107,21 @@ def _scale_whitefield(data, mask, whitefield, std, num_threads):
|
||||
scales = median(y * w, axis=0, num_threads=num_threads) / \
|
||||
median(w * w, axis=0, num_threads=num_threads)
|
||||
whitefield *= scales
|
||||
|
||||
|
||||
def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33):
|
||||
if x_center is not None and y_center is not None:
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
|
||||
r = centers - np.asarray([x_center, y_center])
|
||||
prod = np.sum(norm * r, axis=-1)[..., None]
|
||||
proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None]
|
||||
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
return streaks_mask
|
||||
|
||||
def _calc_streakfinder_analysis(results, snr, mask):
|
||||
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
|
||||
if not do_streakfinder_analysis:
|
||||
@@ -152,16 +163,13 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
x_center = results.get("beam_center_x", None)
|
||||
y_center = results.get("beam_center_y", None)
|
||||
|
||||
# mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
|
||||
#
|
||||
# for mask_roi in mask_rois:
|
||||
# cryst_data = cryst_data.mask_region(mask_roi)
|
||||
mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max]
|
||||
|
||||
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
|
||||
for mask_roi in mask_rois:
|
||||
mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False
|
||||
|
||||
crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max]
|
||||
lookahead = results.get("cbd_lookahead", 1)
|
||||
# if crop_roi is not None:
|
||||
# crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1
|
||||
# cryst_data = cryst_data.crop(crop_roi_t)
|
||||
|
||||
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
|
||||
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
|
||||
@@ -179,7 +187,6 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
|
||||
detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size,
|
||||
lookahead, nfa, num_threads=num_threads)
|
||||
del peaks
|
||||
|
||||
if isinstance(detected, list):
|
||||
detected = detected[0]
|
||||
@@ -194,27 +201,16 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
|
||||
streak_lines = detected.to_lines()
|
||||
detected_streaks = np.array(detected.streaks)
|
||||
del detected
|
||||
|
||||
# Adjust to crop region
|
||||
if crop_roi is not None:
|
||||
shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]]
|
||||
streak_lines = streak_lines + shift
|
||||
|
ext-dorofe_e marked this conversation as resolved
Outdated
ext-dorofe_e
commented
TODO: move concentric filter to a separate method TODO: move concentric filter to a separate method
|
||||
|
||||
if x_center is not None and y_center is not None:
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
threshold = 0.33
|
||||
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
|
||||
r = centers -np.asarray([x_center, y_center])
|
||||
prod = np.sum(norm * r, axis=-1)[..., None]
|
||||
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
|
||||
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
streak_lines = streak_lines[streaks_mask]
|
||||
detected_streaks = detected_streaks[streaks_mask]
|
||||
concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines)
|
||||
if concentric_streaks_mask is not None:
|
||||
streak_lines = streak_lines[concentric_streaks_mask]
|
||||
detected_streaks = detected_streaks[concentric_streaks_mask]
|
||||
|
||||
streak_lengths = np.sqrt(
|
||||
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
|
||||
@@ -224,12 +220,9 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
streak_lines = streak_lines.T
|
||||
_, number_of_streaks = streak_lines.shape
|
||||
|
||||
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
|
||||
list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1
|
||||
bragg_counts = [streak.total_mass() for streak in detected_streaks]
|
||||
|
||||
del detected_streaks
|
||||
del streak_lines
|
||||
|
||||
results["number_of_streaks"] = number_of_streaks
|
||||
results["is_hit_frame"] = (number_of_streaks > min_hit_streaks)
|
||||
results["streaks"] = list_result
|
||||
|
||||
Reference in New Issue
Block a user
TODO:
dels are likely unnecessary; run more tests to make sure