Simplify plot creation in panel_hdf_param_study

This commit is contained in:
usov_i 2022-12-14 19:07:34 +01:00
parent 21562ee85b
commit 7503076a1b

View File

@ -6,37 +6,26 @@ import numpy as np
from bokeh.io import curdoc from bokeh.io import curdoc
from bokeh.layouts import column, gridplot, row from bokeh.layouts import column, gridplot, row
from bokeh.models import ( from bokeh.models import (
BasicTicker,
BoxZoomTool,
Button, Button,
CellEditor, CellEditor,
CheckboxGroup, CheckboxGroup,
ColumnDataSource, ColumnDataSource,
DataRange1d,
DataTable, DataTable,
Div, Div,
FileInput, FileInput,
Grid,
Image,
LinearAxis,
LinearColorMapper, LinearColorMapper,
MultiSelect, MultiSelect,
NumberEditor, NumberEditor,
NumberFormatter, NumberFormatter,
Panel, Panel,
PanTool,
Plot,
Range1d, Range1d,
ResetTool,
Scatter,
Select, Select,
Spinner, Spinner,
TableColumn, TableColumn,
Tabs, Tabs,
Title,
WheelZoomTool,
) )
from bokeh.palettes import Cividis256, Greys256, Plasma256 # pylint: disable=E0611 from bokeh.palettes import Cividis256, Greys256, Plasma256
from bokeh.plotting import figure
import pyzebra import pyzebra
@ -201,7 +190,7 @@ def create():
else: else:
metadata_table_source.data.update(temp=[None]) metadata_table_source.data.update(temp=[None])
update_overview_plot() _update_proj_plots()
def scan_table_source_callback(_attr, _old, _new): def scan_table_source_callback(_attr, _old, _new):
pass pass
@ -258,32 +247,32 @@ def create():
) )
param_select.on_change("value", param_select_callback) param_select.on_change("value", param_select_callback)
def update_overview_plot(): def _update_proj_plots():
scan = _get_selected_scan() scan = _get_selected_scan()
counts = scan["counts"] counts = scan["counts"]
n_im, n_y, n_x = counts.shape n_im, n_y, n_x = counts.shape
overview_x = np.mean(counts, axis=1) im_proj_x = np.mean(counts, axis=1)
overview_y = np.mean(counts, axis=2) im_proj_y = np.mean(counts, axis=2)
# normalize for simpler colormapping # normalize for simpler colormapping
overview_max_val = max(np.max(overview_x), np.max(overview_y)) im_proj_max_val = max(np.max(im_proj_x), np.max(im_proj_y))
overview_x = 1000 * overview_x / overview_max_val im_proj_x = 1000 * im_proj_x / im_proj_max_val
overview_y = 1000 * overview_y / overview_max_val im_proj_y = 1000 * im_proj_y / im_proj_max_val
overview_plot_x_image_source.data.update(image=[overview_x], dw=[n_x], dh=[n_im]) proj_x_image_source.data.update(image=[im_proj_x], dw=[n_x], dh=[n_im])
overview_plot_y_image_source.data.update(image=[overview_y], dw=[n_y], dh=[n_im]) proj_y_image_source.data.update(image=[im_proj_y], dw=[n_y], dh=[n_im])
if proj_auto_checkbox.active: if proj_auto_checkbox.active:
im_min = min(np.min(overview_x), np.min(overview_y)) im_min = min(np.min(im_proj_x), np.min(im_proj_y))
im_max = max(np.max(overview_x), np.max(overview_y)) im_max = max(np.max(im_proj_x), np.max(im_proj_y))
proj_display_min_spinner.value = im_min proj_display_min_spinner.value = im_min
proj_display_max_spinner.value = im_max proj_display_max_spinner.value = im_max
overview_plot_x_image_glyph.color_mapper.low = im_min proj_x_image_glyph.color_mapper.low = im_min
overview_plot_y_image_glyph.color_mapper.low = im_min proj_y_image_glyph.color_mapper.low = im_min
overview_plot_x_image_glyph.color_mapper.high = im_max proj_x_image_glyph.color_mapper.high = im_max
overview_plot_y_image_glyph.color_mapper.high = im_max proj_y_image_glyph.color_mapper.high = im_max
frame_range.start = 0 frame_range.start = 0
frame_range.end = n_im frame_range.end = n_im
@ -292,7 +281,7 @@ def create():
frame_range.bounds = (0, n_im) frame_range.bounds = (0, n_im)
scan_motor = scan["scan_motor"] scan_motor = scan["scan_motor"]
overview_plot_y.axis[1].axis_label = f"Scanning motor, {scan_motor}" proj_y_plot.axis[1].axis_label = f"Scanning motor, {scan_motor}"
var = scan[scan_motor] var = scan[scan_motor]
var_start = var[0] var_start = var[0]
@ -310,81 +299,52 @@ def create():
scanning_motor_range = Range1d(0, 1, bounds=(0, 1)) scanning_motor_range = Range1d(0, 1, bounds=(0, 1))
det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W)) det_x_range = Range1d(0, IMAGE_W, bounds=(0, IMAGE_W))
overview_plot_x = Plot( proj_x_plot = figure(
title=Title(text="Projections on X-axis"), title="Projections on X-axis",
x_axis_label="Coordinate X, pix",
y_axis_label="Frame",
x_range=det_x_range, x_range=det_x_range,
y_range=frame_range, y_range=frame_range,
extra_y_ranges={"scanning_motor": scanning_motor_range}, extra_y_ranges={"scanning_motor": scanning_motor_range},
plot_height=400, plot_height=400,
plot_width=IMAGE_PLOT_W - 3, plot_width=IMAGE_PLOT_W - 3,
tools="pan,box_zoom,wheel_zoom,reset",
active_scroll="wheel_zoom",
) )
# ---- tools proj_x_plot.yaxis.major_label_orientation = "vertical"
wheelzoomtool = WheelZoomTool(maintain_focus=False) proj_x_plot.toolbar.tools[2].maintain_focus = False
overview_plot_x.toolbar.logo = None
overview_plot_x.add_tools(PanTool(), BoxZoomTool(), wheelzoomtool, ResetTool())
overview_plot_x.toolbar.active_scroll = wheelzoomtool
# ---- axes proj_x_image_source = ColumnDataSource(
overview_plot_x.add_layout(LinearAxis(axis_label="Coordinate X, pix"), place="below")
overview_plot_x.add_layout(
LinearAxis(axis_label="Frame", major_label_orientation="vertical"), place="left"
)
# ---- grid lines
overview_plot_x.add_layout(Grid(dimension=0, ticker=BasicTicker()))
overview_plot_x.add_layout(Grid(dimension=1, ticker=BasicTicker()))
# ---- rgba image glyph
overview_plot_x_image_source = ColumnDataSource(
dict(image=[np.zeros((1, 1), dtype="float32")], x=[0], y=[0], dw=[IMAGE_W], dh=[1]) dict(image=[np.zeros((1, 1), dtype="float32")], x=[0], y=[0], dw=[IMAGE_W], dh=[1])
) )
overview_plot_x_image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh") proj_x_image_glyph = proj_x_plot.image(source=proj_x_image_source).glyph
overview_plot_x.add_glyph(
overview_plot_x_image_source, overview_plot_x_image_glyph, name="image_glyph"
)
det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H)) det_y_range = Range1d(0, IMAGE_H, bounds=(0, IMAGE_H))
overview_plot_y = Plot( proj_y_plot = figure(
title=Title(text="Projections on Y-axis"), title="Projections on Y-axis",
x_axis_label="Coordinate Y, pix",
y_axis_label="Scanning motor",
y_axis_location="right",
x_range=det_y_range, x_range=det_y_range,
y_range=frame_range, y_range=frame_range,
extra_y_ranges={"scanning_motor": scanning_motor_range}, extra_y_ranges={"scanning_motor": scanning_motor_range},
plot_height=400, plot_height=400,
plot_width=IMAGE_PLOT_H + 22, plot_width=IMAGE_PLOT_H + 22,
tools="pan,box_zoom,wheel_zoom,reset",
active_scroll="wheel_zoom",
) )
# ---- tools proj_y_plot.yaxis.y_range_name = "scanning_motor"
wheelzoomtool = WheelZoomTool(maintain_focus=False) proj_y_plot.yaxis.major_label_orientation = "vertical"
overview_plot_y.toolbar.logo = None proj_y_plot.toolbar.tools[2].maintain_focus = False
overview_plot_y.add_tools(PanTool(), BoxZoomTool(), wheelzoomtool, ResetTool())
overview_plot_y.toolbar.active_scroll = wheelzoomtool
# ---- axes proj_y_image_source = ColumnDataSource(
overview_plot_y.add_layout(LinearAxis(axis_label="Coordinate Y, pix"), place="below")
overview_plot_y.add_layout(
LinearAxis(
y_range_name="scanning_motor",
axis_label="Scanning motor",
major_label_orientation="vertical",
),
place="right",
)
# ---- grid lines
overview_plot_y.add_layout(Grid(dimension=0, ticker=BasicTicker()))
overview_plot_y.add_layout(Grid(dimension=1, ticker=BasicTicker()))
# ---- rgba image glyph
overview_plot_y_image_source = ColumnDataSource(
dict(image=[np.zeros((1, 1), dtype="float32")], x=[0], y=[0], dw=[IMAGE_H], dh=[1]) dict(image=[np.zeros((1, 1), dtype="float32")], x=[0], y=[0], dw=[IMAGE_H], dh=[1])
) )
overview_plot_y_image_glyph = Image(image="image", x="x", y="y", dw="dw", dh="dh") proj_y_image_glyph = proj_y_plot.image(source=proj_y_image_source).glyph
overview_plot_y.add_glyph(
overview_plot_y_image_source, overview_plot_y_image_glyph, name="image_glyph"
)
cmap_dict = { cmap_dict = {
"gray": Greys256, "gray": Greys256,
@ -394,8 +354,8 @@ def create():
} }
def colormap_callback(_attr, _old, new): def colormap_callback(_attr, _old, new):
overview_plot_x_image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new]) proj_x_image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
overview_plot_y_image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new]) proj_y_image_glyph.color_mapper = LinearColorMapper(palette=cmap_dict[new])
colormap = Select(title="Colormap:", options=list(cmap_dict.keys()), width=210) colormap = Select(title="Colormap:", options=list(cmap_dict.keys()), width=210)
colormap.on_change("value", colormap_callback) colormap.on_change("value", colormap_callback)
@ -411,7 +371,7 @@ def create():
proj_display_min_spinner.disabled = False proj_display_min_spinner.disabled = False
proj_display_max_spinner.disabled = False proj_display_max_spinner.disabled = False
update_overview_plot() _update_proj_plots()
proj_auto_checkbox = CheckboxGroup( proj_auto_checkbox = CheckboxGroup(
labels=["Projections Intensity Range"], active=[0], width=145, margin=[10, 5, 0, 5] labels=["Projections Intensity Range"], active=[0], width=145, margin=[10, 5, 0, 5]
@ -420,8 +380,8 @@ def create():
def proj_display_max_spinner_callback(_attr, _old_value, new_value): def proj_display_max_spinner_callback(_attr, _old_value, new_value):
proj_display_min_spinner.high = new_value - PROJ_STEP proj_display_min_spinner.high = new_value - PROJ_STEP
overview_plot_x_image_glyph.color_mapper.high = new_value proj_x_image_glyph.color_mapper.high = new_value
overview_plot_y_image_glyph.color_mapper.high = new_value proj_y_image_glyph.color_mapper.high = new_value
proj_display_max_spinner = Spinner( proj_display_max_spinner = Spinner(
low=0 + PROJ_STEP, low=0 + PROJ_STEP,
@ -435,8 +395,8 @@ def create():
def proj_display_min_spinner_callback(_attr, _old_value, new_value): def proj_display_min_spinner_callback(_attr, _old_value, new_value):
proj_display_max_spinner.low = new_value + PROJ_STEP proj_display_max_spinner.low = new_value + PROJ_STEP
overview_plot_x_image_glyph.color_mapper.low = new_value proj_x_image_glyph.color_mapper.low = new_value
overview_plot_y_image_glyph.color_mapper.low = new_value proj_y_image_glyph.color_mapper.low = new_value
proj_display_min_spinner = Spinner( proj_display_min_spinner = Spinner(
low=0, low=0,
@ -471,21 +431,20 @@ def create():
if "fit" in s and fit_param: if "fit" in s and fit_param:
x.append(p) x.append(p)
y.append(s["fit"][fit_param]) y.append(s["fit"][fit_param])
param_plot_scatter_source.data.update(x=x, y=y) param_scatter_source.data.update(x=x, y=y)
# Parameter plot # Parameter plot
param_plot = Plot(x_range=DataRange1d(), y_range=DataRange1d(), plot_height=400, plot_width=700) param_plot = figure(
x_axis_label="Parameter",
y_axis_label="Fit parameter",
plot_height=400,
plot_width=700,
tools="pan,wheel_zoom,reset",
)
param_plot.add_layout(LinearAxis(axis_label="Fit parameter"), place="left") param_scatter_source = ColumnDataSource(dict(x=[], y=[]))
param_plot.add_layout(LinearAxis(axis_label="Parameter"), place="below") param_plot.circle(source=param_scatter_source)
param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker()))
param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker()))
param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[]))
param_plot.add_glyph(param_plot_scatter_source, Scatter(x="x", y="y"))
param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool())
param_plot.toolbar.logo = None param_plot.toolbar.logo = None
def fit_param_select_callback(_attr, _old, _new): def fit_param_select_callback(_attr, _old, _new):
@ -553,12 +512,9 @@ def create():
proc_all_button, proc_all_button,
) )
layout_overview = column( layout_proj = column(
gridplot( gridplot(
[[overview_plot_x, overview_plot_y]], [[proj_x_plot, proj_y_plot]], toolbar_options={"logo": None}, toolbar_location="right"
toolbar_options=dict(logo=None),
merge_tools=True,
toolbar_location="left",
), ),
layout_controls, layout_controls,
) )
@ -566,7 +522,7 @@ def create():
# Plot tabs # Plot tabs
plots = Tabs( plots = Tabs(
tabs=[ tabs=[
Panel(child=layout_overview, title="single scan"), Panel(child=layout_proj, title="single scan"),
Panel(child=column(param_plot, row(fit_param_select)), title="parameter plot"), Panel(child=column(param_plot, row(fit_param_select)), title="parameter plot"),
] ]
) )