|
|
|
|
@@ -3,31 +3,99 @@ Streak Finder algorithm implemented by CFEL Chapman group
|
|
|
|
|
|
|
|
|
|
Requires Convergent beam streak finder package installed:
|
|
|
|
|
|
|
|
|
|
https://github.com/simply-nicky/streak_finder
|
|
|
|
|
(note g++ 11 required for building)
|
|
|
|
|
https://github.com/simply-nicky/streak_finder/tree/swiss_fel
|
|
|
|
|
(note g++ 11 required for building, numpy 2+ required)
|
|
|
|
|
"""
|
|
|
|
|
import h5py
|
|
|
|
|
import numpy as np
|
|
|
|
|
from math import sqrt, pow
|
|
|
|
|
|
|
|
|
|
from streak_finder import PatternStreakFinder
|
|
|
|
|
from streak_finder import CrystData
|
|
|
|
|
from streak_finder.label import Structure2D
|
|
|
|
|
|
|
|
|
|
from skimage.measure import profile_line
|
|
|
|
|
DEFAULT_NUM_THREADS = 16
|
|
|
|
|
|
|
|
|
|
def calc_cbd_analysis(results, data, pf_pixel_mask):
|
|
|
|
|
try:
|
|
|
|
|
cryst_data = _generate_cryst_data(results, data, pf_pixel_mask)
|
|
|
|
|
except Exception as error: # Broad exception - we don't want to break anything here
|
|
|
|
|
print(f"Error processing CBD data:\n{error}")
|
|
|
|
|
results["cbd_error"] = f"Error processing CBD data:\n{error}"
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def calc_streakfinder_analysis(results, data, pixel_mask_sf):
|
|
|
|
|
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
|
|
|
|
|
if not do_streakfinder_analysis:
|
|
|
|
|
print(f"No streak finder analysis")
|
|
|
|
|
try:
|
|
|
|
|
_calc_streakfinder_analysis(results, cryst_data)
|
|
|
|
|
except Exception as error: # Broad exception - we don't want to break anything here
|
|
|
|
|
print(f"Error processing CBD data:\n{error}")
|
|
|
|
|
results["cbd_error"] = f"Error processing CBD data:\n{error}"
|
|
|
|
|
return cryst_data.snr
|
|
|
|
|
|
|
|
|
|
def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData:
|
|
|
|
|
do_snr = results.get("do_snr", False)
|
|
|
|
|
if not do_snr:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
params_required = [
|
|
|
|
|
"sf_structure_radius",
|
|
|
|
|
"sf_structure_rank",
|
|
|
|
|
"whitefield_data_file",
|
|
|
|
|
"mask_data_file",
|
|
|
|
|
"std_data_file",
|
|
|
|
|
"scale_whitefield", # Bool
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not all([param in results.keys() for param in params_required]):
|
|
|
|
|
raise ValueError(f"ERROR: Not enough parameters for CBD correction. Skipping\n"
|
|
|
|
|
f"{params_required=}")
|
|
|
|
|
|
|
|
|
|
whitefield_data_file = results["whitefield_data_file"]
|
|
|
|
|
mask_data_file = results["mask_data_file"]
|
|
|
|
|
std_data_file = results["std_data_file"]
|
|
|
|
|
scale_whitefield = results["scale_whitefield"]
|
|
|
|
|
|
|
|
|
|
# Using CXI Store specification as default
|
|
|
|
|
whitefield_dataset = results.get("whitefield_dataset", "entry/crystallography/whitefield")
|
|
|
|
|
mask_dataset = results.get("mask_dataset", "entry/instrument/detector/mask")
|
|
|
|
|
std_dataset = results.get("std_dataset", "entry/crystallography/std")
|
|
|
|
|
|
|
|
|
|
num_threads = results.get("num_threads", DEFAULT_NUM_THREADS)
|
|
|
|
|
|
|
|
|
|
with h5py.File(whitefield_data_file, "r") as hf:
|
|
|
|
|
whitefield = np.asarray(hf[whitefield_dataset])
|
|
|
|
|
|
|
|
|
|
with h5py.File(mask_data_file, "r") as hf:
|
|
|
|
|
mask = np.asarray(hf[mask_dataset])
|
|
|
|
|
|
|
|
|
|
with h5py.File(std_data_file, "r") as hf:
|
|
|
|
|
std = np.asarray(hf[std_dataset])
|
|
|
|
|
|
|
|
|
|
data = CrystData(
|
|
|
|
|
data=data.reshape((-1,) + data.shape[-2:]),
|
|
|
|
|
mask=mask*pf_pixel_mask,
|
|
|
|
|
std=std,
|
|
|
|
|
whitefield=whitefield
|
|
|
|
|
)
|
|
|
|
|
if scale_whitefield:
|
|
|
|
|
data = data.scale_whitefield(method='median', num_threads=num_threads)
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def _calc_streakfinder_analysis(results, cryst_data: CrystData):
|
|
|
|
|
do_streakfinder_analysis = results.get("do_streakfinder_analysis", False)
|
|
|
|
|
if not do_streakfinder_analysis:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
params_required = [
|
|
|
|
|
"sf_peak_structure_radius",
|
|
|
|
|
"sf_peak_structure_rank",
|
|
|
|
|
"sf_streak_structure_radius",
|
|
|
|
|
"sf_streak_structure_rank",
|
|
|
|
|
"sf_peak_vmin",
|
|
|
|
|
"sf_streak_vmin",
|
|
|
|
|
"sf_min_size",
|
|
|
|
|
"sf_vmin",
|
|
|
|
|
"sf_npts",
|
|
|
|
|
"sf_xtol"
|
|
|
|
|
"sf_xtol",
|
|
|
|
|
"sf_nfa",
|
|
|
|
|
|
|
|
|
|
"sf_num_threads",
|
|
|
|
|
# "beam_center_x",
|
|
|
|
|
# "beam_center_y"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not all([param in results.keys() for param in params_required]):
|
|
|
|
|
@@ -35,40 +103,50 @@ def calc_streakfinder_analysis(results, data, pixel_mask_sf):
|
|
|
|
|
f"{params_required=}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
radius = results["sf_structure_radius"]
|
|
|
|
|
rank = results["sf_structure_rank"]
|
|
|
|
|
peak_structure_radius = results["sf_peak_structure_radius"] # peak
|
|
|
|
|
peak_structure_rank = results["sf_peak_structure_rank"]
|
|
|
|
|
streak_structure_radius = results["sf_streak_structure_radius"] # streak
|
|
|
|
|
streak_structure_rank = results["sf_streak_structure_rank"]
|
|
|
|
|
peak_vmin = results["sf_peak_vmin"] # peak
|
|
|
|
|
streak_vmin = results["sf_streak_vmin"] # streak
|
|
|
|
|
min_size = results["sf_min_size"]
|
|
|
|
|
vmin = results["sf_vmin"]
|
|
|
|
|
npts = results["sf_npts"]
|
|
|
|
|
xtol = results["sf_xtol"]
|
|
|
|
|
nfa = results["sf_nfa"]
|
|
|
|
|
num_threads = results["sf_num_threads"]
|
|
|
|
|
|
|
|
|
|
struct = Structure2D(radius, rank)
|
|
|
|
|
psf = PatternStreakFinder(
|
|
|
|
|
data=data,
|
|
|
|
|
mask=pixel_mask_sf,
|
|
|
|
|
structure=struct,
|
|
|
|
|
min_size=min_size
|
|
|
|
|
)
|
|
|
|
|
# Find peaks in a pattern. Returns a sparse set of peaks which values are above a threshold
|
|
|
|
|
# ``vmin`` that have a supporing set of a size larger than ``npts``. The minimal distance
|
|
|
|
|
# between peaks is ``2 * structure.radius``
|
|
|
|
|
peaks = psf.detect_peaks(vmin=vmin, npts=npts)
|
|
|
|
|
x_center = results.get("beam_center_x", None)
|
|
|
|
|
y_center = results.get("beam_center_y", None)
|
|
|
|
|
|
|
|
|
|
# Streak finding algorithm. Starting from the set of seed peaks, the lines are iteratively
|
|
|
|
|
# extended with a connectivity structure.
|
|
|
|
|
streaks = psf.detect_streaks(peaks=peaks, xtol=xtol, vmin=vmin).to_lines()
|
|
|
|
|
streak_lengths = []
|
|
|
|
|
bragg_counts = []
|
|
|
|
|
for streak in streaks:
|
|
|
|
|
x0, y0, x1, y1 = streak
|
|
|
|
|
streak_lengths.append(sqrt(pow((x1 - x0), 2) + pow((y1 - y0), 2)))
|
|
|
|
|
bragg_counts.append(float(np.sum(profile_line(data, (x0, y0), (x1, y1)))))
|
|
|
|
|
streak_lines = streaks.T
|
|
|
|
|
peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank)
|
|
|
|
|
streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
det_obj = cryst_data.streak_detector(streaks_structure)
|
|
|
|
|
peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads)
|
|
|
|
|
detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa,
|
|
|
|
|
num_threads=num_threads)
|
|
|
|
|
if isinstance(detected, list):
|
|
|
|
|
detected = detected[0]
|
|
|
|
|
|
|
|
|
|
streaks = det_obj.to_streaks(detected)
|
|
|
|
|
|
|
|
|
|
if x_center is not None and y_center is not None:
|
|
|
|
|
streaks = streaks.concentric_only(x_center, y_center)
|
|
|
|
|
|
|
|
|
|
streak_lines = streaks.lines
|
|
|
|
|
streak_lengths = np.sqrt(
|
|
|
|
|
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) +
|
|
|
|
|
np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2)
|
|
|
|
|
).tolist()
|
|
|
|
|
|
|
|
|
|
streak_lines = streak_lines.T
|
|
|
|
|
_, number_of_streaks = streak_lines.shape
|
|
|
|
|
print(f"Found {number_of_streaks} streaks")
|
|
|
|
|
list_result = []
|
|
|
|
|
for line in streak_lines: # arr(4, n_lines); 0coord x0, y0, x1, y1
|
|
|
|
|
list_result.append(line.tolist())
|
|
|
|
|
|
|
|
|
|
list_result = [line.tolist() for line in streak_lines] # arr(4, n_lines); 0coord x0, y0, x1, y1
|
|
|
|
|
bragg_counts = [streak.total_mass() for streak in detected.streaks.values()]
|
|
|
|
|
|
|
|
|
|
results.update({"number_of_streaks": number_of_streaks})
|
|
|
|
|
results.update({"is_hit_frame": number_of_streaks > 0})
|
|
|
|
|
results.update({"streaks": list_result})
|
|
|
|
|
|