Fix memory leak in streak finder #3

Merged
augustin_s merged 15 commits from ext-dorofe_e/dap:cbd_bugfix into main 2025-07-15 10:21:21 +02:00
2 changed files with 28 additions and 34 deletions
Showing only changes of commit 118b97acfb - Show all commits

View File

@@ -136,8 +136,9 @@ options:
* `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm.
* `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle).
* `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5.
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding.
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding.
* `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup.
* `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak.
Algorithm Output:
* `'number_of_streaks': int` - Indicates the count of identified streaks.

View File

@@ -11,7 +11,7 @@ import numpy as np
from psutil import virtual_memory
from streak_finder.label import Structure2D
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
from streak_finder._src.src.median import robust_lsq, median
from streak_finder._src.src.median import median
DEFAULT_MIN_HIT_STREAKS = 5
DEFAULT_NUM_THREADS = 16
@@ -97,9 +97,6 @@ def _calc_snr(results, data, pf_pixel_mask):
out=np.zeros_like(data),
where=(std!=0.0)
)
del whitefield
del std
del mask
return snr
ext-dorofe_e marked this conversation as resolved Outdated

TODO: dels are likely unnecessary; run more tests to make sure

TODO: `del`s are likely unnecessary; run more tests to make sure
def _scale_whitefield(data, mask, whitefield, std, num_threads):
@@ -110,7 +107,21 @@ def _scale_whitefield(data, mask, whitefield, std, num_threads):
scales = median(y * w, axis=0, num_threads=num_threads) / \
median(w * w, axis=0, num_threads=num_threads)
whitefield *= scales
def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33):
if x_center is not None and y_center is not None:
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
r = centers - np.asarray([x_center, y_center])
prod = np.sum(norm * r, axis=-1)[..., None]
proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None]
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
return streaks_mask
def _calc_streakfinder_analysis(results, snr, mask):
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
if not do_streakfinder_analysis:
@@ -152,16 +163,13 @@ def _calc_streakfinder_analysis(results, snr, mask):
x_center = results.get("beam_center_x", None)
y_center = results.get("beam_center_y", None)
# mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
#
# for mask_roi in mask_rois:
# cryst_data = cryst_data.mask_region(mask_roi)
mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max]
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
for mask_roi in mask_rois:
mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False
crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max]
lookahead = results.get("cbd_lookahead", 1)
# if crop_roi is not None:
# crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1
# cryst_data = cryst_data.crop(crop_roi_t)
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
@@ -179,7 +187,6 @@ def _calc_streakfinder_analysis(results, snr, mask):
detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size,
lookahead, nfa, num_threads=num_threads)
del peaks
if isinstance(detected, list):
detected = detected[0]
@@ -194,27 +201,16 @@ def _calc_streakfinder_analysis(results, snr, mask):
streak_lines = detected.to_lines()
detected_streaks = np.array(detected.streaks)
del detected
# Adjust to crop region
if crop_roi is not None:
shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]]
streak_lines = streak_lines + shift
ext-dorofe_e marked this conversation as resolved Outdated

TODO: move concentric filter to a separate method

TODO: move concentric filter to a separate method
if x_center is not None and y_center is not None:
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
threshold = 0.33
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
r = centers -np.asarray([x_center, y_center])
prod = np.sum(norm * r, axis=-1)[..., None]
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
streak_lines = streak_lines[streaks_mask]
detected_streaks = detected_streaks[streaks_mask]
concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines)
if concentric_streaks_mask is not None:
streak_lines = streak_lines[concentric_streaks_mask]
detected_streaks = detected_streaks[concentric_streaks_mask]
streak_lengths = np.sqrt(
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
@@ -224,12 +220,9 @@ def _calc_streakfinder_analysis(results, snr, mask):
streak_lines = streak_lines.T
_, number_of_streaks = streak_lines.shape
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1
bragg_counts = [streak.total_mass() for streak in detected_streaks]
del detected_streaks
del streak_lines
results["number_of_streaks"] = number_of_streaks
results["is_hit_frame"] = (number_of_streaks > min_hit_streaks)
results["streaks"] = list_result