From be01a75c688ce69ef88dd1666e2d5b3f4395d642 Mon Sep 17 00:00:00 2001 From: Alexander Steppke Date: Thu, 14 Sep 2023 11:29:52 +0200 Subject: [PATCH] cleaned up caching and small fixes --- src/cristallina/analysis.py | 26 ++++++++--------- src/cristallina/plot.py | 56 +++++++++++++++++++++++++------------ src/cristallina/utils.py | 27 +++++++++++++++++- tests/test_analysis.py | 13 +++++++++ 4 files changed, 88 insertions(+), 34 deletions(-) diff --git a/src/cristallina/analysis.py b/src/cristallina/analysis.py index a9ed7d2..447ea0f 100644 --- a/src/cristallina/analysis.py +++ b/src/cristallina/analysis.py @@ -15,9 +15,6 @@ from joblib import Parallel, delayed, Memory from . import utils from .utils import ROI -memory = None - - def setup_cachedirs(pgroup=None, cachedir=None): """ Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841") @@ -26,7 +23,6 @@ def setup_cachedirs(pgroup=None, cachedir=None): If heuristics fail we use "/tmp" as a non-persistent alternative. """ - global memory if cachedir is not None: # explicit directory given, use this choice memory = Memory(cachedir, verbose=0, compress=2) @@ -51,13 +47,13 @@ def setup_cachedirs(pgroup=None, cachedir=None): return memory - -setup_cachedirs() +memory = None +memory = setup_cachedirs() @memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes def perform_image_calculations( - fileset, + filesets, channel="JF16T03V01", alignment_channels=None, batch_size=10, @@ -67,7 +63,7 @@ def perform_image_calculations( ): """ Performs one or more calculations ("sum", "mean" or "std") for a given region of interest (roi) - for an image channel from a fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object). + for an image channel from a fileset (e.g. ["run0352/data/acq0001.*.h5"] or step.fnames from a SFScanInfo object). Allows alignment, i.e. reducing only to a common subset with other channels. @@ -85,7 +81,7 @@ def perform_image_calculations( "std": ["mean", np.std], } - with SFDataFiles(*fileset) as data: + with SFDataFiles(*filesets) as data: if alignment_channels is not None: channels = [channel] + [ch for ch in alignment_channels] else: @@ -124,7 +120,7 @@ def perform_image_calculations( @memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes def sum_images( - fileset, + filesets, channel="JF16T03V01", alignment_channels=None, batch_size=10, @@ -146,7 +142,7 @@ def sum_images( """ return perform_image_calculations( - fileset, + filesets, channel=channel, alignment_channels=alignment_channels, batch_size=batch_size, @@ -157,7 +153,7 @@ def sum_images( def get_contrast_images( - fileset, + filesets, channel="JF16T03V01", alignment_channels=None, batch_size=10, @@ -169,7 +165,7 @@ def get_contrast_images( """ return perform_image_calculations( - fileset, + filesets, channel=channel, alignment_channels=alignment_channels, batch_size=batch_size, @@ -181,7 +177,7 @@ def get_contrast_images( @memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes def perform_image_stack_sum( - fileset, + filesets, channel="JF16T03V01", alignment_channels=None, batch_size=10, @@ -210,7 +206,7 @@ def perform_image_stack_sum( "std": ["mean", np.std], } - with SFDataFiles(*fileset) as data: + with SFDataFiles(*filesets) as data: if alignment_channels is not None: channels = [channel] + [ch for ch in alignment_channels] else: diff --git a/src/cristallina/plot.py b/src/cristallina/plot.py index be5cae3..da25a85 100644 --- a/src/cristallina/plot.py +++ b/src/cristallina/plot.py @@ -5,6 +5,7 @@ import matplotlib from matplotlib import pyplot as plt import warnings + # because of https://github.com/kornia/kornia/issues/1425 warnings.simplefilter("ignore", DeprecationWarning) @@ -43,6 +44,7 @@ def ju_patch_less_verbose(ju_module): ju_patch_less_verbose(ju) + def plot_correlation(x, y, ax=None, **ax_kwargs): """ Plots the correlation of x and y in a normalized scatterplot. @@ -59,7 +61,7 @@ def plot_correlation(x, y, ax=None, **ax_kwargs): ynorm = (y - np.mean(y)) / ystd n = len(y) - + r = 1 / (n) * sum(xnorm * ynorm) if ax is None: @@ -73,10 +75,11 @@ def plot_correlation(x, y, ax=None, **ax_kwargs): return ax, r -def plot_channel(data : SFDataFiles, channel_name, ax=None): - """ - Plots a given channel from an SFDataFiles object. - + +def plot_channel(data: SFDataFiles, channel_name, ax=None): + """ + Plots a given channel from an SFDataFiles object. + Optionally: a matplotlib axis to plot into """ @@ -95,7 +98,6 @@ def plot_channel(data : SFDataFiles, channel_name, ax=None): def axis_styling(ax, channel_name, description): - ax.set_title(channel_name) # ax.set_xlabel('x') # ax.set_ylabel('a.u.') @@ -110,7 +112,7 @@ def axis_styling(ax, channel_name, description): ) -def plot_1d_channel(data : SFDataFiles, channel_name, ax=None): +def plot_1d_channel(data: SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a single numeric value per pulse. """ @@ -131,7 +133,7 @@ def plot_1d_channel(data : SFDataFiles, channel_name, ax=None): axis_styling(ax, channel_name, description) -def plot_2d_channel(data : SFDataFiles, channel_name, ax=None): +def plot_2d_channel(data: SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a 1d array of numeric values per pulse. """ @@ -153,22 +155,22 @@ def plot_2d_channel(data : SFDataFiles, channel_name, ax=None): axis_styling(ax, channel_name, description) -def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False): +def plot_detector_image(image_data, channel_name=None, ax=None, rois=None, norms=None, log_colorscale=False): """ Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. - Optional: + Optional: - rois: draw a rectangular patch for the given roi(s) - norms: [min, max] values for colormap - log_colorscale: True for a logarithmic colormap """ - im = data[channel_name][pulse] + im = image_data def log_transform(z): - return np.log(np.clip(z, 1E-12, np.max(z))) + return np.log(np.clip(z, 1e-12, np.max(z))) if log_colorscale: - im = log_transform(im) + im = log_transform(im) if ax is None: fig, ax = plt.subplots(constrained_layout=True) @@ -189,7 +191,9 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois= for i, roi in enumerate(rois): # Create a rectangle with ([bottom left corner coordinates], width, height) rect = patches.Rectangle( - [roi.left, roi.bottom], roi.width, roi.height, + [roi.left, roi.bottom], + roi.width, + roi.height, linewidth=3, edgecolor=f"C{i}", facecolor="none", @@ -199,9 +203,26 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois= description = f"mean: {mean:.2e},\nstd: {std:.2e}" axis_styling(ax, channel_name, description) - plt.legend(loc=4) + ax.legend(loc=4) -def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None): + +def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False): + """ + Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. + Optional: + - rois: draw a rectangular patch for the given roi(s) + - norms: [min, max] values for colormap + - log_colorscale: True for a logarithmic colormap + """ + + image_data = data[channel_name][pulse] + + plot_detector_image( + image_data, channel_name=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale + ) + + +def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None): """ Plots channel data for two channels where the first is taken as the (constant) x-axis and the second as the y-axis (here we take by default the mean over the individual pulses). @@ -217,12 +238,11 @@ def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, av y_data = mean_over_frames else: y_data = data[channel_name_y].data[pulse] - if ax is None: fig, ax = plt.subplots(constrained_layout=True) ax.plot(data[channel_name_x].data[0], y_data) - description = None # f"mean: {mean:.2e},\nstd: {std:.2e}" + description = None # f"mean: {mean:.2e},\nstd: {std:.2e}" ax.set_xlabel(channel_name_x) axis_styling(ax, channel_name_y, description) diff --git a/src/cristallina/utils.py b/src/cristallina/utils.py index 9117545..eb09fb6 100644 --- a/src/cristallina/utils.py +++ b/src/cristallina/utils.py @@ -1,6 +1,7 @@ import yaml import os import json +import re from pathlib import Path from collections import defaultdict @@ -19,7 +20,7 @@ from joblib import Parallel, delayed, cpu_count def collect_runs_metadata(pgroup_data_dir: str | os.PathLike): """ - Generates metadata overview for all runs in a given p-group data directory. + Generates metadata overview for all runs in a given p-group data directory (e.g. "/sf/cristallina/data/p21261"). Not all available metadata is included, we skip e.g. lists of recorded channels. @@ -74,6 +75,18 @@ def collect_runs_metadata(pgroup_data_dir: str | os.PathLike): df = pd.DataFrame(measurements) + # add run and acquisition numbers from fileset + # kind of added after the fact but now everything is in place + run_no, acq_no = [], [] + for fileset in df["files"]: + pattern = r".*run(\d{4})/data/acq(\d{4})" + m = re.match(pattern, fileset) + run_no.append(int(m.groups()[0])) + acq_no.append(int(m.groups()[1])) + + df['run'] = run_no + df['acq'] = acq_no + return df @@ -274,6 +287,18 @@ def process_run( Parallel(n_jobs=n_jobs, verbose=10)(delayed(process_step)(i) for i in range(len(scan))) +def is_processed_JF_file(filepath, detector='JF16T03V01'): + """ Checks if a given .h5 file from a Jungfrau detector has been processed or only + contains raw adc values. + """ + import h5py + f = h5py.File(filepath) + try: + data = f['data'][detector] + except KeyError: + raise ValueError(f"{filepath} does not seem to be an Jungfrau file from the detector {detector}.") + return 'meta' in f['data'][detector].keys() + class ROI: """Definition of region of interest (ROI) in image coordinates. diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 26a989f..9fd5f79 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -27,6 +27,19 @@ def test_image_calculations(): 453842.6] assert np.allclose(res["JF16T03V01_intensity"], intensity) +def test_joblib_memory(): + """ We need joblib for fast caching of intermediate results in all cases. So we check + if the basic function caching to disk works. + """ + def calc_example(x): + return x**2 + + calc_cached = cristallina.analysis.memory.cache(calc_example) + + assert calc_cached(8) == 64 + assert calc_cached.check_call_in_cache(8) == True + + def test_minimal_2d_gaussian(): image = np.array([[0,0,0,0,0],