Fix memory leak in streak finder #3
@@ -136,8 +136,9 @@ options:
|
||||
* `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm.
|
||||
* `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle).
|
||||
* `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5.
|
||||
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding.
|
||||
* `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding.
|
||||
* `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup.
|
||||
* `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak.
|
||||
|
||||
Algorithm Output:
|
||||
* `'number_of_streaks': int` - Indicates the count of identified streaks.
|
||||
|
||||
@@ -18,14 +18,14 @@ def calc_apply_additional_mask_from_file(results, pixel_mask_pf):
|
||||
# Support for hdf5 and npy
|
||||
if mask_file.endswith(".npy"):
|
||||
try:
|
||||
mask = np.load(mask_file).astype(np.bool)
|
||||
mask = np.load(mask_file).astype(bool)
|
||||
except Exception as error:
|
||||
results["mask_error"] = f"Error loading mask data from NumPy file {mask_file}:\n{error}"
|
||||
return
|
||||
else:
|
||||
try:
|
||||
with h5py.File(mask_file, "r") as h5f:
|
||||
mask = h5f[mask_dataset][:].astype(np.bool)
|
||||
mask = h5f[mask_dataset][:].astype(bool)
|
||||
except Exception as error:
|
||||
results["mask_error"] = f"Error loading mask from hdf5 file {mask_file}:\n{error}"
|
||||
return
|
||||
|
||||
@@ -8,11 +8,12 @@ https://github.com/simply-nicky/streak_finder/
|
||||
"""
|
||||
import h5py
|
||||
import numpy as np
|
||||
from streak_finder import CrystData
|
||||
from streak_finder.label import Structure2D
|
||||
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
|
||||
from streak_finder._src.src.median import median
|
||||
|
||||
DEFAULT_NUM_THREADS = 16
|
||||
DEFAULT_MIN_HIT_STREAKS = 5
|
||||
DEFAULT_NUM_THREADS = 16
|
||||
|
||||
|
||||
def _handle_negative_values(data, mask, handler: str):
|
||||
@@ -37,23 +38,23 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask):
|
||||
_handle_negative_values(data, pf_pixel_mask, negative_val_handler)
|
||||
|
||||
try:
|
||||
cryst_data = _generate_cryst_data(results, data, pf_pixel_mask)
|
||||
snr = _calc_snr(results, data, pf_pixel_mask)
|
||||
except Exception as error: # Broad exception - we don't want to break anything here
|
||||
results["cbd_error"] = f"Error processing CBD data:\n{error}"
|
||||
results["cbd_error"] = f"SNR - Error processing CBD data:\n{error}"
|
||||
return data
|
||||
|
||||
if do_snr:
|
||||
# Changes data and mask in-place
|
||||
data = cryst_data.snr[0].copy()
|
||||
# Changes data in-place
|
||||
data = snr
|
||||
|
||||
try:
|
||||
_calc_streakfinder_analysis(results, cryst_data)
|
||||
_calc_streakfinder_analysis(results, snr, pf_pixel_mask)
|
||||
except Exception as error: # Broad exception - we don't want to break anything here
|
||||
results["cbd_error"] = f"Error processing CBD data:\n{error}"
|
||||
results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}"
|
||||
|
||||
return data
|
||||
|
||||
def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
|
||||
def _calc_snr(results, data, pf_pixel_mask):
|
||||
params_required = [
|
||||
"cbd_whitefield_data_file",
|
||||
"cbd_std_data_file",
|
||||
@@ -72,8 +73,6 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
|
||||
whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield")
|
||||
std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std")
|
||||
|
||||
num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS)
|
||||
|
||||
with h5py.File(whitefield_data_file, "r") as hf:
|
||||
whitefield = hf[whitefield_dataset][:]
|
||||
|
||||
@@ -86,22 +85,42 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
|
||||
else:
|
||||
mask_dataset = results.get("cbd_mask_dataset", "entry/instrument/detector/mask")
|
||||
with h5py.File(mask_data_file, "r") as hf:
|
||||
mask = hf[mask_dataset][:].astype(np.bool)
|
||||
mask = hf[mask_dataset][:].astype(bool)
|
||||
mask *= pf_pixel_mask
|
||||
|
||||
data = CrystData(
|
||||
data=data[np.newaxis, :],
|
||||
mask=mask,
|
||||
std=std,
|
||||
whitefield=whitefield
|
||||
)
|
||||
if scale_whitefield:
|
||||
data = data.scale_whitefield(method='median', num_threads=num_threads)
|
||||
|
||||
data = data.update_snr()
|
||||
return data
|
||||
_scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS))
|
||||
snr = np.divide(
|
||||
data * mask - whitefield,
|
||||
std,
|
||||
out=np.zeros_like(data),
|
||||
where=(std!=0.0)
|
||||
)
|
||||
return snr
|
||||
|
||||
def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
def _scale_whitefield(data, mask, whitefield, std, num_threads):
|
||||
|
ext-dorofe_e marked this conversation as resolved
Outdated
|
||||
mask = mask & (std > 0.0)
|
||||
y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask]
|
||||
w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
|
||||
|
||||
scales = median(y * w, axis=0, num_threads=num_threads) / \
|
||||
median(w * w, axis=0, num_threads=num_threads)
|
||||
whitefield *= scales
|
||||
|
||||
def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33):
|
||||
if x_center is not None and y_center is not None:
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
|
||||
r = centers - np.asarray([x_center, y_center])
|
||||
prod = np.sum(norm * r, axis=-1)[..., None]
|
||||
proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None]
|
||||
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
return streaks_mask
|
||||
|
||||
def _calc_streakfinder_analysis(results, snr, mask):
|
||||
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
|
||||
if not do_streakfinder_analysis:
|
||||
return
|
||||
@@ -121,8 +140,8 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
]
|
||||
|
||||
if not all(param in results.keys() for param in params_required):
|
||||
print(f"ERROR: Not enough parameters for streak finder analysis. Skipping.\n"
|
||||
f"{params_required=}")
|
||||
results["cbd_error"] = (f"ERROR: Not enough parameters for streak finder analysis. "
|
||||
f"Skipping.\n{params_required=}")
|
||||
return
|
||||
|
||||
peak_structure_radius = results["cbd_peak_structure_radius"] # peak
|
||||
@@ -142,25 +161,29 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
x_center = results.get("beam_center_x", None)
|
||||
y_center = results.get("beam_center_y", None)
|
||||
|
||||
mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
|
||||
|
||||
mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max]
|
||||
for mask_roi in mask_rois:
|
||||
cryst_data = cryst_data.mask_region(mask_roi)
|
||||
mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False
|
||||
|
||||
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
|
||||
|
||||
if crop_roi is not None:
|
||||
crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1
|
||||
cryst_data = cryst_data.crop(crop_roi_t)
|
||||
crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max]
|
||||
lookahead = results.get("cbd_lookahead", 1)
|
||||
|
||||
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
|
||||
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
|
||||
|
||||
|
||||
det_obj = cryst_data.streak_detector(streaks_structure)
|
||||
peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
|
||||
detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
|
||||
num_threads=num_threads)
|
||||
if crop_roi is None:
|
||||
region = snr
|
||||
region_mask = mask
|
||||
else:
|
||||
region = snr[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]]
|
||||
region_mask = mask[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]]
|
||||
peaks = detect_peaks(region, region_mask, peaks_structure.rank, peak_vmin,
|
||||
num_threads=num_threads)
|
||||
peaks = filter_peaks(peaks, region, region_mask, peaks_structure, peak_vmin, npts,
|
||||
num_threads=num_threads)
|
||||
|
||||
detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size,
|
||||
lookahead, nfa, num_threads=num_threads)
|
||||
|
||||
if isinstance(detected, list):
|
||||
detected = detected[0]
|
||||
@@ -173,33 +196,28 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
results["bragg_counts"] = []
|
||||
return
|
||||
|
||||
streaks = det_obj.to_streaks(detected)
|
||||
streak_lines = detected.to_lines()
|
||||
detected_streaks = np.array(detected.streaks)
|
||||
streak_lines = streaks.lines
|
||||
|
||||
# Adjust to crop region
|
||||
if crop_roi is not None:
|
||||
shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]]
|
||||
streak_lines = streak_lines + shift
|
||||
|
||||
if x_center is not None and y_center is not None:
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
streaks_mask = streaks.concentric_only(x_center, y_center)
|
||||
streak_lines = streak_lines[streaks_mask]
|
||||
detected_streaks = detected_streaks[streaks_mask]
|
||||
concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines)
|
||||
if concentric_streaks_mask is not None:
|
||||
|
ext-dorofe_e marked this conversation as resolved
Outdated
ext-dorofe_e
commented
TODO: move concentric filter to a separate method TODO: move concentric filter to a separate method
|
||||
streak_lines = streak_lines[concentric_streaks_mask]
|
||||
detected_streaks = detected_streaks[concentric_streaks_mask]
|
||||
|
||||
streak_lengths = np.sqrt(
|
||||
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
|
||||
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2)
|
||||
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
|
||||
np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2)
|
||||
).tolist()
|
||||
|
||||
streak_lines = streak_lines.T
|
||||
_, number_of_streaks = streak_lines.shape
|
||||
|
||||
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
|
||||
|
||||
list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1
|
||||
bragg_counts = [streak.total_mass() for streak in detected_streaks]
|
||||
|
||||
results["number_of_streaks"] = number_of_streaks
|
||||
|
||||
Reference in New Issue
Block a user
TODO:
dels are likely unnecessary; run more tests to make sure