import re from collections import defaultdict import matplotlib from matplotlib import pyplot as plt import warnings # because of https://github.com/kornia/kornia/issues/1425 warnings.simplefilter("ignore", DeprecationWarning) import numpy as np from tqdm import tqdm from matplotlib import patches from pathlib import Path import matplotlib as mpl from sfdata import SFDataFiles, sfdatafile, SFScanInfo import jungfrau_utils as ju from . import utils from .utils import ROI # setup style sheet plt.style.use("cristallina.cristallina_style") def ju_patch_less_verbose(ju_module): """Quick monkey patch to suppress verbose messages from gain & pedestal file searcher. Not required for newer versions of ju.""" if hasattr(ju_module, "swissfel_helpers"): ju_module.swissfel_helpers._locate_gain_file = ju_module.swissfel_helpers.locate_gain_file ju_module.swissfel_helpers._locate_pedestal_file = ju_module.swissfel_helpers.locate_pedestal_file def less_verbose_gain(*args, **kwargs): kwargs["verbose"] = False return ju_module.swissfel_helpers._locate_gain_file(*args, **kwargs) def less_verbose_pedestal(*args, **kwargs): kwargs["verbose"] = False return ju_module.swissfel_helpers._locate_pedestal_file(*args, **kwargs) # ju_module.swissfel_helpers.locate_gain_file = less_verbose_gain # ju_module.swissfel_helpers.locate_pedestal_file = less_verbose_pedestal ju_module.file_adapter.locate_gain_file = less_verbose_gain ju_module.file_adapter.locate_pedestal_file = less_verbose_pedestal ju_patch_less_verbose(ju) def plot_correlation(x, y, ax=None, **ax_kwargs): """ Plots the correlation of x and y in a normalized scatterplot. If no axis is given a figure and axis are created. Returns: The axis object and the correlation coefficient between x and y. """ xstd = np.std(x) ystd = np.std(y) xnorm = (x - np.mean(x)) / xstd ynorm = (y - np.mean(y)) / ystd n = len(y) r = 1 / (n) * sum(xnorm * ynorm) if ax is None: fig, ax = plt.subplots() if ax_kwargs is not None: ax.set(**ax_kwargs) ax.plot(xnorm, ynorm, "o") ax.text(0.95, 0.05, f"r = {r:.2f}", transform=ax.transAxes, horizontalalignment="right") return ax, r def plot_channel(data: SFDataFiles, channel_name, ax=None): """ Plots a given channel from an SFDataFiles object. Optionally: a matplotlib axis to plot into """ channel_dim = len(data[channel_name].shape) # dim == 3: a 2D Image # dim == 2: an array per pulse (probably) # dim == 1: a single value per pulse (probably) plot_f = { 1: plot_1d_channel, 2: plot_2d_channel, 3: plot_image_channel, } plot_f[channel_dim](data, channel_name, ax=ax) def axis_styling(ax, channel_name, description): ax.set_title(channel_name) # ax.set_xlabel('x') # ax.set_ylabel('a.u.') ax.ticklabel_format(useOffset=False) ax.text( 0.05, 0.05, description, transform=ax.transAxes, horizontalalignment="left", bbox=dict(boxstyle="round", color="lightgrey"), ) def plot_1d_channel(data: SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a single numeric value per pulse. """ try: mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data) n_entries_per_frame = data[channel_name].shape except TypeError: print(f"Cannot parse channel {channel_name}. Check dimensionality.") return y_data = data[channel_name].data if ax is None: fig, ax = plt.subplots(constrained_layout=True) ax.plot(y_data) description = f"mean: {mean:.2e},\nstd: {std:.2e}" axis_styling(ax, channel_name, description) def plot_2d_channel(data: SFDataFiles, channel_name, ax=None): """ Plots channel data for a channel that contains a 1d array of numeric values per pulse. """ try: mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data) # data[channel_name].data mean_over_frames = np.mean(data[channel_name].data, axis=0) except TypeError: print(f"Unknown data in channel {channel_name}.") return y_data = mean_over_frames if ax is None: fig, ax = plt.subplots(constrained_layout=True) ax.plot(y_data) description = f"mean: {mean:.2e},\nstd: {std:.2e}" axis_styling(ax, channel_name, description) def plot_detector_image( image_data, title=None, comment=None, ax=None, rois=None, norms=None, log_colorscale=False, show_legend=True, ax_colormap=matplotlib.colormaps["viridis"], **fig_kw, ): """ Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. Optional: - rois: draw a rectangular patch for the given roi(s) - norms: [min, max] values for colormap - log_colorscale: True for a logarithmic colormap - title: Title of the plot - show_legend: True if the legend box should be drawn - ax_colormap: a matplotlib colormap (viridis by default) """ im = image_data def log_transform(z): return np.log(np.clip(z, 1e-12, np.max(z))) if log_colorscale: im = log_transform(im) if ax is None: fig, ax = plt.subplots(constrained_layout=True, **fig_kw) std = im.std() mean = im.mean() if norms is None: norm = matplotlib.colors.Normalize(vmin=mean - std, vmax=mean + std) else: norm = matplotlib.colors.Normalize(vmin=norms[0], vmax=norms[1]) ax.imshow(im, norm=norm, cmap=ax_colormap) ax.invert_yaxis() if rois is not None: # Plot rois if given for i, roi in enumerate(rois): # Create a rectangle with ([bottom left corner coordinates], width, height) rect = patches.Rectangle( [roi.left, roi.bottom], roi.width, roi.height, linewidth=2, edgecolor=f"C{i}", facecolor="none", label=roi.name, ) ax.add_patch(rect) if comment is not None: description = f"{comment}\nmean: {mean:.2e},\nstd: {std:.2e}" else: description = f"mean: {mean:.2e},\nstd: {std:.2e}" if not show_legend: description = "" axis_styling(ax, title, description) def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False): """ Plots channel data for a channel that contains an image (2d array) of numeric values per pulse. Optional: - rois: draw a rectangular patch for the given roi(s) - norms: [min, max] values for colormap - log_colorscale: True for a logarithmic colormap """ image_data = data[channel_name][pulse] plot_detector_image(image_data, title=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale) def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None): """ Plots channel data for two channels where the first is taken as the (constant) x-axis and the second as the y-axis (here we take by default the mean over the individual pulses). """ try: mean, std = np.mean(data[channel_name_y].data), np.std(data[channel_name_y].data) mean_over_frames = np.mean(data[channel_name_y].data, axis=0) except TypeError: print(f"Unknown data in channel {channel_name_y}.") return if average: y_data = mean_over_frames else: y_data = data[channel_name_y].data[pulse] if ax is None: fig, ax = plt.subplots(constrained_layout=True) ax.plot(data[channel_name_x].data[0], y_data) description = None # f"mean: {mean:.2e},\nstd: {std:.2e}" ax.set_xlabel(channel_name_x) axis_styling(ax, channel_name_y, description) def line_plot_with_colorbar(xs,ys,colors, cmap=plt.cm.viridis, markers='o',markersize=6,alpha=1, title=None,xlabel=None,ylabel=None,cbar_label=None, **fig_kw): '''Plot lines with colorbar. xs, ys -> array of arrays colors -> array ''' fig,ax = plt.subplots(1,1,constrained_layout=True,**fig_kw) # normalise to [0..1] norm = mpl.colors.Normalize(vmin=np.min(colors),vmax=np.max(colors)) # create a ScalarMappable and initialize a data structure s_m = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) s_m.set_array([]) for x,y,col in zip(xs,ys,colors): ax.plot(x,y,color=s_m.to_rgba(col),marker=markers,markersize=markersize,alpha=alpha) if title: plt.suptitle(title) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) # add colorbar fig.colorbar(s_m,ax=ax,ticks=colors,label=cbar_label,alpha=alpha)