TEST: try to get rid of data containers in streak finder package; make it simple
This commit is contained in:
+47
-32
@@ -9,22 +9,13 @@ https://github.com/simply-nicky/streak_finder/
|
||||
import h5py
|
||||
import numpy as np
|
||||
from psutil import virtual_memory
|
||||
from streak_finder import CrystData as CrystDataBase
|
||||
from streak_finder.label import Structure2D
|
||||
from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks
|
||||
|
||||
DEFAULT_NUM_THREADS = 16
|
||||
DEFAULT_MIN_HIT_STREAKS = 5
|
||||
|
||||
|
||||
class CrystData(CrystDataBase):
|
||||
|
||||
def clear(self):
|
||||
self.data = None
|
||||
self.mask = None
|
||||
self.snr = None
|
||||
self.std = None
|
||||
|
||||
|
||||
def _handle_negative_values(data, mask, handler: str):
|
||||
if not handler or np.all(data>=0):
|
||||
return
|
||||
@@ -107,14 +98,14 @@ def _generate_cryst_data(results, data, pf_pixel_mask): # -> CrystData:
|
||||
data * mask - whitefield,
|
||||
std,
|
||||
out=np.zeros_like(data),
|
||||
where=(std==0.0)
|
||||
where=(std!=0.0)
|
||||
)
|
||||
del whitefield
|
||||
del std
|
||||
del mask
|
||||
return snr
|
||||
|
||||
def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
def _calc_streakfinder_analysis(results, snr, mask):
|
||||
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
|
||||
if not do_streakfinder_analysis:
|
||||
return
|
||||
@@ -155,26 +146,33 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
x_center = results.get("beam_center_x", None)
|
||||
y_center = results.get("beam_center_y", None)
|
||||
|
||||
mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
|
||||
|
||||
for mask_roi in mask_rois:
|
||||
cryst_data = cryst_data.mask_region(mask_roi)
|
||||
# mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max]
|
||||
#
|
||||
# for mask_roi in mask_rois:
|
||||
# cryst_data = cryst_data.mask_region(mask_roi)
|
||||
|
||||
crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max]
|
||||
|
||||
if crop_roi is not None:
|
||||
crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1
|
||||
cryst_data = cryst_data.crop(crop_roi_t)
|
||||
lookahead = results.get("cbd_lookahead", 1)
|
||||
# if crop_roi is not None:
|
||||
# crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]] # y0, y1, x0, x1
|
||||
# cryst_data = cryst_data.crop(crop_roi_t)
|
||||
|
||||
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
|
||||
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
|
||||
|
||||
|
||||
det_obj = cryst_data.streak_detector(streaks_structure)
|
||||
peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
|
||||
detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
|
||||
num_threads=num_threads)
|
||||
|
||||
peaks = detect_peaks(snr, mask, peaks_structure.rank, peak_vmin,
|
||||
num_threads=num_threads)
|
||||
peaks = filter_peaks(peaks, snr, mask, peaks_structure, peak_vmin, npts,
|
||||
num_threads=num_threads)
|
||||
|
||||
# det_obj = cryst_data.streak_detector(streaks_structure)
|
||||
|
||||
|
||||
# peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
|
||||
# detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
|
||||
# num_threads=num_threads)
|
||||
detected = detect_streaks(peaks, snr, mask, streaks_structure, xtol, streak_vmin, min_size,
|
||||
lookahead, nfa, num_threads=num_threads)
|
||||
if isinstance(detected, list):
|
||||
detected = detected[0]
|
||||
|
||||
@@ -186,9 +184,17 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
results["bragg_counts"] = []
|
||||
return
|
||||
|
||||
streaks = det_obj.to_streaks(detected)
|
||||
|
||||
if isinstance(detected, list):
|
||||
streaks = [np.asarray(pattern.to_lines()) for pattern in detected]
|
||||
else:
|
||||
streaks = [np.asarray(detected.to_lines()),]
|
||||
streak_lines = np.concatenate(streaks)
|
||||
# rstreaks = Streaks(index=IndexArray(idxs), lines=lines)
|
||||
|
||||
# streaks = det_obj.to_streaks(detected)
|
||||
detected_streaks = np.array(detected.streaks)
|
||||
streak_lines = streaks.lines
|
||||
# streak_lines = streaks.lines
|
||||
|
||||
# Adjust to crop region
|
||||
if crop_roi is not None:
|
||||
@@ -196,10 +202,19 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
||||
streak_lines = streak_lines + shift
|
||||
|
||||
if x_center is not None and y_center is not None:
|
||||
if crop_roi is not None:
|
||||
x_center -= crop_roi[0]
|
||||
y_center -= crop_roi[2]
|
||||
streaks_mask = streaks.concentric_only(x_center, y_center)
|
||||
# if crop_roi is not None:
|
||||
# x_center -= crop_roi[0]
|
||||
# y_center -= crop_roi[2]
|
||||
threshold = 0.33
|
||||
centers =np.mean(streak_lines.reshape(-1, 2, 2), axis=1)
|
||||
norm =np.stack([streak_lines[:, 3] - streak_lines[:, 1],
|
||||
streak_lines[:, 0] - streak_lines[:, 2]], axis=-1)
|
||||
r = centers -np.asarray([x_center, y_center])
|
||||
prod =np.sum(norm * r, axis=-1)[..., None]
|
||||
proj = r - prod * norm /np.sum(norm ** 2, axis=-1)[..., None]
|
||||
streaks_mask =np.sqrt(np.sum(proj ** 2, axis=-1)) /np.sqrt(np.sum(r ** 2, axis=-1)) < threshold
|
||||
# return mask
|
||||
# streaks_mask = streaks.concentric_only(x_center, y_center)
|
||||
streak_lines = streak_lines[streaks_mask]
|
||||
detected_streaks = detected_streaks[streaks_mask]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user