Files
cristallina_analysis_package/src/cristallina/plot.py
2025-04-02 04:15:31 +02:00

301 lines
8.9 KiB
Python

import re
from collections import defaultdict
import matplotlib
from matplotlib import pyplot as plt
import warnings
# because of https://github.com/kornia/kornia/issues/1425
warnings.simplefilter("ignore", DeprecationWarning)
import numpy as np
from tqdm import tqdm
from matplotlib import patches
from pathlib import Path
import matplotlib as mpl
from sfdata import SFDataFiles, sfdatafile, SFScanInfo
import jungfrau_utils as ju
from . import utils
from .utils import ROI
# setup style sheet
plt.style.use("cristallina.cristallina_style")
def ju_patch_less_verbose(ju_module):
"""Quick monkey patch to suppress verbose messages from gain & pedestal file searcher.
Not required for newer versions of ju."""
if hasattr(ju_module, "swissfel_helpers"):
ju_module.swissfel_helpers._locate_gain_file = ju_module.swissfel_helpers.locate_gain_file
ju_module.swissfel_helpers._locate_pedestal_file = ju_module.swissfel_helpers.locate_pedestal_file
def less_verbose_gain(*args, **kwargs):
kwargs["verbose"] = False
return ju_module.swissfel_helpers._locate_gain_file(*args, **kwargs)
def less_verbose_pedestal(*args, **kwargs):
kwargs["verbose"] = False
return ju_module.swissfel_helpers._locate_pedestal_file(*args, **kwargs)
# ju_module.swissfel_helpers.locate_gain_file = less_verbose_gain
# ju_module.swissfel_helpers.locate_pedestal_file = less_verbose_pedestal
ju_module.file_adapter.locate_gain_file = less_verbose_gain
ju_module.file_adapter.locate_pedestal_file = less_verbose_pedestal
ju_patch_less_verbose(ju)
def plot_correlation(x, y, ax=None, **ax_kwargs):
"""
Plots the correlation of x and y in a normalized scatterplot.
If no axis is given a figure and axis are created.
Returns: The axis object and the correlation coefficient between
x and y.
"""
xstd = np.std(x)
ystd = np.std(y)
xnorm = (x - np.mean(x)) / xstd
ynorm = (y - np.mean(y)) / ystd
n = len(y)
r = 1 / (n) * sum(xnorm * ynorm)
if ax is None:
fig, ax = plt.subplots()
if ax_kwargs is not None:
ax.set(**ax_kwargs)
ax.plot(xnorm, ynorm, "o")
ax.text(0.95, 0.05, f"r = {r:.2f}", transform=ax.transAxes, horizontalalignment="right")
return ax, r
def plot_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
Optionally: a matplotlib axis to plot into
"""
channel_dim = len(data[channel_name].shape)
# dim == 3: a 2D Image
# dim == 2: an array per pulse (probably)
# dim == 1: a single value per pulse (probably)
plot_f = {
1: plot_1d_channel,
2: plot_2d_channel,
3: plot_image_channel,
}
plot_f[channel_dim](data, channel_name, ax=ax)
def axis_styling(ax, channel_name, description):
ax.set_title(channel_name)
# ax.set_xlabel('x')
# ax.set_ylabel('a.u.')
ax.ticklabel_format(useOffset=False)
ax.text(
0.05,
0.05,
description,
transform=ax.transAxes,
horizontalalignment="left",
bbox=dict(boxstyle="round", color="lightgrey"),
)
def plot_1d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a single numeric value per pulse.
"""
try:
mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data)
n_entries_per_frame = data[channel_name].shape
except TypeError:
print(f"Cannot parse channel {channel_name}. Check dimensionality.")
return
y_data = data[channel_name].data
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(y_data)
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
def plot_2d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
"""
try:
mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data)
# data[channel_name].data
mean_over_frames = np.mean(data[channel_name].data, axis=0)
except TypeError:
print(f"Unknown data in channel {channel_name}.")
return
y_data = mean_over_frames
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(y_data)
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
def plot_detector_image(
image_data,
title=None,
comment=None,
ax=None,
rois=None,
norms=None,
log_colorscale=False,
show_legend=True,
ax_colormap=matplotlib.colormaps["viridis"],
**fig_kw,
):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
- title: Title of the plot
- show_legend: True if the legend box should be drawn
- ax_colormap: a matplotlib colormap (viridis by default)
"""
im = image_data
def log_transform(z):
return np.log(np.clip(z, 1e-12, np.max(z)))
if log_colorscale:
im = log_transform(im)
if ax is None:
fig, ax = plt.subplots(constrained_layout=True, **fig_kw)
std = im.std()
mean = im.mean()
if norms is None:
norm = matplotlib.colors.Normalize(vmin=mean - std, vmax=mean + std)
else:
norm = matplotlib.colors.Normalize(vmin=norms[0], vmax=norms[1])
ax.imshow(im, norm=norm, cmap=ax_colormap)
ax.invert_yaxis()
if rois is not None:
# Plot rois if given
for i, roi in enumerate(rois):
# Create a rectangle with ([bottom left corner coordinates], width, height)
rect = patches.Rectangle(
[roi.left, roi.bottom],
roi.width,
roi.height,
linewidth=2,
edgecolor=f"C{i}",
facecolor="none",
label=roi.name,
)
ax.add_patch(rect)
if comment is not None:
description = f"{comment}\nmean: {mean:.2e},\nstd: {std:.2e}"
else:
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
if not show_legend:
description = ""
axis_styling(ax, title, description)
def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
"""
image_data = data[channel_name][pulse]
plot_detector_image(image_data, title=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale)
def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
"""
Plots channel data for two channels where the first is taken as the (constant) x-axis
and the second as the y-axis (here we take by default the mean over the individual pulses).
"""
try:
mean, std = np.mean(data[channel_name_y].data), np.std(data[channel_name_y].data)
mean_over_frames = np.mean(data[channel_name_y].data, axis=0)
except TypeError:
print(f"Unknown data in channel {channel_name_y}.")
return
if average:
y_data = mean_over_frames
else:
y_data = data[channel_name_y].data[pulse]
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(data[channel_name_x].data[0], y_data)
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
ax.set_xlabel(channel_name_x)
axis_styling(ax, channel_name_y, description)
def line_plot_with_colorbar(xs,ys,colors, cmap=plt.cm.viridis,
markers='o',markersize=6,alpha=1,
title=None,xlabel=None,ylabel=None,cbar_label=None,
**fig_kw):
'''Plot lines with colorbar.
xs, ys -> array of arrays
colors -> array
'''
fig,ax = plt.subplots(1,1,constrained_layout=True,**fig_kw)
# normalise to [0..1]
norm = mpl.colors.Normalize(vmin=np.min(colors),vmax=np.max(colors))
# create a ScalarMappable and initialize a data structure
s_m = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
s_m.set_array([])
for x,y,col in zip(xs,ys,colors):
ax.plot(x,y,color=s_m.to_rgba(col),marker=markers,markersize=markersize,alpha=alpha)
if title:
plt.suptitle(title)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
# add colorbar
fig.colorbar(s_m,ax=ax,ticks=colors,label=cbar_label,alpha=alpha)