cleaned up caching and small fixes
This commit is contained in:
@@ -15,9 +15,6 @@ from joblib import Parallel, delayed, Memory
|
||||
from . import utils
|
||||
from .utils import ROI
|
||||
|
||||
memory = None
|
||||
|
||||
|
||||
def setup_cachedirs(pgroup=None, cachedir=None):
|
||||
"""
|
||||
Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841")
|
||||
@@ -26,7 +23,6 @@ def setup_cachedirs(pgroup=None, cachedir=None):
|
||||
If heuristics fail we use "/tmp" as a non-persistent alternative.
|
||||
"""
|
||||
|
||||
global memory
|
||||
if cachedir is not None:
|
||||
# explicit directory given, use this choice
|
||||
memory = Memory(cachedir, verbose=0, compress=2)
|
||||
@@ -51,13 +47,13 @@ def setup_cachedirs(pgroup=None, cachedir=None):
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
setup_cachedirs()
|
||||
memory = None
|
||||
memory = setup_cachedirs()
|
||||
|
||||
|
||||
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
|
||||
def perform_image_calculations(
|
||||
fileset,
|
||||
filesets,
|
||||
channel="JF16T03V01",
|
||||
alignment_channels=None,
|
||||
batch_size=10,
|
||||
@@ -67,7 +63,7 @@ def perform_image_calculations(
|
||||
):
|
||||
"""
|
||||
Performs one or more calculations ("sum", "mean" or "std") for a given region of interest (roi)
|
||||
for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
|
||||
for an image channel from a fileset (e.g. ["run0352/data/acq0001.*.h5"] or step.fnames from a SFScanInfo object).
|
||||
|
||||
Allows alignment, i.e. reducing only to a common subset with other channels.
|
||||
|
||||
@@ -85,7 +81,7 @@ def perform_image_calculations(
|
||||
"std": ["mean", np.std],
|
||||
}
|
||||
|
||||
with SFDataFiles(*fileset) as data:
|
||||
with SFDataFiles(*filesets) as data:
|
||||
if alignment_channels is not None:
|
||||
channels = [channel] + [ch for ch in alignment_channels]
|
||||
else:
|
||||
@@ -124,7 +120,7 @@ def perform_image_calculations(
|
||||
|
||||
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
|
||||
def sum_images(
|
||||
fileset,
|
||||
filesets,
|
||||
channel="JF16T03V01",
|
||||
alignment_channels=None,
|
||||
batch_size=10,
|
||||
@@ -146,7 +142,7 @@ def sum_images(
|
||||
"""
|
||||
|
||||
return perform_image_calculations(
|
||||
fileset,
|
||||
filesets,
|
||||
channel=channel,
|
||||
alignment_channels=alignment_channels,
|
||||
batch_size=batch_size,
|
||||
@@ -157,7 +153,7 @@ def sum_images(
|
||||
|
||||
|
||||
def get_contrast_images(
|
||||
fileset,
|
||||
filesets,
|
||||
channel="JF16T03V01",
|
||||
alignment_channels=None,
|
||||
batch_size=10,
|
||||
@@ -169,7 +165,7 @@ def get_contrast_images(
|
||||
"""
|
||||
|
||||
return perform_image_calculations(
|
||||
fileset,
|
||||
filesets,
|
||||
channel=channel,
|
||||
alignment_channels=alignment_channels,
|
||||
batch_size=batch_size,
|
||||
@@ -181,7 +177,7 @@ def get_contrast_images(
|
||||
|
||||
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
|
||||
def perform_image_stack_sum(
|
||||
fileset,
|
||||
filesets,
|
||||
channel="JF16T03V01",
|
||||
alignment_channels=None,
|
||||
batch_size=10,
|
||||
@@ -210,7 +206,7 @@ def perform_image_stack_sum(
|
||||
"std": ["mean", np.std],
|
||||
}
|
||||
|
||||
with SFDataFiles(*fileset) as data:
|
||||
with SFDataFiles(*filesets) as data:
|
||||
if alignment_channels is not None:
|
||||
channels = [channel] + [ch for ch in alignment_channels]
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,7 @@ import matplotlib
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import warnings
|
||||
|
||||
# because of https://github.com/kornia/kornia/issues/1425
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
@@ -43,6 +44,7 @@ def ju_patch_less_verbose(ju_module):
|
||||
|
||||
ju_patch_less_verbose(ju)
|
||||
|
||||
|
||||
def plot_correlation(x, y, ax=None, **ax_kwargs):
|
||||
"""
|
||||
Plots the correlation of x and y in a normalized scatterplot.
|
||||
@@ -59,7 +61,7 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
|
||||
ynorm = (y - np.mean(y)) / ystd
|
||||
|
||||
n = len(y)
|
||||
|
||||
|
||||
r = 1 / (n) * sum(xnorm * ynorm)
|
||||
|
||||
if ax is None:
|
||||
@@ -73,10 +75,11 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
|
||||
|
||||
return ax, r
|
||||
|
||||
def plot_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
"""
|
||||
Plots a given channel from an SFDataFiles object.
|
||||
|
||||
|
||||
def plot_channel(data: SFDataFiles, channel_name, ax=None):
|
||||
"""
|
||||
Plots a given channel from an SFDataFiles object.
|
||||
|
||||
Optionally: a matplotlib axis to plot into
|
||||
"""
|
||||
|
||||
@@ -95,7 +98,6 @@ def plot_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
|
||||
|
||||
def axis_styling(ax, channel_name, description):
|
||||
|
||||
ax.set_title(channel_name)
|
||||
# ax.set_xlabel('x')
|
||||
# ax.set_ylabel('a.u.')
|
||||
@@ -110,7 +112,7 @@ def axis_styling(ax, channel_name, description):
|
||||
)
|
||||
|
||||
|
||||
def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
def plot_1d_channel(data: SFDataFiles, channel_name, ax=None):
|
||||
"""
|
||||
Plots channel data for a channel that contains a single numeric value per pulse.
|
||||
"""
|
||||
@@ -131,7 +133,7 @@ def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
axis_styling(ax, channel_name, description)
|
||||
|
||||
|
||||
def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
def plot_2d_channel(data: SFDataFiles, channel_name, ax=None):
|
||||
"""
|
||||
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
|
||||
"""
|
||||
@@ -153,22 +155,22 @@ def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
|
||||
axis_styling(ax, channel_name, description)
|
||||
|
||||
|
||||
def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
|
||||
def plot_detector_image(image_data, channel_name=None, ax=None, rois=None, norms=None, log_colorscale=False):
|
||||
"""
|
||||
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
|
||||
Optional:
|
||||
Optional:
|
||||
- rois: draw a rectangular patch for the given roi(s)
|
||||
- norms: [min, max] values for colormap
|
||||
- log_colorscale: True for a logarithmic colormap
|
||||
"""
|
||||
|
||||
im = data[channel_name][pulse]
|
||||
im = image_data
|
||||
|
||||
def log_transform(z):
|
||||
return np.log(np.clip(z, 1E-12, np.max(z)))
|
||||
return np.log(np.clip(z, 1e-12, np.max(z)))
|
||||
|
||||
if log_colorscale:
|
||||
im = log_transform(im)
|
||||
im = log_transform(im)
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(constrained_layout=True)
|
||||
@@ -189,7 +191,9 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
|
||||
for i, roi in enumerate(rois):
|
||||
# Create a rectangle with ([bottom left corner coordinates], width, height)
|
||||
rect = patches.Rectangle(
|
||||
[roi.left, roi.bottom], roi.width, roi.height,
|
||||
[roi.left, roi.bottom],
|
||||
roi.width,
|
||||
roi.height,
|
||||
linewidth=3,
|
||||
edgecolor=f"C{i}",
|
||||
facecolor="none",
|
||||
@@ -199,9 +203,26 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
|
||||
|
||||
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
|
||||
axis_styling(ax, channel_name, description)
|
||||
plt.legend(loc=4)
|
||||
ax.legend(loc=4)
|
||||
|
||||
def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
|
||||
|
||||
def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
|
||||
"""
|
||||
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
|
||||
Optional:
|
||||
- rois: draw a rectangular patch for the given roi(s)
|
||||
- norms: [min, max] values for colormap
|
||||
- log_colorscale: True for a logarithmic colormap
|
||||
"""
|
||||
|
||||
image_data = data[channel_name][pulse]
|
||||
|
||||
plot_detector_image(
|
||||
image_data, channel_name=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale
|
||||
)
|
||||
|
||||
|
||||
def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
|
||||
"""
|
||||
Plots channel data for two channels where the first is taken as the (constant) x-axis
|
||||
and the second as the y-axis (here we take by default the mean over the individual pulses).
|
||||
@@ -217,12 +238,11 @@ def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, av
|
||||
y_data = mean_over_frames
|
||||
else:
|
||||
y_data = data[channel_name_y].data[pulse]
|
||||
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(constrained_layout=True)
|
||||
|
||||
ax.plot(data[channel_name_x].data[0], y_data)
|
||||
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
|
||||
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
|
||||
ax.set_xlabel(channel_name_x)
|
||||
axis_styling(ax, channel_name_y, description)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import yaml
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -19,7 +20,7 @@ from joblib import Parallel, delayed, cpu_count
|
||||
|
||||
def collect_runs_metadata(pgroup_data_dir: str | os.PathLike):
|
||||
"""
|
||||
Generates metadata overview for all runs in a given p-group data directory.
|
||||
Generates metadata overview for all runs in a given p-group data directory (e.g. "/sf/cristallina/data/p21261").
|
||||
|
||||
Not all available metadata is included, we skip e.g. lists of recorded channels.
|
||||
|
||||
@@ -74,6 +75,18 @@ def collect_runs_metadata(pgroup_data_dir: str | os.PathLike):
|
||||
|
||||
df = pd.DataFrame(measurements)
|
||||
|
||||
# add run and acquisition numbers from fileset
|
||||
# kind of added after the fact but now everything is in place
|
||||
run_no, acq_no = [], []
|
||||
for fileset in df["files"]:
|
||||
pattern = r".*run(\d{4})/data/acq(\d{4})"
|
||||
m = re.match(pattern, fileset)
|
||||
run_no.append(int(m.groups()[0]))
|
||||
acq_no.append(int(m.groups()[1]))
|
||||
|
||||
df['run'] = run_no
|
||||
df['acq'] = acq_no
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@@ -274,6 +287,18 @@ def process_run(
|
||||
|
||||
Parallel(n_jobs=n_jobs, verbose=10)(delayed(process_step)(i) for i in range(len(scan)))
|
||||
|
||||
def is_processed_JF_file(filepath, detector='JF16T03V01'):
|
||||
""" Checks if a given .h5 file from a Jungfrau detector has been processed or only
|
||||
contains raw adc values.
|
||||
"""
|
||||
import h5py
|
||||
f = h5py.File(filepath)
|
||||
try:
|
||||
data = f['data'][detector]
|
||||
except KeyError:
|
||||
raise ValueError(f"{filepath} does not seem to be an Jungfrau file from the detector {detector}.")
|
||||
return 'meta' in f['data'][detector].keys()
|
||||
|
||||
|
||||
class ROI:
|
||||
"""Definition of region of interest (ROI) in image coordinates.
|
||||
|
||||
@@ -27,6 +27,19 @@ def test_image_calculations():
|
||||
453842.6]
|
||||
assert np.allclose(res["JF16T03V01_intensity"], intensity)
|
||||
|
||||
def test_joblib_memory():
|
||||
""" We need joblib for fast caching of intermediate results in all cases. So we check
|
||||
if the basic function caching to disk works.
|
||||
"""
|
||||
def calc_example(x):
|
||||
return x**2
|
||||
|
||||
calc_cached = cristallina.analysis.memory.cache(calc_example)
|
||||
|
||||
assert calc_cached(8) == 64
|
||||
assert calc_cached.check_call_in_cache(8) == True
|
||||
|
||||
|
||||
def test_minimal_2d_gaussian():
|
||||
|
||||
image = np.array([[0,0,0,0,0],
|
||||
|
||||
Reference in New Issue
Block a user