Files
cristallina_analysis_package/src/cristallina/analysis.py

727 lines
21 KiB
Python
Executable File

import re
from collections import defaultdict
from pathlib import Path
from typing import Optional
import logging
import numpy as np
from matplotlib import pyplot as plt
import lmfit
from sfdata import SFDataFiles, sfdatafile, SFScanInfo
from joblib import Memory
from . import utils
from . import channels
from .utils import ROI
logger = logging.getLogger(__name__)
def setup_cachedirs(pgroup=None, cachedir=None):
"""
Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841")
or an explicitly given directory.
If heuristics fail we use "/tmp" as a non-persistent alternative.
"""
if cachedir is not None:
# explicit directory given, use this choice
memory = Memory(cachedir, verbose=0, compress=2)
return memory
try:
if pgroup is None:
pgroup_no = utils.heuristic_extract_pgroup()
else:
parts = re.split(r"(\d.*)", pgroup) # ['p', '2343', '']
pgroup_no = parts[-2]
candidates = [
f"/sf/cristallina/data/p{pgroup_no}/work",
f"/sf/cristallina/data/p{pgroup_no}/res",
]
for cache_parent_dir in candidates:
if Path(cache_parent_dir).exists():
cachedir = Path(cache_parent_dir) / "cachedir"
break
except KeyError as e:
logger.warning(f"Could not determine p-group due to {e}. Using default cachedir.")
cachedir = "/das/work/units/cristallina/p19739/cachedir"
try:
memory = Memory(cachedir, verbose=0, compress=2)
except PermissionError as e:
logger.warning(f"Could not use cachedir {cachedir} due to {e}. Using /tmp instead.")
cachedir = "/tmp"
memory = Memory(cachedir, verbose=0, compress=2)
return memory
memory = setup_cachedirs()
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def perform_image_calculations(
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
roi: Optional[ROI] = None,
preview=False,
operations=["sum"],
lower_cutoff_threshold=None,
upper_cutoff_threshold=None,
):
"""
Performs one or more calculations ("sum", "mean" or "std") for a given region of interest (roi)
for an image channel from a fileset (e.g. ["run0352/data/acq0001.*.h5"] or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
Calculations are performed in batches to reduce maximum memory requirements.
Preview only applies calculation to first batch and returns.
Returns a dictionary ({"JF16T03V01_intensity":[11, 18, 21, 55, ...]})
with the given channel values for each pulse and corresponding pulse id.
"""
possible_operations = {
"sum": ["intensity", np.sum],
"mean": ["mean", np.mean],
"std": ["mean", np.std],
}
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:
channels = [channel]
subset = data[channels]
subset.drop_missing()
Images = subset[channel]
res = defaultdict(list)
res["roi"] = repr(roi)
for image_slice in Images.in_batches(batch_size):
index_slice, im = image_slice
if roi is None:
im_ROI = im[:]
else:
im_ROI = im[:, roi.rows, roi.cols]
# use cutoff to set values to 0 below
if lower_cutoff_threshold is not None:
im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI)
if upper_cutoff_threshold is not None:
im_ROI = np.where(im_ROI > upper_cutoff_threshold, 0, im_ROI)
# iterate over all operations
for op in operations:
label, func = possible_operations[op]
res[f"{channel}_{label}"].extend(func(im_ROI, axis=(1, 2)))
res["pids"].extend(Images.pids[index_slice])
# only return first batch
if preview:
break
return res
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def sum_images(
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
roi: Optional[ROI] = None,
preview=False,
lower_cutoff_threshold=None,
):
"""
Sums a given region of interest (roi) for an image channel from a
given fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
Summation is performed in batches to reduce maximum memory requirements.
Preview only sums and returns the first batch.
Returns a dictionary ({"JF16T03V01_intensity":[11, 18, 21, 55, ...]})
with the given channel intensity for each pulse and corresponding pulse id.
"""
return perform_image_calculations(
filesets,
channel=channel,
alignment_channels=alignment_channels,
batch_size=batch_size,
roi=roi,
preview=preview,
operations=["sum"],
lower_cutoff_threshold=lower_cutoff_threshold,
)
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def perform_image_stack_sum(
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
roi: Optional[ROI] = None,
preview=False,
lower_cutoff_threshold=None, # in keV
upper_cutoff_threshold=None, # in keV
):
"""
Performs summation along the pulse dimensionfor a given region of interest (roi) for an image channel
from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
Calculations are performed in batches to reduce maximum memory requirements.
Preview only applies calculation to first batch and returns.
Returns: A 2D array with the summed image over all pulses without missing data.
"""
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:
channels = [channel]
subset = data[channels]
subset.drop_missing()
Images = subset[channel]
# create empty array for stack sum with right shape
im = Images[0]
if roi is None:
im_ROI = im[:]
else:
im_ROI = im[roi.rows, roi.cols]
summed = np.zeros(im_ROI[0].shape)
for index_slice, im in Images.in_batches(batch_size):
if roi is None:
im_ROI = im
else:
im_ROI = im[:, roi.rows, roi.cols]
if lower_cutoff_threshold is not None:
im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI)
if upper_cutoff_threshold is not None:
im_ROI = np.where(im_ROI > upper_cutoff_threshold, 0, im_ROI)
summed = summed + np.sum(im_ROI, axis=(0))
# only return first batch
if preview:
break
return summed
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def perform_image_roi_crop(
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
roi: ROI = None,
preview=False,
lower_cutoff=None,
upper_cutoff=np.inf,
):
"""
Cuts out a region of interest (ROI) for an image channel
from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
Drops missing data from output and allows alignment,
i.e. reducing only to a common subset with other channels.
Calculations are performed in batches to reduce maximum memory requirements.
Lower- and upper cutoff allow to threshold the data.
Preview only applies calculation to first batch and returns.
Returns: An 1D array (along the pulses recorded without missing) of 2D images
Beware though: this can create a rather large array that exceeds available memory.
TODO: should we create a complete channel here instead of returning `raw` data?
"""
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:
channels = [channel]
subset = data[channels]
subset.drop_missing()
Images = subset[channel]
rois_within_batch = []
for image_slice in Images.in_batches(batch_size):
index_slice, im = image_slice
if roi is None:
im_ROI = im
else:
im_ROI = im[:, roi.rows, roi.cols]
if lower_cutoff is not None:
im_ROI = np.where((im_ROI < lower_cutoff) | (im_ROI > upper_cutoff), 0, im_ROI)
rois_within_batch.extend(im_ROI)
# only return first batch
if preview:
break
return np.array(rois_within_batch)
@memory.cache()
def calculate_JF_stacks(
scan: SFScanInfo,
lower_cutoff_threshold=7.6, # in keV
upper_cutoff_threshold=None, # in keV
exclude_steps: list[int] | None = None,
recompute: bool = False,
detector: str = "JF16T03V02",
):
"""Calculate image stacks for JF detectors in a scan.
Args:
scan (SFScanInfo): The scan object containing scan steps.
lower_cutoff_threshold: Threshold to apply to pixel values before summation in keV.
upper_cutoff_threshold: Upper threshold to apply to pixel values before summation in keV.
exclude_steps: List of step indices to exclude from processing. Defaults to None.
recompute: If True, forces recomputation even if cached results are available. Defaults to False.
detector: The detector channel to process. Defaults to "JF16T03V02" (JF 1.5M).
Returns:
stacks: List of summed image stacks for each step.
I0_norms: List of I0 normalizations for each step.
"""
stacks = []
I0_norms = []
for i, step in enumerate(scan):
if exclude_steps is not None and i in exclude_steps:
logger.info(f"skipping step {i}")
continue
JF_channels = [channels.JF, channels.JF_8M, channels.JF_I0, channels.GASMONITOR]
available_channels = [ch.name for ch in step.channels]
selected_channels = [ch for ch in JF_channels if ch in available_channels]
subset = step[*selected_channels]
pids_before = subset.pids.copy()
subset.drop_missing()
pids_after = subset.pids.copy()
if set(pids_before) != set(pids_after):
logger.warning(
f"Step {i}: dropped {set(pids_before) - set(pids_after)} pulse IDs due to missing data."
)
# we only consider the JF files here
files = [f.fname for f in step.files if "JF" in f.fname]
stack = perform_image_stack_sum(
files,
channel=detector,
lower_cutoff_threshold=lower_cutoff_threshold,
upper_cutoff_threshold=upper_cutoff_threshold,
)
stacks.append(stack)
# TODO: define roi for I0 detector
stack_I0 = perform_image_roi_crop(
files,
channel=channels.JF_I0,
lower_cutoff=2,
)
I0_norm = np.sum(stack_I0, axis=(0, 1, 2))
I0_norms.append(I0_norm)
return np.array(stacks), np.array(I0_norms)
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def calculate_image_histograms(
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
roi: Optional[ROI] = None,
preview=False,
lower_cutoff_threshold=None,
bins=None,
):
"""
Calculates a histogram for a given region of interest (roi)
for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
Calculations are performed in batches to reduce maximum memory requirements.
Preview only applies calculation to first batch and returns.
Returns:
(histogram, bins)
"""
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:
channels = [channel]
subset = data[channels]
subset.drop_missing()
Images = subset[channel]
# create empty array for stack sum with right shape
im = Images[0]
if roi is None:
im_ROI = im[:]
else:
im_ROI = im[roi.rows, roi.cols]
summed = np.zeros(im_ROI[0].shape)
if bins is None:
for image_slice in Images.in_batches(batch_size):
index_slice, im = image_slice
if roi is None:
im_ROI = im
else:
im_ROI = im[:, roi.rows, roi.cols]
if lower_cutoff_threshold is not None:
im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI)
bins = np.histogram_bin_edges(im.flatten(), bins="auto")
# only return first batch to calculate bins
break
all_hist = np.zeros(len(bins) - 1)
for image_slice in Images.in_batches(batch_size):
index_slice, im = image_slice
if roi is None:
im_ROI = im
else:
im_ROI = im[:, roi.rows, roi.cols]
if lower_cutoff_threshold is not None:
im_ROI = np.where(im_ROI < lower_cutoff_threshold, 0, im_ROI)
if bins is None:
bins = np.histogram_bin_edges(im.flatten(), bins="auto")
summed = summed + np.sum(im_ROI, axis=(0))
hist, _ = np.histogram(im.flatten(), bins=bins)
all_hist += hist
# only return first batch
if preview:
break
return all_hist, bins
def fit_2d_gaussian(image, roi: Optional[ROI] = None, plot=False):
"""
2D Gaussian fit using LMFit for a given image and an optional region of interest.
plot=True as optional argument plots the fit results.
Returns the x, y coordinates of the center and the results object which contains
further fit statistics.
"""
# given an image and optional ROI
if roi is not None:
im = image[roi.rows, roi.cols]
else:
im = image
len_y, len_x = im.shape
y = np.arange(len_y)
x = np.arange(len_x)
x, y = np.meshgrid(x, y) # here now a 2D mesh
x, y = x.ravel(), y.ravel() # and all back into sequences of 1D arrays
z = im.ravel() # and this also as a 1D
model = lmfit.models.Gaussian2dModel()
params = model.guess(z.astype("float"), x.astype("float"), y.astype("float"))
result = model.fit(
z,
x=x,
y=y,
params=params,
method="leastsq",
verbose=False,
nan_policy=None,
max_nfev=None,
)
if roi is not None:
# convert back to original image coordinates
center_x = roi.left + result.params["centerx"]
center_y = roi.bottom + result.params["centery"]
else:
center_x = result.params["centerx"].value
center_y = result.params["centery"].value
if plot == True:
_plot_2d_gaussian_fit(im, z, model, result)
return center_x, center_y, result
def _plot_2d_gaussian_fit(im, z, model, result):
"""Plot helper function to use the current image data, model and fit result and
plots them together.
"""
from scipy.interpolate import griddata
len_y, len_x = im.shape
X, Y = np.meshgrid(np.arange(len_x), np.arange(len_y))
Z = griddata((X.ravel(), Y.ravel()), z, (X, Y), method="linear", fill_value=np.nan)
fig, axs = plt.subplots(2, 2, figsize=(10, 10), layout="constrained")
for ax in axs.ravel():
ax.axis("equal")
ax.set_xlabel("x")
ax.set_ylabel("y")
vmax = np.max(Z)
ax = axs[0, 0]
art = ax.pcolormesh(X, Y, Z, vmin=0, vmax=vmax, shading="auto")
fig.colorbar(art, ax=ax, label="z", shrink=0.5)
ax.set_title("Data")
ax = axs[0, 1]
fit = model.func(X, Y, **result.best_values)
art = ax.pcolormesh(X, Y, fit, vmin=0, vmax=vmax, shading="auto")
fig.colorbar(art, ax=ax, label="z", shrink=0.5)
ax.set_title("Fit")
ax = axs[1, 0]
fit = model.func(X, Y, **result.best_values)
art = ax.pcolormesh(X, Y, Z - fit, vmin=-0.05 * vmax, vmax=0.05 * vmax, cmap="gray", shading="auto")
fig.colorbar(art, ax=ax, label="z", shrink=0.5)
ax.set_title("Data - Fit")
ax = axs[1, 1]
fit = model.func(X, Y, **result.best_values)
art = ax.pcolormesh(X, Y, fit, vmin=0, vmax=vmax, shading="auto")
ax.contour(X, Y, fit, 8, colors="r", alpha=0.4)
fig.colorbar(art, ax=ax, label="z", shrink=0.5)
ax.set_title("Data & Fit")
fig.suptitle("2D Gaussian fit results")
def gaussian2d_rot_model(
x,
y=0.0,
amplitude=1.0,
centerx=0.0,
centery=0.0,
sigmax=1.0,
sigmay=1.0,
rotation=0,
background=0,
):
"""Returns a two-dimensional Gaussian model from lmfit with a rotation in radians around the center."""
sr = np.sin(rotation)
cr = np.cos(rotation)
center_x_rot = centerx * cr - centery * sr
center_y_rot = centerx * sr + centery * cr
x_rot = x * cr - y * sr
y_rot = x * sr + y * cr
return (
lmfit.models.gaussian2d(
x_rot,
y=y_rot,
amplitude=amplitude,
centerx=center_x_rot,
centery=center_y_rot,
sigmax=sigmax,
sigmay=sigmay,
)
+ background
)
def fit_2d_gaussian_rotated(
image,
roi=None,
plot=False,
vary_rotation=True,
vary_background=False,
):
"""
2D Gaussian fit with rotation using LMFit for a given image and an optional region of interest.
As the number of free parameters for this kind of fit is large issues with convergence appear often.
Here we first fit without rotation, use the obtained parameters as a starting guess.
plot = True as optional argument plots the fit results.
Returns the x, y coordinates of the center and the results object which contains
further fit statistics.
"""
# given an image and optional ROI
if roi is not None:
im = image[roi.rows, roi.cols]
else:
im = image
len_y, len_x = im.shape
y = np.arange(len_y)
x = np.arange(len_x)
x, y = np.meshgrid(x, y) # here now a 2D mesh
x, y = x.ravel(), y.ravel() # and all back into sequences of 1D arrays
z = im.ravel() # and this also as a 1D
mod = lmfit.Model(gaussian2d_rot_model, independent_vars=["x", "y"])
# Guess parameters, this is one possible approach
mod.set_param_hint("amplitude", value=np.max(z) * 0.75, min=0, vary=True)
mod.set_param_hint("centerx", value=np.mean(x) / 2, vary=True)
mod.set_param_hint("centery", value=np.mean(y) / 2, vary=True)
mod.set_param_hint("sigmax", value=np.mean(x) / 10, vary=True)
mod.set_param_hint("sigmay", value=np.mean(y) / 10, vary=True)
mod.set_param_hint("rotation", value=0.0, min=-np.pi / 2, max=np.pi / 2, vary=False)
mod.set_param_hint("background", value=0.0, vary=vary_background)
params = mod.make_params(verbose=False)
# first fit without rotation
result = mod.fit(
z,
x=x,
y=y,
params=params,
method="leastsq",
verbose=False,
nan_policy=None,
max_nfev=20,
)
# now refining with rotation
params = result.params
params["rotation"].set(vary=vary_rotation)
result = mod.fit(
z,
x=x,
y=y,
params=params,
method="leastsq",
verbose=False,
nan_policy=None,
max_nfev=None,
)
if roi is not None:
# convert back to original image coordinates
center_x = roi.left + result.params["centerx"]
center_y = roi.bottom + result.params["centery"]
else:
center_x = result.params["centerx"].value
center_y = result.params["centery"].value
if plot:
_plot_2d_gaussian_fit(im, z, mod, result)
return center_x, center_y, result
def fit_1d_gaussian(x, y, use_offset=True, ax=None, print_results=False):
"""
1D-Gaussian fit with optional constant offset using LMFIT.
Uses a heuristic guess for initial parameters.
Returns: lmfit.model.ModelResult
"""
peak = lmfit.models.GaussianModel()
offset = lmfit.models.ConstantModel()
model = peak + offset
if use_offset:
pars = offset.make_params(c=np.median(y))
else:
pars = offset.make_params(c=0)
pars["c"].vary = False
pars += peak.guess(y, x, amplitude=(np.max(y) - np.min(y)) / 2)
result = model.fit(
y,
pars,
x=x,
)
if print_results:
print(result.fit_report())
if ax is not None:
ax.plot(x, result.best_fit, label="fit")
return result