From 41f5157ba87c0502ec0634f6587e0fe3355932fc Mon Sep 17 00:00:00 2001 From: Ivan Usov Date: Wed, 14 Dec 2022 20:23:10 +0100 Subject: [PATCH] Simplify plot creation in panel_param_study --- pyzebra/app/panel_param_study.py | 174 ++++++++++++------------------- 1 file changed, 67 insertions(+), 107 deletions(-) diff --git a/pyzebra/app/panel_param_study.py b/pyzebra/app/panel_param_study.py index 2cd18ea..6293752 100644 --- a/pyzebra/app/panel_param_study.py +++ b/pyzebra/app/panel_param_study.py @@ -9,35 +9,23 @@ import numpy as np from bokeh.io import curdoc from bokeh.layouts import column, row from bokeh.models import ( - BasicTicker, Button, CellEditor, CheckboxEditor, CheckboxGroup, ColumnDataSource, CustomJS, - DataRange1d, DataTable, Div, Dropdown, FileInput, - Grid, HoverTool, - Image, - Legend, - Line, - LinearAxis, LinearColorMapper, - MultiLine, MultiSelect, NumberEditor, Panel, - PanTool, - Plot, RadioGroup, Range1d, - ResetTool, - Scatter, Select, Spacer, Span, @@ -45,10 +33,10 @@ from bokeh.models import ( TableColumn, Tabs, TextAreaInput, - WheelZoomTool, Whisker, ) from bokeh.palettes import Category10, Plasma256 +from bokeh.plotting import figure from scipy import interpolate import pyzebra @@ -123,7 +111,7 @@ def create(): file_list.append(os.path.basename(scan["original_filename"])) scan_table_source.data.update( - file=file_list, scan=scan_list, param=param, fit=[0] * len(scan_list), export=export, + file=file_list, scan=scan_list, param=param, fit=[0] * len(scan_list), export=export ) scan_table_source.selected.indices = [] scan_table_source.selected.indices = [0] @@ -281,12 +269,12 @@ def create(): x = scan[scan_motor] plot.axis[0].axis_label = scan_motor - plot_scatter_source.data.update(x=x, y=y, y_upper=y + y_err, y_lower=y - y_err) + scatter_source.data.update(x=x, y=y, y_upper=y + y_err, y_lower=y - y_err) fit = scan.get("fit") if fit is not None: x_fit = np.linspace(x[0], x[-1], 100) - plot_fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit)) + fit_source.data.update(x=x_fit, y=fit.eval(x=x_fit)) x_bkg = [] y_bkg = [] @@ -302,15 +290,15 @@ def create(): xs_peak.append(x_fit) ys_peak.append(comps[f"f{i}_"]) - plot_bkg_source.data.update(x=x_bkg, y=y_bkg) - plot_peak_source.data.update(xs=xs_peak, ys=ys_peak) + bkg_source.data.update(x=x_bkg, y=y_bkg) + peak_source.data.update(xs=xs_peak, ys=ys_peak) fit_output_textinput.value = fit.fit_report() else: - plot_fit_source.data.update(x=[], y=[]) - plot_bkg_source.data.update(x=[], y=[]) - plot_peak_source.data.update(xs=[], ys=[]) + fit_source.data.update(x=[], y=[]) + bkg_source.data.update(x=[], y=[]) + peak_source.data.update(xs=[], ys=[]) fit_output_textinput.value = "" def _update_overview(): @@ -336,9 +324,9 @@ def create(): ov_plot.axis[0].axis_label = scan_motor ov_param_plot.axis[0].axis_label = scan_motor - ov_plot_mline_source.data.update(xs=xs, ys=ys, param=param, color=color_palette(len(xs))) + ov_mline_source.data.update(xs=xs, ys=ys, param=param, color=color_palette(len(xs))) - ov_param_plot_scatter_source.data.update(x=x, y=y) + ov_param_scatter_source.data.update(x=x, y=y) if y: x1, x2 = min(x), max(x) @@ -348,7 +336,7 @@ def create(): np.linspace(y1, y2, ov_param_plot.inner_height), ) image = interpolate.griddata((x, y), par, (grid_x, grid_y)) - ov_param_plot_image_source.data.update( + ov_param_image_source.data.update( image=[image], x=[x1], y=[y1], dw=[x2 - x1], dh=[y2 - y1] ) @@ -363,7 +351,7 @@ def create(): y_range.bounds = (y1, y2) else: - ov_param_plot_image_source.data.update(image=[], x=[], y=[], dw=[], dh=[]) + ov_param_image_source.data.update(image=[], x=[], y=[], dw=[], dh=[]) def _update_param_plot(): x = [] @@ -382,40 +370,31 @@ def create(): y_lower.append(param_fit_val - param_fit_std) y_upper.append(param_fit_val + param_fit_std) - param_plot_scatter_source.data.update(x=x, y=y, y_lower=y_lower, y_upper=y_upper) + param_scatter_source.data.update(x=x, y=y, y_lower=y_lower, y_upper=y_upper) # Main plot - plot = Plot( - x_range=DataRange1d(), - y_range=DataRange1d(only_visible=True), + plot = figure( + x_axis_label="Scan motor", + y_axis_label="Counts", plot_height=450, plot_width=700, + tools="pan,wheel_zoom,reset", ) - plot.add_layout(LinearAxis(axis_label="Counts"), place="left") - plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below") - - plot.add_layout(Grid(dimension=0, ticker=BasicTicker())) - plot.add_layout(Grid(dimension=1, ticker=BasicTicker())) - - plot_scatter_source = ColumnDataSource(dict(x=[0], y=[0], y_upper=[0], y_lower=[0])) - plot_scatter = plot.add_glyph( - plot_scatter_source, Scatter(x="x", y="y", line_color="steelblue", fill_color="steelblue") + scatter_source = ColumnDataSource(dict(x=[0], y=[0], y_upper=[0], y_lower=[0])) + plot.circle( + source=scatter_source, line_color="steelblue", fill_color="steelblue", legend_label="data" ) - plot.add_layout(Whisker(source=plot_scatter_source, base="x", upper="y_upper", lower="y_lower")) + plot.add_layout(Whisker(source=scatter_source, base="x", upper="y_upper", lower="y_lower")) - plot_fit_source = ColumnDataSource(dict(x=[0], y=[0])) - plot_fit = plot.add_glyph(plot_fit_source, Line(x="x", y="y")) + fit_source = ColumnDataSource(dict(x=[0], y=[0])) + plot.line(source=fit_source, legend_label="best fit") - plot_bkg_source = ColumnDataSource(dict(x=[0], y=[0])) - plot_bkg = plot.add_glyph( - plot_bkg_source, Line(x="x", y="y", line_color="green", line_dash="dashed") - ) + bkg_source = ColumnDataSource(dict(x=[0], y=[0])) + plot.line(source=bkg_source, line_color="green", line_dash="dashed", legend_label="linear") - plot_peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]])) - plot_peak = plot.add_glyph( - plot_peak_source, MultiLine(xs="xs", ys="ys", line_color="red", line_dash="dashed") - ) + peak_source = ColumnDataSource(dict(xs=[[0]], ys=[[0]])) + plot.multi_line(source=peak_source, line_color="red", line_dash="dashed", legend_label="peak") fit_from_span = Span(location=None, dimension="height", line_dash="dashed") plot.add_layout(fit_from_span) @@ -423,80 +402,61 @@ def create(): fit_to_span = Span(location=None, dimension="height", line_dash="dashed") plot.add_layout(fit_to_span) - plot.add_layout( - Legend( - items=[ - ("data", [plot_scatter]), - ("best fit", [plot_fit]), - ("peak", [plot_peak]), - ("linear", [plot_bkg]), - ], - location="top_left", - click_policy="hide", - ) - ) - - plot.add_tools(PanTool(), WheelZoomTool(), ResetTool()) + plot.y_range.only_visible = True plot.toolbar.logo = None + plot.legend.click_policy = "hide" # Overview multilines plot - ov_plot = Plot(x_range=DataRange1d(), y_range=DataRange1d(), plot_height=450, plot_width=700) + ov_plot = figure( + x_axis_label="Scan motor", + y_axis_label="Counts", + plot_height=450, + plot_width=700, + tools="pan,wheel_zoom,reset", + ) - ov_plot.add_layout(LinearAxis(axis_label="Counts"), place="left") - ov_plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below") + ov_mline_source = ColumnDataSource(dict(xs=[], ys=[], param=[], color=[])) + ov_plot.multi_line(source=ov_mline_source, line_color="color") - ov_plot.add_layout(Grid(dimension=0, ticker=BasicTicker())) - ov_plot.add_layout(Grid(dimension=1, ticker=BasicTicker())) + ov_plot.add_tools(HoverTool(tooltips=[("param", "@param")])) - ov_plot_mline_source = ColumnDataSource(dict(xs=[], ys=[], param=[], color=[])) - ov_plot.add_glyph(ov_plot_mline_source, MultiLine(xs="xs", ys="ys", line_color="color")) - - hover_tool = HoverTool(tooltips=[("param", "@param")]) - ov_plot.add_tools(PanTool(), WheelZoomTool(), hover_tool, ResetTool()) - - ov_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool()) ov_plot.toolbar.logo = None - # Overview perams plot - ov_param_plot = Plot(x_range=Range1d(), y_range=Range1d(), plot_height=450, plot_width=700) - - ov_param_plot.add_layout(LinearAxis(axis_label="Param"), place="left") - ov_param_plot.add_layout(LinearAxis(axis_label="Scan motor"), place="below") - - ov_param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker())) - ov_param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker())) + # Overview params plot + ov_param_plot = figure( + x_axis_label="Scan motor", + y_axis_label="Param", + x_range=Range1d(), + y_range=Range1d(), + plot_height=450, + plot_width=700, + tools="pan,wheel_zoom,reset", + ) color_mapper = LinearColorMapper(palette=Plasma256) - ov_param_plot_image_source = ColumnDataSource(dict(image=[], x=[], y=[], dw=[], dh=[])) - ov_param_plot.add_glyph( - ov_param_plot_image_source, - Image(image="image", x="x", y="y", dw="dw", dh="dh", color_mapper=color_mapper), - ) + ov_param_image_source = ColumnDataSource(dict(image=[], x=[], y=[], dw=[], dh=[])) + ov_param_plot.image(source=ov_param_image_source, color_mapper=color_mapper) - ov_param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[])) - ov_param_plot.add_glyph( - ov_param_plot_scatter_source, Scatter(x="x", y="y", marker="dot", size=15), - ) + ov_param_scatter_source = ColumnDataSource(dict(x=[], y=[])) + ov_param_plot.dot(source=ov_param_scatter_source, size=15, color="black") - ov_param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool()) ov_param_plot.toolbar.logo = None # Parameter plot - param_plot = Plot(x_range=DataRange1d(), y_range=DataRange1d(), plot_height=400, plot_width=700) - - param_plot.add_layout(LinearAxis(axis_label="Fit parameter"), place="left") - param_plot.add_layout(LinearAxis(axis_label="Parameter"), place="below") - - param_plot.add_layout(Grid(dimension=0, ticker=BasicTicker())) - param_plot.add_layout(Grid(dimension=1, ticker=BasicTicker())) - - param_plot_scatter_source = ColumnDataSource(dict(x=[], y=[], y_upper=[], y_lower=[])) - param_plot.add_glyph(param_plot_scatter_source, Scatter(x="x", y="y")) - param_plot.add_layout( - Whisker(source=param_plot_scatter_source, base="x", upper="y_upper", lower="y_lower") + param_plot = figure( + x_axis_label="Parameter", + y_axis_label="Fit parameter", + plot_height=400, + plot_width=700, + tools="pan,wheel_zoom,reset", + ) + + param_scatter_source = ColumnDataSource(dict(x=[], y=[], y_upper=[], y_lower=[])) + param_plot.circle(source=param_scatter_source) + param_plot.add_layout( + Whisker(source=param_scatter_source, base="x", upper="y_upper", lower="y_lower") ) - param_plot.add_tools(PanTool(), WheelZoomTool(), ResetTool()) param_plot.toolbar.logo = None def fit_param_select_callback(_attr, _old, _new): @@ -684,7 +644,7 @@ def create(): n = len(params) fitparams = dict( - param=params, value=[None] * n, vary=[True] * n, min=[None] * n, max=[None] * n, + param=params, value=[None] * n, vary=[True] * n, min=[None] * n, max=[None] * n ) if function == "linear":