diff --git a/dap/algos/streakfind.py b/dap/algos/streakfind.py index fd1c373..10971bd 100644 --- a/dap/algos/streakfind.py +++ b/dap/algos/streakfind.py @@ -221,18 +221,19 @@ def _calc_streakfinder_analysis(results, snr, mask): def _get_concentric_only_mask(x_center, y_center, crop_roi, streak_lines, threshold=0.33): - if x_center is not None and y_center is not None: - if crop_roi is not None: - x_center -= crop_roi[0] - y_center -= crop_roi[2] - centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) - norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], - streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) - r = centers - np.asarray([x_center, y_center]) - prod = np.sum(norm * r, axis=-1)[..., None] - proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None] - streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold - return streaks_mask + if x_center is None or y_center is None: + return None + if crop_roi is not None: + x_center -= crop_roi[0] + y_center -= crop_roi[2] + centers = np.mean(streak_lines.reshape(-1, 2, 2), axis=1) + norm = np.stack([streak_lines[:, 3] - streak_lines[:, 1], + streak_lines[:, 0] - streak_lines[:, 2]], axis=-1) + r = centers - np.asarray([x_center, y_center]) + prod = np.sum(norm * r, axis=-1)[..., None] + proj = r - prod * norm / np.sum(norm ** 2, axis=-1)[..., None] + streaks_mask = np.sqrt(np.sum(proj ** 2, axis=-1)) / np.sqrt(np.sum(r ** 2, axis=-1)) < threshold + return streaks_mask