cleaned up caching and small fixes

This commit is contained in:
2023-09-14 11:29:52 +02:00
parent 787c863cd1
commit be01a75c68
4 changed files with 88 additions and 34 deletions

View File

@@ -5,6 +5,7 @@ import matplotlib
from matplotlib import pyplot as plt
import warnings
# because of https://github.com/kornia/kornia/issues/1425
warnings.simplefilter("ignore", DeprecationWarning)
@@ -43,6 +44,7 @@ def ju_patch_less_verbose(ju_module):
ju_patch_less_verbose(ju)
def plot_correlation(x, y, ax=None, **ax_kwargs):
"""
Plots the correlation of x and y in a normalized scatterplot.
@@ -59,7 +61,7 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
ynorm = (y - np.mean(y)) / ystd
n = len(y)
r = 1 / (n) * sum(xnorm * ynorm)
if ax is None:
@@ -73,10 +75,11 @@ def plot_correlation(x, y, ax=None, **ax_kwargs):
return ax, r
def plot_channel(data : SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
def plot_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots a given channel from an SFDataFiles object.
Optionally: a matplotlib axis to plot into
"""
@@ -95,7 +98,6 @@ def plot_channel(data : SFDataFiles, channel_name, ax=None):
def axis_styling(ax, channel_name, description):
ax.set_title(channel_name)
# ax.set_xlabel('x')
# ax.set_ylabel('a.u.')
@@ -110,7 +112,7 @@ def axis_styling(ax, channel_name, description):
)
def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
def plot_1d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a single numeric value per pulse.
"""
@@ -131,7 +133,7 @@ def plot_1d_channel(data : SFDataFiles, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
def plot_2d_channel(data: SFDataFiles, channel_name, ax=None):
"""
Plots channel data for a channel that contains a 1d array of numeric values per pulse.
"""
@@ -153,22 +155,22 @@ def plot_2d_channel(data : SFDataFiles, channel_name, ax=None):
axis_styling(ax, channel_name, description)
def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
def plot_detector_image(image_data, channel_name=None, ax=None, rois=None, norms=None, log_colorscale=False):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
"""
im = data[channel_name][pulse]
im = image_data
def log_transform(z):
return np.log(np.clip(z, 1E-12, np.max(z)))
return np.log(np.clip(z, 1e-12, np.max(z)))
if log_colorscale:
im = log_transform(im)
im = log_transform(im)
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
@@ -189,7 +191,9 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
for i, roi in enumerate(rois):
# Create a rectangle with ([bottom left corner coordinates], width, height)
rect = patches.Rectangle(
[roi.left, roi.bottom], roi.width, roi.height,
[roi.left, roi.bottom],
roi.width,
roi.height,
linewidth=3,
edgecolor=f"C{i}",
facecolor="none",
@@ -199,9 +203,26 @@ def plot_image_channel(data : SFDataFiles, channel_name, pulse=0, ax=None, rois=
description = f"mean: {mean:.2e},\nstd: {std:.2e}"
axis_styling(ax, channel_name, description)
plt.legend(loc=4)
ax.legend(loc=4)
def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
def plot_image_channel(data: SFDataFiles, channel_name, pulse=0, ax=None, rois=None, norms=None, log_colorscale=False):
"""
Plots channel data for a channel that contains an image (2d array) of numeric values per pulse.
Optional:
- rois: draw a rectangular patch for the given roi(s)
- norms: [min, max] values for colormap
- log_colorscale: True for a logarithmic colormap
"""
image_data = data[channel_name][pulse]
plot_detector_image(
image_data, channel_name=channel_name, ax=ax, rois=rois, norms=norms, log_colorscale=log_colorscale
)
def plot_spectrum_channel(data: SFDataFiles, channel_name_x, channel_name_y, average=True, pulse=0, ax=None):
"""
Plots channel data for two channels where the first is taken as the (constant) x-axis
and the second as the y-axis (here we take by default the mean over the individual pulses).
@@ -217,12 +238,11 @@ def plot_spectrum_channel(data : SFDataFiles, channel_name_x, channel_name_y, av
y_data = mean_over_frames
else:
y_data = data[channel_name_y].data[pulse]
if ax is None:
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(data[channel_name_x].data[0], y_data)
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
description = None # f"mean: {mean:.2e},\nstd: {std:.2e}"
ax.set_xlabel(channel_name_x)
axis_styling(ax, channel_name_y, description)