diff --git a/src/cristallina/__init__.py b/src/cristallina/__init__.py index eace891..2f0b74a 100644 --- a/src/cristallina/__init__.py +++ b/src/cristallina/__init__.py @@ -15,6 +15,10 @@ except PackageNotFoundError: # pragma: no cover finally: del version, PackageNotFoundError +from . import config from . import utils -from . import analysis from . import plot + +# analysis requires a pgroup for persistent caching, we try +# heuristics but this can also be (re-)set after the import. +from . import analysis \ No newline at end of file diff --git a/src/cristallina/analysis.py b/src/cristallina/analysis.py index c736ab3..928096c 100644 --- a/src/cristallina/analysis.py +++ b/src/cristallina/analysis.py @@ -13,23 +13,40 @@ from joblib import Parallel, delayed, Memory from . import utils from .utils import ROI -# Ideally we automatically determine the cache directory tied to the p-group -# If this fails we try with a fixed path for now, and finally use the -# current temporary directory. -try: - pgroup = utils.heuristic_extract_pgroup() - location = f"/das/work/units/cristallina/p{pgroup}/cachedir" -except KeyError as e: - print(e) - location = "/das/work/units/cristallina/p19739/cachedir" +memory = None -try: - memory = Memory(location, verbose=0, compress=2) -except PermissionError as e: - location = "/tmp" - memory = Memory(location, verbose=0, compress=2) +def setup_cachedirs(pgroup=None, cachedir=None): + """ + Sets the path to a persistent cache directory either from the given p-group (e.g. "p20841") + or an explicitly given directory. + + If heuristics fail we use "/tmp" as a non-persistent alternative. + """ + global memory + if cachedir is not None: + # explicit directory given, use this choice + memory = Memory(cachedir, verbose=0, compress=2) + return + try: + if pgroup is None: + pgroup_no = utils.heuristic_extract_pgroup() + else: + parts = re.split(r'(\d.*)', pgroup) # ['p', '2343', ''] + pgroup_no = parts[-2] + cachedir = f"/das/work/units/cristallina/p{pgroup_no}/cachedir" + except KeyError as e: + print(e) + cachedir = "/das/work/units/cristallina/p19739/cachedir" + + try: + memory = Memory(cachedir, verbose=0, compress=2) + except PermissionError as e: + location = "/tmp" + memory = Memory(cachedir, verbose=0, compress=2) + +setup_cachedirs() @memory.cache(ignore=["batch_size"]) # we ignore batch_size for caching purposes def sum_images( @@ -42,11 +59,16 @@ def sum_images( ): """ Sums a given region of interest (roi) for an image channel from a - given fileset (e.g. "run0352/data/acq0001.*.h5"). + given fileset (e.g. "run0352/data/acq0001.*.h5" or step.fnames from a SFScanInfo object). Allows alignment, i.e. reducing only to a common subset with other channels. - Summation is performed in batches, preview only sums and returns the first batch. + Summation is performed in batches to reduce maximum memory requirements. + + Preview only sums and returns the first batch. + + Returns a dictionary ({"JF16T03V01_intensity":[11, 18, 21, 55, ...]}) + with the given channel intensity for each pulse and corresponding pulse id. """ with SFDataFiles(*fileset) as data: diff --git a/src/cristallina/config.py b/src/cristallina/config.py new file mode 100644 index 0000000..a4f7357 --- /dev/null +++ b/src/cristallina/config.py @@ -0,0 +1,3 @@ + +# not used yet +# PGROUP = '' \ No newline at end of file diff --git a/src/cristallina/plot.py b/src/cristallina/plot.py index c7e11e9..e4fca2c 100644 --- a/src/cristallina/plot.py +++ b/src/cristallina/plot.py @@ -73,7 +73,12 @@ def plot_correlation(x, y, ax=None, **ax_kwargs): return ax, r -def plot_channel(data, channel_name, ax=None): +def plot_channel(data : SFDataFiles, channel_name, ax=None): + """ + Plots a given channel from an SFDataFiles object. + + Optionally: a matplotlib axis to plot into + """ channel_dim = len(data[channel_name].shape) # dim == 3: a 2D Image @@ -105,7 +110,7 @@ def axis_styling(ax, channel_name, description): ) -def plot_1d_channel(data, channel_name, ax=None): +def plot_1d_channel(data : SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a single numeric value per pulse. """ @@ -126,7 +131,7 @@ def plot_1d_channel(data, channel_name, ax=None): axis_styling(ax, channel_name, description) -def plot_2d_channel(data, channel_name, ax=None): +def plot_2d_channel(data : SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a 1d array of numeric values per pulse. """ @@ -148,7 +153,7 @@ def plot_2d_channel(data, channel_name, ax=None): axis_styling(ax, channel_name, description) -def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=None): +def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None): """ Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. """ @@ -186,7 +191,7 @@ def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=No axis_styling(ax, channel_name, description) plt.legend(loc=4) -def plot_spectrum_channel(data, channel_name_x, channel_name_y, average=True, pulse=0, ax=None): +def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None): """ Plots channel data for two channels where the first is taken as the (constant) x-axis and the second as the y-axis (here we take by default the mean over the individual pulses). diff --git a/src/cristallina/utils.py b/src/cristallina/utils.py index 018913c..bdcbb07 100644 --- a/src/cristallina/utils.py +++ b/src/cristallina/utils.py @@ -136,7 +136,9 @@ def process_run(run_number, rois,detector='JF16T03V01', calculate =None, only_sh class ROI: """Definition of region of interest (ROI) in image coordinates. - Example: ROI(left=10, right=20, bottom=100, top=200) + Example: ROI(left=10, right=20, bottom=100, top=200). + + Directions assume that lower left corner of image is at (x=0, y=0). """ def __init__( @@ -202,7 +204,7 @@ class ROI: return f"ROI(bottom={self.bottom},top={self.top},left={self.left},right={self.right})" def __eq__(self, other): - # we disregard the name + # we disregard the name for comparisons return (self.left, self.right, self.bottom, self.top) == (other.left, other.right, other.bottom, other.top) def __ne__(self, other):