basics and tests done.

This commit is contained in:
2022-07-12 17:15:28 +02:00
parent 2b077152f4
commit b5a54de804
6 changed files with 433 additions and 3 deletions

190
src/cristallina/plot.py Normal file
View File

@@ -0,0 +1,190 @@
import re
from collections import defaultdict
import matplotlib
from matplotlib import pyplot as plt
import warnings
# because of https://github.com/kornia/kornia/issues/1425
warnings.simplefilter("ignore", DeprecationWarning)
import numpy as np
from tqdm import tqdm
from matplotlib import patches
from pathlib import Path
from sfdata import SFDataFiles, sfdatafile, SFScanInfo
import jungfrau_utils as ju
from .utils import ROI
def ju_patch_less_verbose(ju_module):
"""Quick monkey patch to suppress verbose messages from gain & pedestal file searcher."""
ju_module.swissfel_helpers._locate_gain_file = ju_module.swissfel_helpers.locate_gain_file
ju_module.swissfel_helpers._locate_pedestal_file = ju_module.swissfel_helpers.locate_pedestal_file
def less_verbose_gain(*args, **kwargs):
kwargs["verbose"] = False
return ju_module.swissfel_helpers._locate_gain_file(*args, **kwargs)
def less_verbose_pedestal(*args, **kwargs):
kwargs["verbose"] = False
return ju_module.swissfel_helpers._locate_pedestal_file(*args, **kwargs)
# ju_module.swissfel_helpers.locate_gain_file = less_verbose_gain
# ju_module.swissfel_helpers.locate_pedestal_file = less_verbose_pedestal
ju_module.file_adapter.locate_gain_file = less_verbose_gain
ju_module.file_adapter.locate_pedestal_file = less_verbose_pedestal
ju_patch_less_verbose(ju)
def plot_correlation(x, y, ax=None, **ax_kwargs):
"""
Plots the correlation of x and y in a normalized scatterplot.
If no axis is given a figure and axis are created.
Returns: The axis object and the correlation coefficient between
x and y.
"""
xstd = np.std(x)
ystd = np.std(y)
xnorm = (x - np.mean(x)) / xstd
ynorm = (y - np.mean(y)) / ystd
n = len(ys)
r = 1 / (n) * sum(xnorm * ynorm)
if ax is None:
fig, ax = plt.subplots()
if ax_kwargs is not None:
ax.set(**ax_kwargs)
ax.plot(xnorm, ynorm, "o")
ax.text(0.95, 0.05, f"r = {r:.2f}", transform=ax.transAxes, horizontalalignment="right")
return ax, r
def plot_channel(data, channel_name, ax=None):
channel_dim = len(data[channel_name].shape)
# dim == 3: a 2D Image
# dim == 2: an array per pulse (probably)
# dim == 1: a single value per pulse (probably)
plot_f = {
1: plot_1d_channel,
2: plot_2d_channel,
3: plot_image_channel,
}
plot_f[channel_dim](data, channel_name, ax)
def axis_styling(ax, channel_name, description):
ax.set_title(channel_name)
# ax.set_xlabel('x')
# ax.set_ylabel('a.u.')
ax.ticklabel_format(useOffset=False)
ax.text(
0.05,
0.05,
description,
transform=ax.transAxes,
horizontalalignment="left",
bbox=dict(boxstyle="round", color="lightgrey"),
)
def plot_1d_channel(data, channel_name, ax=None):
"""
Plots channel data for a channel that contains a single numeric value per pulse.
"""
try:
mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data)
n_entries_per_frame = data[channel_name].shape
except TypeError:
print(f"Cannot parse channel {channel_name}. Check dimensionality.")
return
y_data = data[channel_name].data
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(y_data)
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
def plot_2d_channel(data, channel_name, ax=None):
"""
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
"""
try:
mean, std = np.mean(data[channel_name].data), np.std(data[channel_name].data)
# data[channel_name].data
mean_over_frames = np.mean(data[channel_name].data, axis=0)
except TypeError:
print(f"Unknown data in channel {channel_name}.")
return
y_data = mean_over_frames
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(y_data)
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
def plot_image_channel(data, channel_name, pulse=0, ax=None, rois=None, norms=None):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
"""
im = data[channel_name].data[pulse]
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
std = im.std()
mean = im.mean()
if norms is None:
norm = matplotlib.colors.Normalize(vmin=mean - std, vmax=mean + std)
else:
norm = matplotlib.colors.Normalize(vmin=norms[0], vmax=norms[1])
ax.imshow(im, norm=norm)
ax.invert_yaxis()
if rois is not None:
# Plot rois if given
for i, roi in enumerate(rois):
# Create a rectangle with ([bottom left corner coordinates], width, height)
rect = patches.Rectangle(
[roi.left, roi.bottom], roi.width, roi.height,
linewidth=3,
edgecolor=f"C{i}",
facecolor="none",
label=roi.name,
)
ax.add_patch(rect)
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
plt.legend(loc=4)