TEST: try to get rid of data containers in streak finder package; make it simple

This commit is contained in:
2025-07-14 20:58:20 +02:00
parent 2ea2129d5b
commit 901587db79
+47 -32
View File
@@ -9,22 +9,13 @@ https://github.com/simply-nicky/streak_finder/
import h5py
import numpy as np
from psutil import virtual_memory
from streak_finder import CrystData as CrystDataBase
from streak_finder.label import Structure2D
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
DEFAULT_NUM_THREADS = 16
DEFAULT_MIN_HIT_STREAKS = 5
class CrystData(CrystDataBase):
def clear(self):
self.data = None
self.mask = None
self.snr = None
self.std = None
def _handle_negative_values(data, mask, handler: str):
if not handler or np.all(data>=0):
return
@@ -107,14 +98,14 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData:
data * mask - whitefield,
std,
out=np.zeros_like(data),
where=(std==0.0)
where=(std!=0.0)
)
del whitefield
del std
del mask
return snr
def _calc_streakfinder_analysis(results, cryst_data: CrystData):
def _calc_streakfinder_analysis(results, snr, mask):
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
if not do_streakfinder_analysis:
return
@@ -155,26 +146,33 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
x_center = results.get("beam_center_x", None)
y_center = results.get("beam_center_y", None)
mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
for mask_roi in mask_rois:
cryst_data = cryst_data.mask_region(mask_roi)
# mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
#
# for mask_roi in mask_rois:
# cryst_data = cryst_data.mask_region(mask_roi)
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
if crop_roi is not None:
crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1
cryst_data = cryst_data.crop(crop_roi_t)
lookahead = results.get("cbd_lookahead", 1)
# if crop_roi is not None:
# crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1
# cryst_data = cryst_data.crop(crop_roi_t)
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
det_obj = cryst_data.streak_detector(streaks_structure)
peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
num_threads=num_threads)
peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin,
num_threads=num_threads)
peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts,
num_threads=num_threads)
# det_obj = cryst_data.streak_detector(streaks_structure)
# peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
# detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
# num_threads=num_threads)
detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size,
lookahead, nfa, num_threads=num_threads)
if isinstance(detected, list):
detected = detected[0]
@@ -186,9 +184,17 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
results["bragg_counts"] = []
return
streaks = det_obj.to_streaks(detected)
if isinstance(detected, list):
streaks = [np.asarray(pattern.to_lines()) for pattern in detected]
else:
streaks = [np.asarray(detected.to_lines()),]
streak_lines = np.concatenate(streaks)
# rstreaks = Streaks(index=IndexArray(idxs), lines=lines)
# streaks = det_obj.to_streaks(detected)
detected_streaks = np.array(detected.streaks)
streak_lines = streaks.lines
# streak_lines = streaks.lines
# Adjust to crop region
if crop_roi is not None:
@@ -196,10 +202,19 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
streak_lines = streak_lines + shift
if x_center is not None and y_center is not None:
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
streaks_mask = streaks.concentric_only(x_center, y_center)
# if crop_roi is not None:
# x_center -= crop_roi[0]
# y_center -= crop_roi[2]
threshold = 0.33
centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1],
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
r = centers -np.asarray([x_center, y_center])
prod =np.sum(norm * r, axis=-1)[..., None]
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
# return mask
# streaks_mask = streaks.concentric_only(x_center, y_center)
streak_lines = streak_lines[streaks_mask]
detected_streaks = detected_streaks[streaks_mask]