diff --git a/src/cristallina/analysis.py b/src/cristallina/analysis.py index 7566d1d..ea7344f 100644 --- a/src/cristallina/analysis.py +++ b/src/cristallina/analysis.py @@ -1,5 +1,5 @@ import re -from collections import defaultdict, deque +from collections import defaultdict from pathlib import Path from typing import Optional import logging @@ -19,6 +19,7 @@ from .utils import ROI logger = logging.getLogger(__name__) + def setup_cachedirs(pgroup=None, cachedir=None): """ Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841") @@ -39,26 +40,30 @@ def setup_cachedirs(pgroup=None, cachedir=None): parts = re.split(r"(\d.*)", pgroup) # ['p', '2343', ''] pgroup_no = parts[-2] - candidates = [f"/sf/cristallina/data/p{pgroup_no}/work", - f"/sf/cristallina/data/p{pgroup_no}/res",] - + candidates = [ + f"/sf/cristallina/data/p{pgroup_no}/work", + f"/sf/cristallina/data/p{pgroup_no}/res", + ] + for cache_parent_dir in candidates: if Path(cache_parent_dir).exists(): cachedir = Path(cache_parent_dir) / "cachedir" break - + except KeyError as e: - print(e) + logger.warning(f"Could not determine p-group due to {e}. Using default cachedir.") cachedir = "/das/work/units/cristallina/p19739/cachedir" try: memory = Memory(cachedir, verbose=0, compress=2) except PermissionError as e: + logger.warning(f"Could not use cachedir {cachedir} due to {e}. Using /tmp instead.") cachedir = "/tmp" memory = Memory(cachedir, verbose=0, compress=2) - + return memory + memory = setup_cachedirs() @@ -204,12 +209,12 @@ def perform_image_stack_sum( batch_size=10, roi: Optional[ROI] = None, preview=False, - # operations=["sum"], - lower_cutoff_threshold=None, + lower_cutoff_threshold=None, # in keV + upper_cutoff_threshold=None, # in keV ): """ - Performs one or more calculations ("sum", "mean" or "std") for a given region of interest (roi) - for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object). + Performs summation for a given region of interest (roi) for an image channel + from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object). Allows alignment, i.e. reducing only to a common subset with other channels. @@ -217,16 +222,9 @@ def perform_image_stack_sum( Preview only applies calculation to first batch and returns. - Returns a dictionary ({"JF16T03V01_intensity":[11, 18, 21, 55, ...]}) - with the given channel values for each pulse and corresponding pulse id. + Returns: A 2D array with the summed image over all pulses without missing data. """ - possible_operations = { - "sum": ["intensity", np.sum], - "mean": ["mean", np.mean], - "std": ["mean", np.std], - } - with SFDataFiles(*filesets) as data: if alignment_channels is not None: channels = [channel] + [ch for ch in alignment_channels] @@ -246,10 +244,8 @@ def perform_image_stack_sum( im_ROI = im[roi.rows, roi.cols] summed = np.zeros(im_ROI[0].shape) - - for image_slice in Images.in_batches(batch_size): - - index_slice, im = image_slice + + for index_slice, im in Images.in_batches(batch_size): if roi is None: im_ROI = im @@ -258,6 +254,8 @@ def perform_image_stack_sum( if lower_cutoff_threshold is not None: im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI) + if upper_cutoff_threshold is not None: + im_ROI = np.where(im_ROI > upper_cutoff_threshold, 0, im_ROI) summed = summed + np.sum(im_ROI, axis=(0)) @@ -280,10 +278,10 @@ def perform_image_roi_crop( upper_cutoff=np.inf, ): """ - Cuts out a region of interest (ROI) for an image channel + Cuts out a region of interest (ROI) for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object). - - Drops missing data from output and allows alignment, + + Drops missing data from output and allows alignment, i.e. reducing only to a common subset with other channels. Calculations are performed in batches to reduce maximum memory requirements. @@ -310,10 +308,10 @@ def perform_image_roi_crop( Images = subset[channel] - rois_within_batch = list() + rois_within_batch = [] for image_slice in Images.in_batches(batch_size): - + index_slice, im = image_slice if roi is None: @@ -329,20 +327,24 @@ def perform_image_roi_crop( # only return first batch if preview: break - + return np.array(rois_within_batch) @memory.cache() -def calculate_JF_stacks(scan : SFScanInfo, - lower_cutoff_threshold=7.6, - exclude_steps: list[int] | None = None, - recompute : bool=False, - detector : str = "JF16T03V02"): +def calculate_JF_stacks( + scan: SFScanInfo, + lower_cutoff_threshold=7.6, # in keV + upper_cutoff_threshold=None, # in keV + exclude_steps: list[int] | None = None, + recompute: bool = False, + detector: str = "JF16T03V02", +): """Calculate image stacks for JF detectors in a scan. Args: scan (SFScanInfo): The scan object containing scan steps. - lower_cutoff_threshold: Threshold to apply to images before summation in keV. + lower_cutoff_threshold: Threshold to apply to pixel values before summation in keV. + upper_cutoff_threshold: Upper threshold to apply to pixel values before summation in keV. exclude_steps: List of step indices to exclude from processing. Defaults to None. recompute: If True, forces recomputation even if cached results are available. Defaults to False. detector: The detector channel to process. Defaults to "JF16T03V02" (JF 1.5M). @@ -350,7 +352,7 @@ def calculate_JF_stacks(scan : SFScanInfo, stacks: List of summed image stacks for each step. I0_norms: List of I0 normalizations for each step. """ - + stacks = [] I0_norms = [] @@ -369,17 +371,27 @@ def calculate_JF_stacks(scan : SFScanInfo, pids_after = subset.pids.copy() if set(pids_before) != set(pids_after): - logger.warning(f"Step {i}: dropped {set(pids_before) - set(pids_after)} pulse IDs due to missing data.") - + logger.warning( + f"Step {i}: dropped {set(pids_before) - set(pids_after)} pulse IDs due to missing data." + ) + # we only consider the JF files here files = [f.fname for f in step.files if "JF" in f.fname] - stack = perform_image_stack_sum(files, channel=detector, - lower_cutoff_threshold=lower_cutoff_threshold) + stack = perform_image_stack_sum( + files, + channel=detector, + lower_cutoff_threshold=lower_cutoff_threshold, + upper_cutoff_threshold=upper_cutoff_threshold, + ) stacks.append(stack) - + # TODO: define roi for I0 detector - stack_I0 = perform_image_roi_crop(files, channel=channels.JF_I0, lower_cutoff=2,) + stack_I0 = perform_image_roi_crop( + files, + channel=channels.JF_I0, + lower_cutoff=2, + ) I0_norm = np.sum(stack_I0, axis=(0, 1, 2)) I0_norms.append(I0_norm) @@ -433,24 +445,24 @@ def calculate_image_histograms( if bins is None: for image_slice in Images.in_batches(batch_size): - + index_slice, im = image_slice - + if roi is None: im_ROI = im else: im_ROI = im[:, roi.rows, roi.cols] - + if lower_cutoff_threshold is not None: im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI) - - bins = np.histogram_bin_edges(im.flatten(), bins='auto') + + bins = np.histogram_bin_edges(im.flatten(), bins="auto") # only return first batch to calculate bins break - all_hist = np.zeros(len(bins)-1) + all_hist = np.zeros(len(bins) - 1) for image_slice in Images.in_batches(batch_size): - + index_slice, im = image_slice if roi is None: @@ -461,7 +473,7 @@ def calculate_image_histograms( if lower_cutoff_threshold is not None: im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI) if bins is None: - bins = np.histogram_bin_edges(im.flatten(), bins='auto') + bins = np.histogram_bin_edges(im.flatten(), bins="auto") summed = summed + np.sum(im_ROI, axis=(0)) hist, _ = np.histogram(im.flatten(), bins=bins) @@ -469,12 +481,10 @@ def calculate_image_histograms( # only return first batch if preview: break - + return all_hist, bins - - def fit_2d_gaussian(image, roi: Optional[ROI] = None, plot=False): """ 2D Gaussian fit using LMFit for a given image and an optional region of interest. @@ -502,7 +512,7 @@ def fit_2d_gaussian(image, roi: Optional[ROI] = None, plot=False): z = im.ravel() # and this also as a 1D model = lmfit.models.Gaussian2dModel() - params = model.guess(z.astype('float'), x.astype('float'), y.astype('float')) + params = model.guess(z.astype("float"), x.astype("float"), y.astype("float")) result = model.fit( z, x=x, @@ -533,7 +543,7 @@ def _plot_2d_gaussian_fit(im, z, model, result): """Plot helper function to use the current image data, model and fit result and plots them together. """ - from matplotlib import pyplot as plt + from scipy.interpolate import griddata len_y, len_x = im.shape @@ -562,9 +572,7 @@ def _plot_2d_gaussian_fit(im, z, model, result): ax = axs[1, 0] fit = model.func(X, Y, **result.best_values) - art = ax.pcolormesh( - X, Y, Z - fit, vmin=-0.05 * vmax, vmax=0.05 * vmax, cmap="gray", shading="auto" - ) + art = ax.pcolormesh(X, Y, Z - fit, vmin=-0.05 * vmax, vmax=0.05 * vmax, cmap="gray", shading="auto") fig.colorbar(art, ax=ax, label="z", shrink=0.5) ax.set_title("Data - Fit") @@ -699,7 +707,7 @@ def fit_2d_gaussian_rotated( center_x = result.params["centerx"].value center_y = result.params["centery"].value - if plot == True: + if plot: _plot_2d_gaussian_fit(im, z, mod, result) return center_x, center_y, result @@ -707,7 +715,7 @@ def fit_2d_gaussian_rotated( def fit_1d_gaussian(x, y, use_offset=True, ax=None, print_results=False): """ - 1D-Gaussian fit with optional constant offset using LMFIT. + 1D-Gaussian fit with optional constant offset using LMFIT. Uses a heuristic guess for initial parameters. Returns: lmfit.model.ModelResult @@ -719,19 +727,23 @@ def fit_1d_gaussian(x, y, use_offset=True, ax=None, print_results=False): model = peak + offset if use_offset: - pars = offset.make_params(c = np.median(y)) + pars = offset.make_params(c=np.median(y)) else: pars = offset.make_params(c=0) - pars['c'].vary = False - - pars += peak.guess(y, x, amplitude=(np.max(y)-np.min(y))/2) - - result = model.fit(y, pars, x=x,) + pars["c"].vary = False + + pars += peak.guess(y, x, amplitude=(np.max(y) - np.min(y)) / 2) + + result = model.fit( + y, + pars, + x=x, + ) if print_results: print(result.fit_report()) if ax is not None: - ax.plot(x, result.best_fit, label='fit') + ax.plot(x, result.best_fit, label="fit") return result