improved persistent caching, added config module extension

This commit is contained in:
2023-01-29 14:17:43 +01:00
parent 8c271a310c
commit ba63bd7f7b
5 changed files with 60 additions and 24 deletions

View File

@@ -15,6 +15,10 @@ except PackageNotFoundError: # pragma: no cover
finally:
del version, PackageNotFoundError
from . import config
from . import utils
from . import analysis
from . import plot
# analysis requires a pgroup for persistent caching, we try
# heuristics but this can also be (re-)set after the import.
from . import analysis

View File

@@ -13,23 +13,40 @@ from joblib import Parallel, delayed, Memory
from . import utils
from .utils import ROI
# Ideally we automatically determine the cache directory tied to the p-group
# If this fails we try with a fixed path for now, and finally use the
# current temporary directory.
try:
pgroup = utils.heuristic_extract_pgroup()
location = f"/das/work/units/cristallina/p{pgroup}/cachedir"
except KeyError as e:
print(e)
location = "/das/work/units/cristallina/p19739/cachedir"
memory = None
try:
memory = Memory(location, verbose=0, compress=2)
except PermissionError as e:
location = "/tmp"
memory = Memory(location, verbose=0, compress=2)
def setup_cachedirs(pgroup=None, cachedir=None):
"""
Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841")
or an explicitly given directory.
If heuristics fail we use "/tmp" as a non-persistent alternative.
"""
global memory
if cachedir is not None:
# explicit directory given, use this choice
memory = Memory(cachedir, verbose=0, compress=2)
return
try:
if pgroup is None:
pgroup_no = utils.heuristic_extract_pgroup()
else:
parts = re.split(r'(\d.*)', pgroup) # ['p', '2343', '']
pgroup_no = parts[-2]
cachedir = f"/das/work/units/cristallina/p{pgroup_no}/cachedir"
except KeyError as e:
print(e)
cachedir = "/das/work/units/cristallina/p19739/cachedir"
try:
memory = Memory(cachedir, verbose=0, compress=2)
except PermissionError as e:
location = "/tmp"
memory = Memory(cachedir, verbose=0, compress=2)
setup_cachedirs()
@memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes
def sum_images(
@@ -42,11 +59,16 @@ def sum_images(
):
"""
Sums a given region of interest (roi) for an image channel from a
given fileset (e.g. "run0352/data/acq0001.*.h5").
given fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object).
Allows alignment, i.e. reducing only to a common subset with other channels.
Summation is performed in batches, preview only sums and returns the first batch.
Summation is performed in batches to reduce maximum memory requirements.
Preview only sums and returns the first batch.
Returns a dictionary ({"JF16T03V01_intensity":[11, 18, 21, 55, ...]})
with the given channel intensity for each pulse and corresponding pulse id.
"""
with SFDataFiles(*fileset) as data:

View File

@@ -0,0 +1,3 @@
# not used yet
# PGROUP = ''

View File

@@ -73,7 +73,12 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
return ax, r
def plot_channel(data, channel_name, ax=None):
def plot_channel(data : SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
Optionally: a matplotlib axis to plot into
"""
channel_dim = len(data[channel_name].shape)
# dim == 3: a 2D Image
@@ -105,7 +110,7 @@ def axis_styling(ax, channel_name, description):
)
def plot_1d_channel(data, channel_name, ax=None):
def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a single numeric value per pulse.
"""
@@ -126,7 +131,7 @@ def plot_1d_channel(data, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_2d_channel(data, channel_name, ax=None):
def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
"""
@@ -148,7 +153,7 @@ def plot_2d_channel(data, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=None):
def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
"""
@@ -186,7 +191,7 @@ def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=No
axis_styling(ax, channel_name, description)
plt.legend(loc=4)
def plot_spectrum_channel(data, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
"""
Plots channel data for two channels where the first is taken as the (constant) x-axis
and the second as the y-axis (here we take by default the mean over the individual pulses).

View File

@@ -136,7 +136,9 @@ def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_sh
class ROI:
"""Definition of region of interest (ROI) in image coordinates.
Example: ROI(left=10, right=20, bottom=100, top=200)
Example: ROI(left=10, right=20, bottom=100, top=200).
Directions assume that lower left corner of image is at (x=0, y=0).
"""
def __init__(
@@ -202,7 +204,7 @@ class ROI:
return f"ROI(bottom={self.bottom},top={self.top},left={self.left},right={self.right})"
def __eq__(self, other):
# we disregard the name
# we disregard the name for comparisons
return (self.left, self.right, self.bottom, self.top) == (other.left, other.right, other.bottom, other.top)
def __ne__(self, other):