TEST: try return crop roi; cleanup
This commit is contained in:
@@ -14,6 +14,7 @@ from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_pea
|
||||
from streak_finder._src.src.median import robust_lsq, median
|
||||
|
||||
DEFAULT_MIN_HIT_STREAKS = 5
|
||||
DEFAULT_NUM_THREADS = 16
|
||||
|
||||
|
||||
def _handle_negative_values(data, mask, handler: str):
|
||||
@@ -89,7 +90,7 @@ def _calc_snr(results, data, pf_pixel_mask):
|
||||
mask = hf[mask_dataset][:].astype(bool)
|
||||
mask *= pf_pixel_mask
|
||||
if scale_whitefield:
|
||||
_scale_whitefield(data, mask, whitefield, std)
|
||||
_scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS))
|
||||
snr = np.divide(
|
||||
data * mask - whitefield,
|
||||
std,
|
||||
@@ -101,21 +102,13 @@ def _calc_snr(results, data, pf_pixel_mask):
|
||||
del mask
|
||||
return snr
|
||||
|
||||
def _scale_whitefield(data, mask, whitefield, std,
|
||||
r0: float = 0.0, r1: float = 0.5,
|
||||
n_iter: int = 12, lm: float = 9.0, num_threads: int = 16
|
||||
):
|
||||
def _scale_whitefield(data, mask, whitefield, std, num_threads):
|
||||
mask = mask & (std > 0.0)
|
||||
y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask]
|
||||
W = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
|
||||
#
|
||||
# scales = robust_lsq(W=W, y=y, axis=0, r0=r0, r1=r1, n_iter=n_iter, lm=lm,
|
||||
# num_threads=num_threads)
|
||||
w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask]
|
||||
|
||||
scales = median(y * W, axis=0, num_threads=num_threads) / \
|
||||
median(W * W, axis=0, num_threads=num_threads)
|
||||
print(f"{scales=}")
|
||||
scales=np.ravel(scales)
|
||||
scales = median(y * w, axis=0, num_threads=num_threads) / \
|
||||
median(w * w, axis=0, num_threads=num_threads)
|
||||
whitefield *= scales
|
||||
|
||||
def _calc_streakfinder_analysis(results, snr, mask):
|
||||
@@ -173,19 +166,19 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
|
||||
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
|
||||
|
||||
peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin,
|
||||
if crop_roi is None:
|
||||
region = snr
|
||||
else:
|
||||
region = snr[crop_roi[0]: crop_roi[1], crop_roi[2]: crop_roi[3]]
|
||||
peaks = detect_peaks(region, mask, peaks_structure.rank, peak_vmin,
|
||||
num_threads=num_threads)
|
||||
peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts,
|
||||
peaks = filter_peaks(peaks, region, mask, peaks_structure, peak_vmin, npts,
|
||||
num_threads=num_threads)
|
||||
|
||||
# det_obj = cryst_data.streak_detector(streaks_structure)
|
||||
|
||||
|
||||
# peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
|
||||
# detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
|
||||
# num_threads=num_threads)
|
||||
detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size,
|
||||
detected = detect_streaks(peaks, region, mask, streaks_structure, xtol, streak_vmin, min_size,
|
||||
lookahead, nfa, num_threads=num_threads)
|
||||
del peaks
|
||||
|
||||
if isinstance(detected, list):
|
||||
detected = detected[0]
|
||||
|
||||
@@ -197,17 +190,9 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
results["bragg_counts"] = []
|
||||
return
|
||||
|
||||
|
||||
if isinstance(detected, list):
|
||||
streaks = [np.asarray(pattern.to_lines()) for pattern in detected]
|
||||
else:
|
||||
streaks = [np.asarray(detected.to_lines()),]
|
||||
streak_lines = np.concatenate(streaks)
|
||||
# rstreaks = Streaks(index=IndexArray(idxs), lines=lines)
|
||||
|
||||
# streaks = det_obj.to_streaks(detected)
|
||||
streak_lines = detected.to_lines()
|
||||
detected_streaks = np.array(detected.streaks)
|
||||
# streak_lines = streaks.lines
|
||||
del detected
|
||||
|
||||
# Adjust to crop region
|
||||
if crop_roi is not None:
|
||||
@@ -215,19 +200,17 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
streak_lines = streak_lines + shift
|
||||
|
||||
if x_center is not None and y_center is not None:
|
||||
# if crop_roi is not None:
|
||||
# x_center -= crop_roi[0]
|
||||
# y_center -= crop_roi[2]
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
threshold = 0.33
|
||||
centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
|
||||
r = centers -np.asarray([x_center, y_center])
|
||||
prod =np.sum(norm * r, axis=-1)[..., None]
|
||||
prod = np.sum(norm * r, axis=-1)[..., None]
|
||||
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
|
||||
streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
# return mask
|
||||
# streaks_mask = streaks.concentric_only(x_center, y_center)
|
||||
streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
streak_lines = streak_lines[streaks_mask]
|
||||
detected_streaks = detected_streaks[streaks_mask]
|
||||
|
||||
@@ -240,9 +223,11 @@ def _calc_streakfinder_analysis(results, snr, mask):
|
||||
_, number_of_streaks = streak_lines.shape
|
||||
|
||||
list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1
|
||||
|
||||
bragg_counts = [streak.total_mass() for streak in detected_streaks]
|
||||
|
||||
del detected_streaks
|
||||
del streak_lines
|
||||
|
||||
results["number_of_streaks"] = number_of_streaks
|
||||
results["is_hit_frame"] = (number_of_streaks > min_hit_streaks)
|
||||
results["streaks"] = list_result
|
||||
|
||||
Reference in New Issue
Block a user