Fix memory leak in streak finder #3

Merged
augustin_s merged 15 commits from ext-dorofe_e/dap:cbd_bugfix into main 2025-07-15 10:21:21 +02:00
3 changed files with 74 additions and 55 deletions

View File

@@ -136,8 +136,9 @@ options:
* `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm.
* `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle).
* `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5.
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding.
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding.
* `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup.
* `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak.
Algorithm Output:
* `'number_of_streaks': int` - Indicates the count of identified streaks.

View File

@@ -18,14 +18,14 @@ def calc_apply_additional_mask_from_file(results, pixel_mask_pf):
# Support for hdf5 and npy
if mask_file.endswith(".npy"):
try:
mask = np.load(mask_file).astype(np.bool)
mask = np.load(mask_file).astype(bool)
except Exception as error:
results["mask_error"] = f"Error loading mask data from NumPy file {mask_file}:\n{error}"
return
else:
try:
with h5py.File(mask_file, "r") as h5f:
mask = h5f[mask_dataset][:].astype(np.bool)
mask = h5f[mask_dataset][:].astype(bool)
except Exception as error:
results["mask_error"] = f"Error loading mask from hdf5 file {mask_file}:\n{error}"
return

View File

@@ -8,11 +8,12 @@ https://github.com/simply-nicky/streak_finder/
"""
import h5py
import numpy as np
from streak_finder import CrystData
from streak_finder.label import Structure2D
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
from streak_finder._src.src.median import median
DEFAULT_NUM_THREADS = 16
DEFAULT_MIN_HIT_STREAKS = 5
DEFAULT_NUM_THREADS = 16
def _handle_negative_values(data, mask, handler: str):
@@ -37,23 +38,23 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask):
_handle_negative_values(data, pf_pixel_mask, negative_val_handler)
try:
cryst_data = _generate_cryst_data(results, data, pf_pixel_mask)
snr = _calc_snr(results, data, pf_pixel_mask)
except Exception as error: # Broad exception - we don't want to break anything here
results["cbd_error"] = f"Error processing CBD data:\n{error}"
results["cbd_error"] = f"SNR - Error processing CBD data:\n{error}"
return data
if do_snr:
# Changes data and mask in-place
data = cryst_data.snr[0].copy()
# Changes data in-place
data = snr
try:
_calc_streakfinder_analysis(results, cryst_data)
_calc_streakfinder_analysis(results, snr, pf_pixel_mask)
except Exception as error: # Broad exception - we don't want to break anything here
results["cbd_error"] = f"Error processing CBD data:\n{error}"
results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}"
return data
def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
def _calc_snr(results, data, pf_pixel_mask):
params_required = [
"cbd_whitefield_data_file",
"cbd_std_data_file",
@@ -72,8 +73,6 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield")
std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std")
num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS)
with h5py.File(whitefield_data_file, "r") as hf:
whitefield = hf[whitefield_dataset][:]
@@ -86,22 +85,42 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
else:
mask_dataset = results.get("cbd_mask_dataset", "entry/instrument/detector/mask")
with h5py.File(mask_data_file, "r") as hf:
mask = hf[mask_dataset][:].astype(np.bool)
mask = hf[mask_dataset][:].astype(bool)
mask *= pf_pixel_mask
data = CrystData(
data=data[np.newaxis, :],
mask=mask,
std=std,
whitefield=whitefield
)
if scale_whitefield:
data = data.scale_whitefield(method='median', num_threads=num_threads)
data = data.update_snr()
return data
_scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS))
snr = np.divide(
data * mask - whitefield,
std,
out=np.zeros_like(data),
where=(std!=0.0)
)
return snr
def _calc_streakfinder_analysis(results, cryst_data: CrystData):
def _scale_whitefield(data, mask, whitefield, std, num_threads):
ext-dorofe_e marked this conversation as resolved Outdated

TODO: dels are likely unnecessary; run more tests to make sure

TODO: `del`s are likely unnecessary; run more tests to make sure
mask = mask & (std > 0.0)
y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask]
w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
scales = median(y * w, axis=0, num_threads=num_threads) / \
median(w * w, axis=0, num_threads=num_threads)
whitefield *= scales
def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33):
if x_center is not None and y_center is not None:
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
r = centers - np.asarray([x_center, y_center])
prod = np.sum(norm * r, axis=-1)[..., None]
proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None]
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
return streaks_mask
def _calc_streakfinder_analysis(results, snr, mask):
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
if not do_streakfinder_analysis:
return
@@ -121,8 +140,8 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
]
if not all(param in results.keys() for param in params_required):
print(f"ERROR: Not enough parameters for streak finder analysis. Skipping.\n"
f"{params_required=}")
results["cbd_error"] = (f"ERROR: Not enough parameters for streak finder analysis. "
f"Skipping.\n{params_required=}")
return
peak_structure_radius = results["cbd_peak_structure_radius"] # peak
@@ -142,25 +161,29 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
x_center = results.get("beam_center_x", None)
y_center = results.get("beam_center_y", None)
mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max]
for mask_roi in mask_rois:
cryst_data = cryst_data.mask_region(mask_roi)
mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
if crop_roi is not None:
crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1
cryst_data = cryst_data.crop(crop_roi_t)
crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max]
lookahead = results.get("cbd_lookahead", 1)
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
det_obj = cryst_data.streak_detector(streaks_structure)
peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
num_threads=num_threads)
if crop_roi is None:
region = snr
region_mask = mask
else:
region = snr[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]]
region_mask = mask[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]]
peaks = detect_peaks(region, region_mask, peaks_structure.rank, peak_vmin,
num_threads=num_threads)
peaks = filter_peaks(peaks, region, region_mask, peaks_structure, peak_vmin, npts,
num_threads=num_threads)
detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size,
lookahead, nfa, num_threads=num_threads)
if isinstance(detected, list):
detected = detected[0]
@@ -173,33 +196,28 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
results["bragg_counts"] = []
return
streaks = det_obj.to_streaks(detected)
streak_lines = detected.to_lines()
detected_streaks = np.array(detected.streaks)
streak_lines = streaks.lines
# Adjust to crop region
if crop_roi is not None:
shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]]
streak_lines = streak_lines + shift
if x_center is not None and y_center is not None:
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
streaks_mask = streaks.concentric_only(x_center, y_center)
streak_lines = streak_lines[streaks_mask]
detected_streaks = detected_streaks[streaks_mask]
concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines)
if concentric_streaks_mask is not None:
ext-dorofe_e marked this conversation as resolved Outdated

TODO: move concentric filter to a separate method

TODO: move concentric filter to a separate method
streak_lines = streak_lines[concentric_streaks_mask]
detected_streaks = detected_streaks[concentric_streaks_mask]
streak_lengths = np.sqrt(
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2)
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2)
).tolist()
streak_lines = streak_lines.T
_, number_of_streaks = streak_lines.shape
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1
bragg_counts = [streak.total_mass() for streak in detected_streaks]
results["number_of_streaks"] = number_of_streaks