TEST: try return crop roi; cleanup

This commit is contained in:
2025-07-14 23:04:17 +02:00
parent 9ed3335021
commit d30adf3e14

View File

@@ -14,6 +14,7 @@ from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_pea
from streak_finder._src.src.median import robust_lsq, median
DEFAULT_MIN_HIT_STREAKS = 5
DEFAULT_NUM_THREADS = 16
def _handle_negative_values(data, mask, handler: str):
@@ -89,7 +90,7 @@ def _calc_snr(results, data, pf_pixel_mask):
mask = hf[mask_dataset][:].astype(bool)
mask *= pf_pixel_mask
if scale_whitefield:
_scale_whitefield(data, mask, whitefield, std)
_scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS))
snr = np.divide(
data * mask - whitefield,
std,
@@ -101,21 +102,13 @@ def _calc_snr(results, data, pf_pixel_mask):
del mask
return snr
def _scale_whitefield(data, mask, whitefield, std,
r0: float = 0.0, r1: float = 0.5,
n_iter: int = 12, lm: float = 9.0, num_threads: int = 16
):
def _scale_whitefield(data, mask, whitefield, std, num_threads):
mask = mask & (std > 0.0)
y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask]
W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
#
# scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm,
# num_threads=num_threads)
w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
scales = median(y * W, axis=0, num_threads=num_threads) / \
median(W * W, axis=0, num_threads=num_threads)
print(f"{scales=}")
scales=np.ravel(scales)
scales = median(y * w, axis=0, num_threads=num_threads) / \
median(w * w, axis=0, num_threads=num_threads)
whitefield *= scales
def _calc_streakfinder_analysis(results, snr, mask):
@@ -173,19 +166,19 @@ def _calc_streakfinder_analysis(results, snr, mask):
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin,
if crop_roi is None:
region = snr
else:
region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]]
peaks = detect_peaks(region, mask, peaks_structure.rank, peak_vmin,
num_threads=num_threads)
peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts,
peaks = filter_peaks(peaks, region, mask, peaks_structure, peak_vmin, npts,
num_threads=num_threads)
# det_obj = cryst_data.streak_detector(streaks_structure)
# peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
# detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
# num_threads=num_threads)
detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size,
detected = detect_streaks(peaks, region, mask, streaks_structure, xtol, streak_vmin, min_size,
lookahead, nfa, num_threads=num_threads)
del peaks
if isinstance(detected, list):
detected = detected[0]
@@ -197,17 +190,9 @@ def _calc_streakfinder_analysis(results, snr, mask):
results["bragg_counts"] = []
return
if isinstance(detected, list):
streaks = [np.asarray(pattern.to_lines()) for pattern in detected]
else:
streaks = [np.asarray(detected.to_lines()),]
streak_lines = np.concatenate(streaks)
# rstreaks = Streaks(index=IndexArray(idxs), lines=lines)
# streaks = det_obj.to_streaks(detected)
streak_lines = detected.to_lines()
detected_streaks = np.array(detected.streaks)
# streak_lines = streaks.lines
del detected
# Adjust to crop region
if crop_roi is not None:
@@ -215,19 +200,17 @@ def _calc_streakfinder_analysis(results, snr, mask):
streak_lines = streak_lines + shift
if x_center is not None and y_center is not None:
# if crop_roi is not None:
# x_center -= crop_roi[0]
# y_center -= crop_roi[2]
if crop_roi is not None:
x_center -= crop_roi[0]
y_center -= crop_roi[2]
threshold = 0.33
centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1],
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
r = centers -np.asarray([x_center, y_center])
prod =np.sum(norm * r, axis=-1)[..., None]
prod = np.sum(norm * r, axis=-1)[..., None]
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
# return mask
# streaks_mask = streaks.concentric_only(x_center, y_center)
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
streak_lines = streak_lines[streaks_mask]
detected_streaks = detected_streaks[streaks_mask]
@@ -240,9 +223,11 @@ def _calc_streakfinder_analysis(results, snr, mask):
_, number_of_streaks = streak_lines.shape
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
bragg_counts = [streak.total_mass() for streak in detected_streaks]
del detected_streaks
del streak_lines
results["number_of_streaks"] = number_of_streaks
results["is_hit_frame"] = (number_of_streaks > min_hit_streaks)
results["streaks"] = list_result