diff --git a/README.md b/README.md index 227ba92..580dd62 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,9 @@ options: * `'cbd_num_threads': int` - Number of threads to use for peak finder algorithm. * `'cbd_negative_handler': 'mask'/'zero'/'shift'/''` - [optional] Method to handle negative values in converted frames, defaults to `''` (do not handle). * `'cbd_min_hit_streaks': int` - [optional] Minimum number of discovered streaks to categorize frame as a hit, defaults to 5. - * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(y_min, y_max, x_min, x_max)` coordinates of ROIs to mask out during peak finding. + * `'cbd_mask_rois': list[(int, int, int, int)]` - [optional] list of `(x_min, x_max, y_min, y_max)` coordinates of ROIs to mask out during peak finding. * `'cbd_crop_roi': [int, int, int, int]` - [optional] run streak finder on a cropped image, e.g. one quadrant, for purpose of spedup. + * `'cbd_lookahead': int` - [optional] Number of linelets considered at the ends of a streak to be added to the streak. Algorithm Output: * `'number_of_streaks': int` - Indicates the count of identified streaks. diff --git a/dap/algos/addmaskfile.py b/dap/algos/addmaskfile.py index 4f82ee5..f337f70 100644 --- a/dap/algos/addmaskfile.py +++ b/dap/algos/addmaskfile.py @@ -18,14 +18,14 @@ def calc_apply_additional_mask_from_file(results, pixel_mask_pf): # Support for hdf5 and npy if mask_file.endswith(".npy"): try: - mask = np.load(mask_file).astype(np.bool) + mask = np.load(mask_file).astype(bool) except Exception as error: results["mask_error"] = f"Error loading mask data from NumPy file {mask_file}:\n{error}" return else: try: with h5py.File(mask_file, "r") as h5f: - mask = h5f[mask_dataset][:].astype(np.bool) + mask = h5f[mask_dataset][:].astype(bool) except Exception as error: results["mask_error"] = f"Error loading mask from hdf5 file {mask_file}:\n{error}" return diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index fc91a28..0c3a7f6 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -8,11 +8,12 @@ https://github.com/simply-nicky/streak_finder/ """ import h5py import numpy as np -from streak_finder import CrystData from streak_finder.label import Structure2D +from streak_finder.streak_finder import detect_peaks, detect_streaks, filter_peaks +from streak_finder._src.src.median import median -DEFAULT_NUM_THREADS = 16 DEFAULT_MIN_HIT_STREAKS = 5 +DEFAULT_NUM_THREADS = 16 def _handle_negative_values(data, mask, handler: str): @@ -37,23 +38,23 @@ def calc_streakfinder_analysis(results, data, pf_pixel_mask): _handle_negative_values(data, pf_pixel_mask, negative_val_handler) try: - cryst_data = _generate_cryst_data(results, data, pf_pixel_mask) + snr = _calc_snr(results, data, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" + results["cbd_error"] = f"SNR - Error processing CBD data:\n{error}" return data if do_snr: - # Changes data and mask in-place - data = cryst_data.snr[0].copy() + # Changes data in-place + data = snr try: - _calc_streakfinder_analysis(results, cryst_data) + _calc_streakfinder_analysis(results, snr, pf_pixel_mask) except Exception as error: # Broad exception - we don't want to break anything here - results["cbd_error"] = f"Error processing CBD data:\n{error}" + results["cbd_error"] = f"StreakFind - Error processing CBD data:\n{error}" return data -def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: +def _calc_snr(results, data, pf_pixel_mask): params_required = [ "cbd_whitefield_data_file", "cbd_std_data_file", @@ -72,8 +73,6 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: whitefield_dataset = results.get("cbd_whitefield_dataset", "entry/crystallography/whitefield") std_dataset = results.get("cbd_std_dataset", "entry/crystallography/std") - num_threads = results.get("cbd_num_threads", DEFAULT_NUM_THREADS) - with h5py.File(whitefield_data_file, "r") as hf: whitefield = hf[whitefield_dataset][:] @@ -86,22 +85,42 @@ def _generate_cryst_data(results, data, pf_pixel_mask) -> CrystData: else: mask_dataset = results.get("cbd_mask_dataset", "entry/instrument/detector/mask") with h5py.File(mask_data_file, "r") as hf: - mask = hf[mask_dataset][:].astype(np.bool) + mask = hf[mask_dataset][:].astype(bool) mask *= pf_pixel_mask - - data = CrystData( - data=data[np.newaxis, :], - mask=mask, - std=std, - whitefield=whitefield - ) if scale_whitefield: - data = data.scale_whitefield(method='median', num_threads=num_threads) - - data = data.update_snr() - return data + _scale_whitefield(data, mask, whitefield, std, results.get("cbd_num_threads", DEFAULT_NUM_THREADS)) + snr = np.divide( + data * mask - whitefield, + std, + out=np.zeros_like(data), + where=(std!=0.0) + ) + return snr -def _calc_streakfinder_analysis(results, cryst_data: CrystData): +def _scale_whitefield(data, mask, whitefield, std, num_threads): + mask = mask & (std > 0.0) + y = np.divide(data, std, out=np.zeros_like(data), where=mask)[mask] + w = np.divide(whitefield, std, out=np.zeros_like(data), where=mask)[mask] + + scales = median(y * w, axis=0, num_threads=num_threads) / \ + median(w * w, axis=0, num_threads=num_threads) + whitefield *= scales + +def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33): + if x_center is not None and y_center is not None: + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers - np.asarray([x_center, y_center]) + prod = np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + return streaks_mask + +def _calc_streakfinder_analysis(results, snr, mask): do_streakfinder_analysis = results.get("do_streakfinder_analysis", False) if not do_streakfinder_analysis: return @@ -121,8 +140,8 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): ] if not all(param in results.keys() for param in params_required): - print(f"ERROR: Not enough parameters for streak finder analysis. Skipping.\n" - f"{params_required=}") + results["cbd_error"] = (f"ERROR: Not enough parameters for streak finder analysis. " + f"Skipping.\n{params_required=}") return peak_structure_radius = results["cbd_peak_structure_radius"] # peak @@ -142,25 +161,29 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): x_center = results.get("beam_center_x", None) y_center = results.get("beam_center_y", None) - mask_rois = results.get("cbd_mask_rois", []) # list of [y_min, y_max, x_min, x_max] - + mask_rois = results.get("cbd_mask_rois", []) # list of [x_min, x_max, y_min, y_max] for mask_roi in mask_rois: - cryst_data = cryst_data.mask_region(mask_roi) + mask[mask_roi[2]: mask_roi[3], mask_roi[0]: mask_roi[1]] = False - crop_roi = results.get("cbd_crop_roi", None) # [y_min, y_max, x_min, x_max] - - if crop_roi is not None: - crop_roi_t = [crop_roi[2], crop_roi[3], crop_roi[0], crop_roi[1]]# y0, y1, x0, x1 - cryst_data = cryst_data.crop(crop_roi_t) + crop_roi = results.get("cbd_crop_roi", None) # [x_min, x_max, y_min, y_max] + lookahead = results.get("cbd_lookahead", 1) peaks_structure = Structure2D(peak_structure_radius, peak_structure_rank) streaks_structure = Structure2D(streak_structure_radius, streak_structure_rank) - - det_obj = cryst_data.streak_detector(streaks_structure) - peaks = det_obj.detect_peaks(peak_vmin, npts, peaks_structure, num_threads) - detected = det_obj.detect_streaks(peaks, xtol, streak_vmin, min_size, nfa=nfa, - num_threads=num_threads) + if crop_roi is None: + region = snr + region_mask = mask + else: + region = snr[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]] + region_mask = mask[crop_roi[2]: crop_roi[3], crop_roi[0]: crop_roi[1]] + peaks = detect_peaks(region, region_mask, peaks_structure.rank, peak_vmin, + num_threads=num_threads) + peaks = filter_peaks(peaks, region, region_mask, peaks_structure, peak_vmin, npts, + num_threads=num_threads) + + detected = detect_streaks(peaks, region, region_mask, streaks_structure, xtol, streak_vmin, min_size, + lookahead, nfa, num_threads=num_threads) if isinstance(detected, list): detected = detected[0] @@ -173,33 +196,28 @@ def _calc_streakfinder_analysis(results, cryst_data: CrystData): results["bragg_counts"] = [] return - streaks = det_obj.to_streaks(detected) + streak_lines = detected.to_lines() detected_streaks = np.array(detected.streaks) - streak_lines = streaks.lines # Adjust to crop region if crop_roi is not None: shift = [crop_roi[0], crop_roi[2], crop_roi[0], crop_roi[2]] streak_lines = streak_lines + shift - if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - streaks_mask = streaks.concentric_only(x_center, y_center) - streak_lines = streak_lines[streaks_mask] - detected_streaks = detected_streaks[streaks_mask] + concentric_streaks_mask = _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines) + if concentric_streaks_mask is not None: + streak_lines = streak_lines[concentric_streaks_mask] + detected_streaks = detected_streaks[concentric_streaks_mask] streak_lengths = np.sqrt( - np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + - np.pow((streak_lines[..., 2] - streak_lines[..., 0]), 2) + np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) + + np.power((streak_lines[..., 2] - streak_lines[..., 0]), 2) ).tolist() streak_lines = streak_lines.T _, number_of_streaks = streak_lines.shape - list_result = streak_lines.tolist() # arr(4, n_lines); 0coord x0, y0, x1, y1 - + list_result = streak_lines.tolist() # arr(4, n_lines); coord x0, y0, x1, y1 bragg_counts = [streak.total_mass() for streak in detected_streaks] results["number_of_streaks"] = number_of_streaks