cleaned up caching and small fixes

This commit is contained in:
2023-09-14 11:29:52 +02:00
parent 787c863cd1
commit be01a75c68
4 changed files with 88 additions and 34 deletions

View File

@@ -15,9 +15,6 @@ from joblib import Parallel, delayed, Memory
from . import utils
from .utils import ROI
memory = None
def setup_cachedirs(pgroup=None, cachedir=None):
"""
Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841")
@@ -26,7 +23,6 @@ def setup_cachedirs(pgroup=None, cachedir=None):
If heuristics fail we use "/tmp" as a non-persistent alternative.
"""
global memory
if cachedir is not None:
# explicit directory given, use this choice
memory = Memory(cachedir, verbose=0, compress=2)
@@ -51,13 +47,13 @@ def setup_cachedirs(pgroup=None, cachedir=None):
return memory
setup_cachedirs()
memory = None
memory = setup_cachedirs()
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def perform_image_calculations(
fileset,
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
@@ -67,7 +63,7 @@ def perform_image_calculations(
):
"""
Performs one or more calculations ("sum", "mean" or "std") for a given region of interest (roi)
for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
for an image channel from a fileset (e.g. ["run0352/data/acq0001.*.h5"] or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
@@ -85,7 +81,7 @@ def perform_image_calculations(
"std": ["mean", np.std],
}
with SFDataFiles(*fileset) as data:
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:
@@ -124,7 +120,7 @@ def perform_image_calculations(
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def sum_images(
fileset,
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
@@ -146,7 +142,7 @@ def sum_images(
"""
return perform_image_calculations(
fileset,
filesets,
channel=channel,
alignment_channels=alignment_channels,
batch_size=batch_size,
@@ -157,7 +153,7 @@ def sum_images(
def get_contrast_images(
fileset,
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
@@ -169,7 +165,7 @@ def get_contrast_images(
"""
return perform_image_calculations(
fileset,
filesets,
channel=channel,
alignment_channels=alignment_channels,
batch_size=batch_size,
@@ -181,7 +177,7 @@ def get_contrast_images(
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def perform_image_stack_sum(
fileset,
filesets,
channel="JF16T03V01",
alignment_channels=None,
batch_size=10,
@@ -210,7 +206,7 @@ def perform_image_stack_sum(
"std": ["mean", np.std],
}
with SFDataFiles(*fileset) as data:
with SFDataFiles(*filesets) as data:
if alignment_channels is not None:
channels = [channel] + [ch for ch in alignment_channels]
else:

View File

@@ -5,6 +5,7 @@ import matplotlib
from matplotlib import pyplot as plt
import warnings
# because of https://github.com/kornia/kornia/issues/1425
warnings.simplefilter("ignore", DeprecationWarning)
@@ -43,6 +44,7 @@ def ju_patch_less_verbose(ju_module):
ju_patch_less_verbose(ju)
def plot_correlation(x, y, ax=None, **ax_kwargs):
"""
Plots the correlation of x and y in a normalized scatterplot.
@@ -59,7 +61,7 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
ynorm = (y - np.mean(y)) / ystd
n = len(y)
r = 1 / (n) * sum(xnorm * ynorm)
if ax is None:
@@ -73,10 +75,11 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
return ax, r
def plot_channel(data : SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
def plot_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
Optionally: a matplotlib axis to plot into
"""
@@ -95,7 +98,6 @@ def plot_channel(data : SFDataFiles, channel_name, ax=None):
def axis_styling(ax, channel_name, description):
ax.set_title(channel_name)
# ax.set_xlabel('x')
# ax.set_ylabel('a.u.')
@@ -110,7 +112,7 @@ def axis_styling(ax, channel_name, description):
)
def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
def plot_1d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a single numeric value per pulse.
"""
@@ -131,7 +133,7 @@ def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
def plot_2d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
"""
@@ -153,22 +155,22 @@ def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
def plot_detector_image(image_data, channel_name=None, ax=None, rois=None, norms=None, log_colorscale=False):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
"""
im = data[channel_name][pulse]
im = image_data
def log_transform(z):
return np.log(np.clip(z, 1E-12, np.max(z)))
return np.log(np.clip(z, 1e-12, np.max(z)))
if log_colorscale:
im = log_transform(im)
im = log_transform(im)
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
@@ -189,7 +191,9 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
for i, roi in enumerate(rois):
# Create a rectangle with ([bottom left corner coordinates], width, height)
rect = patches.Rectangle(
[roi.left, roi.bottom], roi.width, roi.height,
[roi.left, roi.bottom],
roi.width,
roi.height,
linewidth=3,
edgecolor=f"C{i}",
facecolor="none",
@@ -199,9 +203,26 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
plt.legend(loc=4)
ax.legend(loc=4)
def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
"""
image_data = data[channel_name][pulse]
plot_detector_image(
image_data, channel_name=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale
)
def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
"""
Plots channel data for two channels where the first is taken as the (constant) x-axis
and the second as the y-axis (here we take by default the mean over the individual pulses).
@@ -217,12 +238,11 @@ def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, av
y_data = mean_over_frames
else:
y_data = data[channel_name_y].data[pulse]
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(data[channel_name_x].data[0], y_data)
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
ax.set_xlabel(channel_name_x)
axis_styling(ax, channel_name_y, description)

View File

@@ -1,6 +1,7 @@
import yaml
import os
import json
import re
from pathlib import Path
from collections import defaultdict
@@ -19,7 +20,7 @@ from joblib import Parallel, delayed, cpu_count
def collect_runs_metadata(pgroup_data_dir: str | os.PathLike):
"""
Generates metadata overview for all runs in a given p-group data directory.
Generates metadata overview for all runs in a given p-group data directory (e.g. "/sf/cristallina/data/p21261").
Not all available metadata is included, we skip e.g. lists of recorded channels.
@@ -74,6 +75,18 @@ def collect_runs_metadata(pgroup_data_dir: str | os.PathLike):
df = pd.DataFrame(measurements)
# add run and acquisition numbers from fileset
# kind of added after the fact but now everything is in place
run_no, acq_no = [], []
for fileset in df["files"]:
pattern = r".*run(\d{4})/data/acq(\d{4})"
m = re.match(pattern, fileset)
run_no.append(int(m.groups()[0]))
acq_no.append(int(m.groups()[1]))
df['run'] = run_no
df['acq'] = acq_no
return df
@@ -274,6 +287,18 @@ def process_run(
Parallel(n_jobs=n_jobs, verbose=10)(delayed(process_step)(i) for i in range(len(scan)))
def is_processed_JF_file(filepath, detector='JF16T03V01'):
""" Checks if a given .h5 file from a Jungfrau detector has been processed or only
contains raw adc values.
"""
import h5py
f = h5py.File(filepath)
try:
data = f['data'][detector]
except KeyError:
raise ValueError(f"{filepath} does not seem to be an Jungfrau file from the detector {detector}.")
return 'meta' in f['data'][detector].keys()
class ROI:
"""Definition of region of interest (ROI) in image coordinates.

View File

@@ -27,6 +27,19 @@ def test_image_calculations():
453842.6]
assert np.allclose(res["JF16T03V01_intensity"], intensity)
def test_joblib_memory():
""" We need joblib for fast caching of intermediate results in all cases. So we check
if the basic function caching to disk works.
"""
def calc_example(x):
return x**2
calc_cached = cristallina.analysis.memory.cache(calc_example)
assert calc_cached(8) == 64
assert calc_cached.check_call_in_cache(8) == True
def test_minimal_2d_gaussian():
image = np.array([[0,0,0,0,0],