diff --git a/src/cristallina/analysis.py b/src/cristallina/analysis.py new file mode 100644 index 0000000..449bd6c --- /dev/null +++ b/src/cristallina/analysis.py @@ -0,0 +1,64 @@ +import re +from collections import defaultdict + +import numpy as np + +from sfdata import SFDataFiles, sfdatafile, SFScanInfo + +import joblib +from joblib import Parallel, delayed, Memory + +# TODO: generalize this for all analysis, i.e. find appropriate p-group +location = "/das/work/units/cristallina/p19739/cachedir" +memory = Memory(location, verbose=0, compress=2) + +import jungfrau_utils as ju +ju_patch_less_verbose(ju) + +from .utils import ROI + + +@memory.cache +def sum_images(fileset, channel="JF16T03V01", alignment_channels=None, batch_size=10, roi : ROI = None, preview=False): + """ + Sums a given region of interest (roi) for an image channel from a + given fileset (e.g. "run0352/data/acq0001.*.h5"). + + Allows alignment, i.e. reducing only to a common subset with other channels. + + Summation is performed in batches, preview only sums and returns the first batch. + """ + + with SFDataFiles(*fileset) as data: + if alignment_channels is not None: + channels = [channel] + [ch for ch in alignment_channels] + else: + channels = [channel] + + subset = data[channels] + + subset.drop_missing() + + Images = subset[channel] + + res = defaultdict(list) + res["roi"] = repr(roi) + + for image_slice in Images.in_batches(batch_size): + + index_slice, im = image_slice + + if roi is None: + im_ROI = im[:] + else: + im_ROI = im[:, roi.rows, roi.cols] + + res[f"{channel}_intensity"].append(np.sum(im_ROI, axis=(1, 2))) + res["pids"].append(Images.pids[index_slice]) + + # only return first batch + if preview: + break + + return res + diff --git a/src/cristallina/plot.py b/src/cristallina/plot.py new file mode 100644 index 0000000..915a7b1 --- /dev/null +++ b/src/cristallina/plot.py @@ -0,0 +1,190 @@ +import re +from collections import defaultdict + +import matplotlib +from matplotlib import pyplot as plt + +import warnings +# because of https://github.com/kornia/kornia/issues/1425 +warnings.simplefilter("ignore", DeprecationWarning) + +import numpy as np +from tqdm import tqdm +from matplotlib import patches + +from pathlib import Path + +from sfdata import SFDataFiles, sfdatafile, SFScanInfo +import jungfrau_utils as ju + +from .utils import ROI + +def ju_patch_less_verbose(ju_module): + """Quick monkey patch to suppress verbose messages from gain & pedestal file searcher.""" + ju_module.swissfel_helpers._locate_gain_file = ju_module.swissfel_helpers.locate_gain_file + ju_module.swissfel_helpers._locate_pedestal_file = ju_module.swissfel_helpers.locate_pedestal_file + + def less_verbose_gain(*args, **kwargs): + kwargs["verbose"] = False + return ju_module.swissfel_helpers._locate_gain_file(*args, **kwargs) + + def less_verbose_pedestal(*args, **kwargs): + kwargs["verbose"] = False + return ju_module.swissfel_helpers._locate_pedestal_file(*args, **kwargs) + + # ju_module.swissfel_helpers.locate_gain_file = less_verbose_gain + # ju_module.swissfel_helpers.locate_pedestal_file = less_verbose_pedestal + + ju_module.file_adapter.locate_gain_file = less_verbose_gain + ju_module.file_adapter.locate_pedestal_file = less_verbose_pedestal + + +ju_patch_less_verbose(ju) + + + + + +def plot_correlation(x, y, ax=None, **ax_kwargs): + """ + Plots the correlation of x and y in a normalized scatterplot. + If no axis is given a figure and axis are created. + + Returns: The axis object and the correlation coefficient between + x and y. + """ + + xstd = np.std(x) + ystd = np.std(y) + + xnorm = (x - np.mean(x)) / xstd + ynorm = (y - np.mean(y)) / ystd + + n = len(ys) + + r = 1 / (n) * sum(xnorm * ynorm) + + if ax is None: + fig, ax = plt.subplots() + + if ax_kwargs is not None: + ax.set(**ax_kwargs) + + ax.plot(xnorm, ynorm, "o") + ax.text(0.95, 0.05, f"r = {r:.2f}", transform=ax.transAxes, horizontalalignment="right") + + return ax, r + +def plot_channel(data, channel_name, ax=None): + + channel_dim = len(data[channel_name].shape) + # dim == 3: a 2D Image + # dim == 2: an array per pulse (probably) + # dim == 1: a single value per pulse (probably) + + plot_f = { + 1: plot_1d_channel, + 2: plot_2d_channel, + 3: plot_image_channel, + } + + plot_f[channel_dim](data, channel_name, ax) + + +def axis_styling(ax, channel_name, description): + + ax.set_title(channel_name) + # ax.set_xlabel('x') + # ax.set_ylabel('a.u.') + ax.ticklabel_format(useOffset=False) + ax.text( + 0.05, + 0.05, + description, + transform=ax.transAxes, + horizontalalignment="left", + bbox=dict(boxstyle="round", color="lightgrey"), + ) + + +def plot_1d_channel(data, channel_name, ax=None): + """ + Plots channel data for a channel that contains a single numeric value per pulse. + """ + try: + mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data) + n_entries_per_frame = data[channel_name].shape + except TypeError: + print(f"Cannot parse channel {channel_name}. Check dimensionality.") + return + + y_data = data[channel_name].data + + if ax is None: + fig, ax = plt.subplots(constrained_layout=True) + + ax.plot(y_data) + description = f"mean: {mean:.2e},\nstd: {std:.2e}" + axis_styling(ax, channel_name, description) + + +def plot_2d_channel(data, channel_name, ax=None): + """ + Plots channel data for a channel that contains a 1d array of numeric values per pulse. + """ + try: + mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data) + # data[channel_name].data + mean_over_frames = np.mean(data[channel_name].data, axis=0) + except TypeError: + print(f"Unknown data in channel {channel_name}.") + return + + y_data = mean_over_frames + + if ax is None: + fig, ax = plt.subplots(constrained_layout=True) + + ax.plot(y_data) + description = f"mean: {mean:.2e},\nstd: {std:.2e}" + axis_styling(ax, channel_name, description) + + +def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=None): + """ + Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. + """ + + im = data[channel_name].data[pulse] + + if ax is None: + fig, ax = plt.subplots(constrained_layout=True) + + std = im.std() + mean = im.mean() + + if norms is None: + norm = matplotlib.colors.Normalize(vmin=mean - std, vmax=mean + std) + else: + norm = matplotlib.colors.Normalize(vmin=norms[0], vmax=norms[1]) + + ax.imshow(im, norm=norm) + ax.invert_yaxis() + + if rois is not None: + # Plot rois if given + for i, roi in enumerate(rois): + # Create a rectangle with ([bottom left corner coordinates], width, height) + rect = patches.Rectangle( + [roi.left, roi.bottom], roi.width, roi.height, + linewidth=3, + edgecolor=f"C{i}", + facecolor="none", + label=roi.name, + ) + ax.add_patch(rect) + + description = f"mean: {mean:.2e},\nstd: {std:.2e}" + axis_styling(ax, channel_name, description) + plt.legend(loc=4) + diff --git a/src/cristallina/utils.py b/src/cristallina/utils.py new file mode 100644 index 0000000..03dbfc8 --- /dev/null +++ b/src/cristallina/utils.py @@ -0,0 +1,108 @@ +import yaml +import os + +from sfdata import SFDataFiles, sfdatafile, SFScanInfo + + +def print_run_info( + run_number=42, print_channels=True, extra_verbose=False, base_path="/sf/cristallina/data/p19739/raw/" +): + """Prints overview of run information. + + Extra verbose output contains all files and pids. + """ + + scan = SFScanInfo(f"{base_path}/run{run_number:04}/meta/scan.json") + + short = {} + for key, value in scan.info.items(): + if isinstance(value, list): + short[key] = value[:2] + short[key].append("...") + short[key].append(value[-1]) + else: + short[key] = value + + if extra_verbose: + print(yaml.dump(scan.info, sort_keys=False, default_flow_style=False)) + else: + print(yaml.dump(short, sort_keys=False, default_flow_style=False)) + + print(f"Number of steps: {len(scan.info['scan_readbacks'])}") + + total_size = 0 + for files_in_step in scan.info["scan_files"]: + for file in files_in_step: + try: + total_size += os.path.getsize(file) + except FileNotFoundError as e: + pass + + print(f"Total file size: {total_size/(1024*1024*1024):.1f} GB\n") + + for step in scan: + ch = step.channels + print("\n".join([str(c) for c in ch])) + # print only channels for first step + break + + + +class ROI: + """Definition of region of interest (ROI) in image coordinates. + + Example: ROI(left=10, right=20, bottom=100, top=200) + """ + + def __init__( + self, + left: int = None, + right: int = None, + top: int = None, + bottom: int = None, + center_x: int = None, + center_y: int = None, + width: int = None, + height: int = None, + name: str = None, + ): + + if None not in (left, right, bottom, top): + self.left, self.right, self.bottom, self.top, = ( + left, + right, + bottom, + top, + ) + elif None not in (center_x, center_y, width, height): + self.from_centers_widths(center_x, center_y, width, height) + else: + raise ValueError("No valid ROI definition.") + + self.name = name + + def from_centers_widths(self, center_x, center_y, width, height): + self.left = center_x - width // 2 + self.right = center_x + width // 2 + + self.top = center_y + height // 2 + self.bottom = center_y - height // 2 + + @property + def rows(self): + return slice(self.bottom, self.top) + + @property + def cols(self): + return slice(self.left, self.right) + + @property + def width(self): + return self.right - self.left + + @property + def height(self): + return self.top - self.bottom + + def __repr__(self): + return f"ROI(bottom={self.bottom}, top={self.top}, left={self.left}, right={self.right})" diff --git a/tests/conftest.py b/tests/conftest.py index a4b4bda..5e41c6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,29 @@ """ - Dummy conftest.py for cristallina. + conftest.py for cristallina. - If you don't know what this is for, just leave it empty. Read more about conftest.py under: - https://docs.pytest.org/en/stable/fixture.html - https://docs.pytest.org/en/stable/writing_plugins.html """ -# import pytest +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runregression", action="store_true", default=False, help="run slow regression tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "regression: mark test as a slow regression test to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runregression"): + # --runregression given in cli: do not skip slow tests + return + skip_regression = pytest.mark.skip(reason="need --runregression option to run") + for item in items: + if "regression" in item.keywords: + item.add_marker(skip_regression) diff --git a/tests/test_regression.py b/tests/test_regression.py new file mode 100644 index 0000000..99c9787 --- /dev/null +++ b/tests/test_regression.py @@ -0,0 +1,28 @@ +""" +Regression tests for functionality that we assume to be working for our analysis. +""" + +import pytest +from sfdata import SFDataFiles, sfdatafile, SFScanInfo +import tracemalloc + + +@pytest.mark.regression +def test_JU_memory(): + base_path = "/sf/cristallina/data/p19739/raw/" + run_number = 49 + averages = [] + + tracemalloc.start() + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6:.1f}MB; Peak was {peak / 10**6:.1f}MB") + + + with SFDataFiles(f"{base_path}/run{run_number:04}/data/acq00*.h5") as data: + ch = data["JF16T03V01"] + + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6:.1f}MB; Peak was {peak / 10**6:.1f}MB") + tracemalloc.stop() + + assert current/10**6 < 100, "Memory consumption should be below 100 MB" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..0d5f020 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,21 @@ +import pytest + +from cristallina.utils import ROI, print_run_info + +__author__ = "Alexander Steppke" + +def test_print(capsys): + + print_run_info(247) + captured = capsys.readouterr() + assert "15453208940" in captured.out + assert "LiTbF4_rocking" in captured.out + + + +def test_ROI(): + """API Tests""" + r = ROI(left=1, right=2, top=4, bottom=2) + + assert r.width == 1 + assert r.height == 2