From 147ced2cb059f3dd6758ee7e87118845de7325bc Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Fri, 14 Feb 2025 13:22:34 +0100 Subject: [PATCH] feat(waveform): new Waveform widget based on NextGen PlotBase --- bec_widgets/cli/client.py | 567 ++++++ .../jupyter_console/jupyter_console_window.py | 28 +- .../widgets/containers/dock/dock_area.py | 11 +- .../widgets/plots_next_gen/__init__.py | 0 .../widgets/plots_next_gen/plot_base.py | 4 + .../toolbar_bundles/roi_bundle.py | 5 + .../plots_next_gen/waveform/__init__.py | 0 .../widgets/plots_next_gen/waveform/curve.py | 328 ++++ .../waveform/register_waveform.py | 15 + .../waveform/settings/__init__.py | 0 .../settings/curve_settings/__init__.py | 0 .../settings/curve_settings/curve_setting.py | 109 ++ .../settings/curve_settings/curve_tree.py | 538 ++++++ .../plots_next_gen/waveform/utils/__init__.py | 0 .../waveform/utils/roi_manager.py | 84 + .../plots_next_gen/waveform/waveform.py | 1611 +++++++++++++++++ .../waveform/waveform.pyproject | 1 + .../waveform/waveform_plugin.py | 54 + tests/unit_tests/client_mocks.py | 150 ++ tests/unit_tests/test_curve_settings.py | 367 ++++ tests/unit_tests/test_waveform_next_gen.py | 787 ++++++++ 21 files changed, 4648 insertions(+), 11 deletions(-) create mode 100644 bec_widgets/widgets/plots_next_gen/__init__.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/__init__.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/curve.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/register_waveform.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/settings/__init__.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/__init__.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_setting.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_tree.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/utils/__init__.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/utils/roi_manager.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/waveform.py create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/waveform.pyproject create mode 100644 bec_widgets/widgets/plots_next_gen/waveform/waveform_plugin.py create mode 100644 tests/unit_tests/test_curve_settings.py create mode 100644 tests/unit_tests/test_waveform_next_gen.py diff --git a/bec_widgets/cli/client.py b/bec_widgets/cli/client.py index 4ab60ac0..87233c2c 100644 --- a/bec_widgets/cli/client.py +++ b/bec_widgets/cli/client.py @@ -46,6 +46,7 @@ class Widgets(str, enum.Enum): StopButton = "StopButton" TextBox = "TextBox" VSCodeEditor = "VSCodeEditor" + Waveform = "Waveform" WebsiteWidget = "WebsiteWidget" @@ -2996,6 +2997,161 @@ class BECWaveformWidget(RPCBase): """ +class Curve(RPCBase): + @rpc_call + def remove(self): + """ + Remove the curve from the plot. + """ + + @property + @rpc_call + def _rpc_id(self) -> "str": + """ + Get the RPC ID of the widget. + """ + + @property + @rpc_call + def _config_dict(self) -> "dict": + """ + Get the configuration of the widget. + + Returns: + dict: The configuration of the widget. + """ + + @rpc_call + def set(self, **kwargs): + """ + Set the properties of the curve. + + Args: + **kwargs: Keyword arguments for the properties to be set. + + Possible properties: + - color: str + - symbol: str + - symbol_color: str + - symbol_size: int + - pen_width: int + - pen_style: Literal["solid", "dash", "dot", "dashdot"] + """ + + @rpc_call + def set_data(self, x: "list | np.ndarray", y: "list | np.ndarray"): + """ + Set the data of the curve. + + Args: + x(list|np.ndarray): The x data. + y(list|np.ndarray): The y data. + + Raises: + ValueError: If the source is not custom. + """ + + @rpc_call + def set_color(self, color: "str", symbol_color: "str | None" = None): + """ + Change the color of the curve. + + Args: + color(str): Color of the curve. + symbol_color(str, optional): Color of the symbol. Defaults to None. + """ + + @rpc_call + def set_color_map_z(self, colormap: "str"): + """ + Set the colormap for the scatter plot z gradient. + + Args: + colormap(str): Colormap for the scatter plot. + """ + + @rpc_call + def set_symbol(self, symbol: "str"): + """ + Change the symbol of the curve. + + Args: + symbol(str): Symbol of the curve. + """ + + @rpc_call + def set_symbol_color(self, symbol_color: "str"): + """ + Change the symbol color of the curve. + + Args: + symbol_color(str): Color of the symbol. + """ + + @rpc_call + def set_symbol_size(self, symbol_size: "int"): + """ + Change the symbol size of the curve. + + Args: + symbol_size(int): Size of the symbol. + """ + + @rpc_call + def set_pen_width(self, pen_width: "int"): + """ + Change the pen width of the curve. + + Args: + pen_width(int): Width of the pen. + """ + + @rpc_call + def set_pen_style(self, pen_style: "Literal['solid', 'dash', 'dot', 'dashdot']"): + """ + Change the pen style of the curve. + + Args: + pen_style(Literal["solid", "dash", "dot", "dashdot"]): Style of the pen. + """ + + @rpc_call + def get_data(self) -> "tuple[np.ndarray, np.ndarray]": + """ + Get the data of the curve. + Returns: + tuple[np.ndarray,np.ndarray]: X and Y data of the curve. + """ + + @property + @rpc_call + def dap_params(self): + """ + Get the dap parameters. + """ + + @property + @rpc_call + def dap_summary(self): + """ + Get the dap summary. + """ + + @property + @rpc_call + def dap_oversample(self): + """ + Get the dap oversample. + """ + + @dap_oversample.setter + @rpc_call + def dap_oversample(self): + """ + Get the dap oversample. + """ + + class DapComboBox(RPCBase): @rpc_call def select_y_axis(self, y_axis: str): @@ -3775,6 +3931,417 @@ class TextBox(RPCBase): class VSCodeEditor(RPCBase): ... +class Waveform(RPCBase): + @property + @rpc_call + def enable_toolbar(self) -> "bool": + """ + None + """ + + @enable_toolbar.setter + @rpc_call + def enable_toolbar(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def enable_side_panel(self) -> "bool": + """ + None + """ + + @enable_side_panel.setter + @rpc_call + def enable_side_panel(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def enable_fps_monitor(self) -> "bool": + """ + None + """ + + @enable_fps_monitor.setter + @rpc_call + def enable_fps_monitor(self) -> "bool": + """ + None + """ + + @rpc_call + def set(self, **kwargs): + """ + Set the properties of the plot widget. + + Args: + **kwargs: Keyword arguments for the properties to be set. + + Possible properties: + """ + + @property + @rpc_call + def title(self) -> "str": + """ + None + """ + + @title.setter + @rpc_call + def title(self) -> "str": + """ + None + """ + + @property + @rpc_call + def x_label(self) -> "str": + """ + None + """ + + @x_label.setter + @rpc_call + def x_label(self) -> "str": + """ + None + """ + + @property + @rpc_call + def y_label(self) -> "str": + """ + None + """ + + @y_label.setter + @rpc_call + def y_label(self) -> "str": + """ + None + """ + + @property + @rpc_call + def x_limits(self) -> "QPointF": + """ + None + """ + + @x_limits.setter + @rpc_call + def x_limits(self) -> "QPointF": + """ + None + """ + + @property + @rpc_call + def y_limits(self) -> "QPointF": + """ + None + """ + + @y_limits.setter + @rpc_call + def y_limits(self) -> "QPointF": + """ + None + """ + + @property + @rpc_call + def x_grid(self) -> "bool": + """ + None + """ + + @x_grid.setter + @rpc_call + def x_grid(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def y_grid(self) -> "bool": + """ + None + """ + + @y_grid.setter + @rpc_call + def y_grid(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def inner_axes(self) -> "bool": + """ + None + """ + + @inner_axes.setter + @rpc_call + def inner_axes(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def outer_axes(self) -> "bool": + """ + None + """ + + @outer_axes.setter + @rpc_call + def outer_axes(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def lock_aspect_ratio(self) -> "bool": + """ + None + """ + + @lock_aspect_ratio.setter + @rpc_call + def lock_aspect_ratio(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def auto_range_x(self) -> "bool": + """ + None + """ + + @auto_range_x.setter + @rpc_call + def auto_range_x(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def auto_range_y(self) -> "bool": + """ + None + """ + + @auto_range_y.setter + @rpc_call + def auto_range_y(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def x_log(self) -> "bool": + """ + None + """ + + @x_log.setter + @rpc_call + def x_log(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def y_log(self) -> "bool": + """ + None + """ + + @y_log.setter + @rpc_call + def y_log(self) -> "bool": + """ + None + """ + + @property + @rpc_call + def legend_label_size(self) -> "int": + """ + None + """ + + @legend_label_size.setter + @rpc_call + def legend_label_size(self) -> "int": + """ + None + """ + + @rpc_call + def __getitem__(self, key: "int | str"): + """ + None + """ + + @property + @rpc_call + def curves(self) -> "list[Curve]": + """ + Get the curves of the plot widget as a list. + + Returns: + list: List of curves. + """ + + @property + @rpc_call + def x_mode(self) -> "str": + """ + None + """ + + @x_mode.setter + @rpc_call + def x_mode(self) -> "str": + """ + None + """ + + @property + @rpc_call + def color_palette(self) -> "str": + """ + The color palette of the figure widget. + """ + + @color_palette.setter + @rpc_call + def color_palette(self) -> "str": + """ + The color palette of the figure widget. + """ + + @rpc_call + def plot( + self, + arg1: "list | np.ndarray | str | None" = None, + y: "list | np.ndarray | None" = None, + x: "list | np.ndarray | None" = None, + x_name: "str | None" = None, + y_name: "str | None" = None, + x_entry: "str | None" = None, + y_entry: "str | None" = None, + color: "str | None" = None, + label: "str | None" = None, + dap: "str | None" = None, + **kwargs, + ) -> "Curve": + """ + Plot a curve to the plot widget. + + Args: + arg1(list | np.ndarray | str | None): First argument, which can be x data, y data, or y_name. + y(list | np.ndarray): Custom y data to plot. + x(list | np.ndarray): Custom y data to plot. + x_name(str): Name of the x signal. + - "auto": Use the best effort signal. + - "timestamp": Use the timestamp signal. + - "index": Use the index signal. + - Custom signal name of a device from BEC. + y_name(str): The name of the device for the y-axis. + x_entry(str): The name of the entry for the x-axis. + y_entry(str): The name of the entry for the y-axis. + color(str): The color of the curve. + label(str): The label of the curve. + dap(str): The dap model to use for the curve, only available for sync devices. + If not specified, none will be added. + Use the same string as is the name of the LMFit model. + + Returns: + Curve: The curve object. + """ + + @rpc_call + def add_dap_curve( + self, + device_label: "str", + dap_name: "str", + color: "str | None" = None, + dap_oversample: "int" = 1, + **kwargs, + ) -> "Curve": + """ + Create a new DAP curve referencing the existing device curve `device_label`, + with the data processing model `dap_name`. + + Args: + device_label(str): The label of the device curve to add DAP to. + dap_name(str): The name of the DAP model to use. + color(str): The color of the curve. + dap_oversample(int): The oversampling factor for the DAP curve. + **kwargs + + Returns: + Curve: The new DAP curve. + """ + + @rpc_call + def remove_curve(self, curve: "int | str"): + """ + Remove a curve from the plot widget. + + Args: + curve(int|str): The curve to remove. Can be the order of the curve or the name of the curve. + """ + + @rpc_call + def update_with_scan_history(self, scan_index: "int" = None, scan_id: "str" = None): + """ + Update the scan curves with the data from the scan storage. + Provide only one of scan_id or scan_index. + + Args: + scan_id(str, optional): ScanID of the scan to be updated. Defaults to None. + scan_index(int, optional): Index of the scan to be updated. Defaults to None. + """ + + @rpc_call + def get_dap_params(self) -> "dict[str, dict]": + """ + Get the DAP parameters of all DAP curves. + + Returns: + dict[str, dict]: DAP parameters of all DAP curves. + """ + + @rpc_call + def get_dap_summary(self) -> "dict[str, dict]": + """ + Get the DAP summary of all DAP curves. + + Returns: + dict[str, dict]: DAP summary of all DAP curves. + """ + + class WebsiteWidget(RPCBase): @rpc_call def set_url(self, url: str) -> None: diff --git a/bec_widgets/examples/jupyter_console/jupyter_console_window.py b/bec_widgets/examples/jupyter_console/jupyter_console_window.py index 98e94bf8..9d4421c5 100644 --- a/bec_widgets/examples/jupyter_console/jupyter_console_window.py +++ b/bec_widgets/examples/jupyter_console/jupyter_console_window.py @@ -21,6 +21,7 @@ from bec_widgets.widgets.containers.figure import BECFigure from bec_widgets.widgets.containers.layout_manager.layout_manager import LayoutManagerWidget from bec_widgets.widgets.editors.jupyter_console.jupyter_console import BECJupyterConsole from bec_widgets.widgets.plots_next_gen.plot_base import PlotBase +from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform class JupyterConsoleWindow(QWidget): # pragma: no cover: @@ -65,6 +66,7 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: "btn6": self.btn6, "pb": self.pb, "pi": self.pi, + "wfng": self.wfng, } ) @@ -100,7 +102,7 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: self.pb = PlotBase() self.pi = self.pb.plot_item fourth_tab_layout.addWidget(self.pb) - tab_widget.addTab(fourth_tab, "PltoBase") + tab_widget.addTab(fourth_tab, "PlotBase") tab_widget.setCurrentIndex(3) @@ -117,6 +119,15 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: self.btn5 = QPushButton("Button 5") self.btn6 = QPushButton("Button 6") + fifth_tab = QWidget() + fifth_tab_layout = QVBoxLayout(fifth_tab) + self.wfng = Waveform() + fifth_tab_layout.addWidget(self.wfng) + tab_widget.addTab(fifth_tab, "Waveform Next Gen") + tab_widget.setCurrentIndex(4) + # add stuff to the new Waveform widget + self._init_waveform() + # add stuff to figure self._init_figure() @@ -125,6 +136,13 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: self.setWindowTitle("Jupyter Console Window") + def _init_waveform(self): + # self.wfng._add_curve_custom(x=np.arange(10), y=np.random.rand(10), label="curve1") + # self.wfng._add_curve_custom(x=np.arange(10), y=np.random.rand(10), label="curve2") + # self.wfng._add_curve_custom(x=np.arange(10), y=np.random.rand(10), label="curve3") + self.wfng.plot(y_name="bpm4i", y_entry="bpm4i", dap="GaussianModel") + self.wfng.plot(y_name="bpm3a", y_entry="bpm3a", dap="GaussianModel") + def _init_figure(self): self.w1 = self.figure.plot(x_name="samx", y_name="bpm4i", row=0, col=0) self.w1.set( @@ -191,9 +209,11 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: self.im.image("waveform", "1d") self.d2 = self.dock.add_dock(name="dock_2", position="bottom") - self.wf = self.d2.add_widget("BECFigure", row=0, col=0) + self.wf = self.d2.add_widget("BECWaveformWidget", row=0, col=0) + self.wf.plot("bpm4i") + self.wf.plot("bpm3a") - self.mw = self.wf.multi_waveform(monitor="waveform") # , config=config) + self.mw = None # self.wf.multi_waveform(monitor="waveform") # , config=config) self.dock.save_state() @@ -219,7 +239,7 @@ if __name__ == "__main__": # pragma: no cover app.setApplicationName("Jupyter Console") app.setApplicationDisplayName("Jupyter Console") apply_theme("dark") - icon = material_icon("terminal", color="#434343", filled=True) + icon = material_icon("terminal", color=(255, 255, 255, 255), filled=True) app.setWindowIcon(icon) bec_dispatcher = BECDispatcher() diff --git a/bec_widgets/widgets/containers/dock/dock_area.py b/bec_widgets/widgets/containers/dock/dock_area.py index d9bce695..e0616d7a 100644 --- a/bec_widgets/widgets/containers/dock/dock_area.py +++ b/bec_widgets/widgets/containers/dock/dock_area.py @@ -26,7 +26,7 @@ from bec_widgets.widgets.editors.vscode.vscode import VSCodeEditor from bec_widgets.widgets.plots.image.image_widget import BECImageWidget from bec_widgets.widgets.plots.motor_map.motor_map_widget import BECMotorMapWidget from bec_widgets.widgets.plots.multi_waveform.multi_waveform_widget import BECMultiWaveformWidget -from bec_widgets.widgets.plots.waveform.waveform_widget import BECWaveformWidget +from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform from bec_widgets.widgets.progress.ring_progress_bar.ring_progress_bar import RingProgressBar from bec_widgets.widgets.services.bec_queue.bec_queue import BECQueue from bec_widgets.widgets.services.bec_status_box.bec_status_box import BECStatusBox @@ -89,9 +89,7 @@ class BECDockArea(BECWidget, QWidget): label="Add Plot ", actions={ "waveform": MaterialIconAction( - icon_name=BECWaveformWidget.ICON_NAME, - tooltip="Add Waveform", - filled=True, + icon_name=Waveform.ICON_NAME, tooltip="Add Waveform", filled=True ), "multi_waveform": MaterialIconAction( icon_name=BECMultiWaveformWidget.ICON_NAME, @@ -171,7 +169,7 @@ class BECDockArea(BECWidget, QWidget): def _hook_toolbar(self): # Menu Plot self.toolbar.widgets["menu_plots"].widgets["waveform"].triggered.connect( - lambda: self.add_dock(widget="BECWaveformWidget", prefix="waveform") + lambda: self.add_dock(widget="Waveform", prefix="waveform") ) self.toolbar.widgets["menu_plots"].widgets["multi_waveform"].triggered.connect( lambda: self.add_dock(widget="BECMultiWaveformWidget", prefix="multi_waveform") @@ -472,8 +470,7 @@ class BECDockArea(BECWidget, QWidget): self.deleteLater() -if __name__ == "__main__": - from qtpy.QtWidgets import QApplication +if __name__ == "__main__": # pragma: no cover from bec_widgets.utils.colors import set_theme diff --git a/bec_widgets/widgets/plots_next_gen/__init__.py b/bec_widgets/widgets/plots_next_gen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_widgets/widgets/plots_next_gen/plot_base.py b/bec_widgets/widgets/plots_next_gen/plot_base.py index 975c9591..31f92908 100644 --- a/bec_widgets/widgets/plots_next_gen/plot_base.py +++ b/bec_widgets/widgets/plots_next_gen/plot_base.py @@ -219,6 +219,10 @@ class PlotBase(BECWidget, QWidget): self.axis_settings_dialog = None self.toolbar.widgets["axis"].action.setChecked(False) + def reset_legend(self): + """In the case that the legend is not visible, reset it to be visible to top left corner""" + self.plot_item.legend.autoAnchor(50) + ################################################################################ # Toggle UI Elements ################################################################################ diff --git a/bec_widgets/widgets/plots_next_gen/toolbar_bundles/roi_bundle.py b/bec_widgets/widgets/plots_next_gen/toolbar_bundles/roi_bundle.py index 2a9505d2..6516c4d8 100644 --- a/bec_widgets/widgets/plots_next_gen/toolbar_bundles/roi_bundle.py +++ b/bec_widgets/widgets/plots_next_gen/toolbar_bundles/roi_bundle.py @@ -18,9 +18,14 @@ class ROIBundle(ToolbarBundle): crosshair = MaterialIconAction( icon_name="point_scan", tooltip="Show Crosshair", checkable=True ) + reset_legend = MaterialIconAction( + icon_name="restart_alt", tooltip="Reset the position of legend.", checkable=False + ) # Add them to the bundle self.add_action("crosshair", crosshair) + self.add_action("reset_legend", reset_legend) # Immediately connect signals crosshair.action.toggled.connect(self.target_widget.toggle_crosshair) + reset_legend.action.triggered.connect(self.target_widget.reset_legend) diff --git a/bec_widgets/widgets/plots_next_gen/waveform/__init__.py b/bec_widgets/widgets/plots_next_gen/waveform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_widgets/widgets/plots_next_gen/waveform/curve.py b/bec_widgets/widgets/plots_next_gen/waveform/curve.py new file mode 100644 index 00000000..6890dff2 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/curve.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np +import pyqtgraph as pg +from bec_lib import bec_logger +from pydantic import BaseModel, Field, field_validator +from qtpy import QtCore + +from bec_widgets.utils import BECConnector, Colors, ConnectionConfig + +if TYPE_CHECKING: # pragma: no cover + from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform + +logger = bec_logger.logger + + +# noinspection PyDataclass +class DeviceSignal(BaseModel): + """The configuration of a signal in the 1D waveform widget.""" + + name: str + entry: str + dap: str | None = None + dap_oversample: int = 1 + + model_config: dict = {"validate_assignment": True} + + +# noinspection PyDataclass +class CurveConfig(ConnectionConfig): + parent_id: str | None = Field(None, description="The parent plot of the curve.") + label: str | None = Field(None, description="The label of the curve.") + color: str | tuple | None = Field(None, description="The color of the curve.") + symbol: str | None = Field("o", description="The symbol of the curve.") + symbol_color: str | tuple | None = Field( + None, description="The color of the symbol of the curve." + ) + symbol_size: int | None = Field(7, description="The size of the symbol of the curve.") + pen_width: int | None = Field(4, description="The width of the pen of the curve.") + pen_style: Literal["solid", "dash", "dot", "dashdot"] | None = Field( + "solid", description="The style of the pen of the curve." + ) + source: Literal["device", "dap", "custom"] = Field( + "custom", description="The source of the curve." + ) + signal: DeviceSignal | None = Field(None, description="The signal of the curve.") + parent_label: str | None = Field( + None, description="The label of the parent plot, only relevant for dap curves." + ) + + model_config: dict = {"validate_assignment": True} + + _validate_color = field_validator("color")(Colors.validate_color) + _validate_symbol_color = field_validator("symbol_color")(Colors.validate_color) + + +class Curve(BECConnector, pg.PlotDataItem): + USER_ACCESS = [ + "remove", + "_rpc_id", + "_config_dict", + "set", + "set_data", + "set_color", + "set_color_map_z", + "set_symbol", + "set_symbol_color", + "set_symbol_size", + "set_pen_width", + "set_pen_style", + "get_data", + "dap_params", + "dap_summary", + "dap_oversample", + "dap_oversample.setter", + ] + + def __init__( + self, + name: str | None = None, + config: CurveConfig | None = None, + gui_id: str | None = None, + parent_item: Waveform | None = None, + **kwargs, + ): + if config is None: + config = CurveConfig(label=name, widget_class=self.__class__.__name__) + self.config = config + else: + self.config = config + super().__init__(config=config, gui_id=gui_id) + pg.PlotDataItem.__init__(self, name=name) + + self.parent_item = parent_item + self.apply_config() + self.dap_params = None + self.dap_summary = None + if kwargs: + self.set(**kwargs) + + def apply_config(self, config: dict | CurveConfig | None = None, **kwargs) -> None: + """ + Apply the configuration to the curve. + + Args: + config(dict|CurveConfig, optional): The configuration to apply. + """ + + if config is not None: + if isinstance(config, dict): + config = CurveConfig(**config) + self.config = config + + pen_style_map = { + "solid": QtCore.Qt.SolidLine, + "dash": QtCore.Qt.DashLine, + "dot": QtCore.Qt.DotLine, + "dashdot": QtCore.Qt.DashDotLine, + } + pen_style = pen_style_map.get(self.config.pen_style, QtCore.Qt.SolidLine) + + pen = pg.mkPen(color=self.config.color, width=self.config.pen_width, style=pen_style) + self.setPen(pen) + + if self.config.symbol: + symbol_color = self.config.symbol_color or self.config.color + brush = pg.mkBrush(color=symbol_color) + + self.setSymbolBrush(brush) + self.setSymbolSize(self.config.symbol_size) + self.setSymbol(self.config.symbol) + + @property + def dap_params(self): + """ + Get the dap parameters. + """ + return self._dap_params + + @dap_params.setter + def dap_params(self, value): + """ + Set the dap parameters. + + Args: + value(dict): The dap parameters. + """ + self._dap_params = value + + @property + def dap_summary(self): + """ + Get the dap summary. + """ + return self._dap_report + + @dap_summary.setter + def dap_summary(self, value): + """ + Set the dap summary. + """ + self._dap_report = value + + @property + def dap_oversample(self): + """ + Get the dap oversample. + """ + return self.config.signal.dap_oversample + + @dap_oversample.setter + def dap_oversample(self, value): + """ + Set the dap oversample. + + Args: + value(int): The dap oversample. + """ + self.config.signal.dap_oversample = value + self.parent_item.request_dap() # do immediate request for dap update + + def set_data(self, x: list | np.ndarray, y: list | np.ndarray): + """ + Set the data of the curve. + + Args: + x(list|np.ndarray): The x data. + y(list|np.ndarray): The y data. + + Raises: + ValueError: If the source is not custom. + """ + if self.config.source == "custom": + self.setData(x, y) + else: + raise ValueError(f"Source {self.config.source} do not allow custom data setting.") + + def set(self, **kwargs): + """ + Set the properties of the curve. + + Args: + **kwargs: Keyword arguments for the properties to be set. + + Possible properties: + - color: str + - symbol: str + - symbol_color: str + - symbol_size: int + - pen_width: int + - pen_style: Literal["solid", "dash", "dot", "dashdot"] + """ + + # Mapping of keywords to setter methods + method_map = { + "color": self.set_color, + "color_map_z": self.set_color_map_z, + "symbol": self.set_symbol, + "symbol_color": self.set_symbol_color, + "symbol_size": self.set_symbol_size, + "pen_width": self.set_pen_width, + "pen_style": self.set_pen_style, + } + for key, value in kwargs.items(): + if key in method_map: + method_map[key](value) + else: + logger.warning(f"Warning: '{key}' is not a recognized property.") + + def set_color(self, color: str, symbol_color: str | None = None): + """ + Change the color of the curve. + + Args: + color(str): Color of the curve. + symbol_color(str, optional): Color of the symbol. Defaults to None. + """ + self.config.color = color + self.config.symbol_color = symbol_color or color + self.apply_config() + + def set_symbol(self, symbol: str): + """ + Change the symbol of the curve. + + Args: + symbol(str): Symbol of the curve. + """ + self.config.symbol = symbol + self.setSymbol(symbol) + self.updateItems() + + def set_symbol_color(self, symbol_color: str): + """ + Change the symbol color of the curve. + + Args: + symbol_color(str): Color of the symbol. + """ + self.config.symbol_color = symbol_color + self.apply_config() + + def set_symbol_size(self, symbol_size: int): + """ + Change the symbol size of the curve. + + Args: + symbol_size(int): Size of the symbol. + """ + self.config.symbol_size = symbol_size + self.apply_config() + + def set_pen_width(self, pen_width: int): + """ + Change the pen width of the curve. + + Args: + pen_width(int): Width of the pen. + """ + self.config.pen_width = pen_width + self.apply_config() + + def set_pen_style(self, pen_style: Literal["solid", "dash", "dot", "dashdot"]): + """ + Change the pen style of the curve. + + Args: + pen_style(Literal["solid", "dash", "dot", "dashdot"]): Style of the pen. + """ + self.config.pen_style = pen_style + self.apply_config() + + def set_color_map_z(self, colormap: str): + """ + Set the colormap for the scatter plot z gradient. + + Args: + colormap(str): Colormap for the scatter plot. + """ + self.config.color_map_z = colormap + self.apply_config() + self.parent_item.update_with_scan_history(-1) + + def get_data(self) -> tuple[np.ndarray, np.ndarray]: + """ + Get the data of the curve. + Returns: + tuple[np.ndarray,np.ndarray]: X and Y data of the curve. + """ + try: + x_data, y_data = self.getData() + except TypeError: + x_data, y_data = np.array([]), np.array([]) + return x_data, y_data + + def clear_data(self): + """ + Clear the data of the curve. + """ + self.setData([], []) + + def remove(self): + """Remove the curve from the plot.""" + # self.parent_item.removeItem(self) + self.parent_item.remove_curve(self.name()) + self.rpc_register.remove_rpc(self) diff --git a/bec_widgets/widgets/plots_next_gen/waveform/register_waveform.py b/bec_widgets/widgets/plots_next_gen/waveform/register_waveform.py new file mode 100644 index 00000000..541b1578 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/register_waveform.py @@ -0,0 +1,15 @@ +def main(): # pragma: no cover + from qtpy import PYSIDE6 + + if not PYSIDE6: + print("PYSIDE6 is not available in the environment. Cannot patch designer.") + return + from PySide6.QtDesigner import QPyDesignerCustomWidgetCollection + + from bec_widgets.widgets.plots_next_gen.waveform.waveform_plugin import WaveformPlugin + + QPyDesignerCustomWidgetCollection.addCustomWidget(WaveformPlugin()) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/bec_widgets/widgets/plots_next_gen/waveform/settings/__init__.py b/bec_widgets/widgets/plots_next_gen/waveform/settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/__init__.py b/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_setting.py b/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_setting.py new file mode 100644 index 00000000..eedb1db7 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_setting.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from qtpy.QtWidgets import ( + QComboBox, + QGroupBox, + QHBoxLayout, + QLabel, + QSizePolicy, + QVBoxLayout, + QWidget, +) + +from bec_widgets.qt_utils.error_popups import SafeSlot +from bec_widgets.qt_utils.settings_dialog import SettingWidget +from bec_widgets.widgets.control.device_input.device_line_edit.device_line_edit import ( + DeviceLineEdit, +) +from bec_widgets.widgets.plots_next_gen.waveform.settings.curve_settings.curve_tree import CurveTree + +if TYPE_CHECKING: # pragma: no cover + from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform + + +class CurveSetting(SettingWidget): + def __init__(self, parent=None, target_widget: Waveform = None, *args, **kwargs): + super().__init__(parent=parent, *args, **kwargs) + self.setProperty("skip_settings", True) + self.setObjectName("CurveSetting") + self.target_widget = target_widget + + self.layout = QVBoxLayout(self) + + self._init_x_box() + self._init_y_box() + + self.setFixedWidth(580) # TODO height is still debate + + def _init_x_box(self): + self.x_axis_box = QGroupBox("X Axis") + self.x_axis_box.layout = QHBoxLayout(self.x_axis_box) + self.x_axis_box.layout.setContentsMargins(10, 10, 10, 10) + self.x_axis_box.layout.setSpacing(10) + + self.mode_combo_label = QLabel("Mode") + self.mode_combo = QComboBox() + self.mode_combo.addItems(["auto", "index", "timestamp", "device"]) + + self.spacer = QWidget() + self.spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + self.device_x_label = QLabel("Device") + self.device_x = DeviceLineEdit() + + self._get_x_mode_from_waveform() + self.switch_x_device_selection() + + self.mode_combo.currentTextChanged.connect(self.switch_x_device_selection) + + self.x_axis_box.layout.addWidget(self.mode_combo_label) + self.x_axis_box.layout.addWidget(self.mode_combo) + self.x_axis_box.layout.addWidget(self.spacer) + self.x_axis_box.layout.addWidget(self.device_x_label) + self.x_axis_box.layout.addWidget(self.device_x) + + self.x_axis_box.setFixedHeight(80) + self.layout.addWidget(self.x_axis_box) + + def _get_x_mode_from_waveform(self): + if self.target_widget.x_mode in ["auto", "index", "timestamp"]: + self.mode_combo.setCurrentText(self.target_widget.x_mode) + else: + self.mode_combo.setCurrentText("device") + + def switch_x_device_selection(self): + if self.mode_combo.currentText() == "device": + self.device_x.setEnabled(True) + self.device_x.setText(self.target_widget.x_axis_mode["name"]) + else: + self.device_x.setEnabled(False) + + def _init_y_box(self): + self.y_axis_box = QGroupBox("Y Axis") + self.y_axis_box.layout = QVBoxLayout(self.y_axis_box) + self.y_axis_box.layout.setContentsMargins(0, 0, 0, 0) + self.y_axis_box.layout.setSpacing(0) + + self.curve_manager = CurveTree(self, waveform=self.target_widget) + self.y_axis_box.layout.addWidget(self.curve_manager) + + self.layout.addWidget(self.y_axis_box) + + @SafeSlot() + def accept_changes(self): + """ + Accepts the changes made in the settings widget and applies them to the target widget. + """ + if self.mode_combo.currentText() == "device": + self.target_widget.x_mode = self.device_x.text() + else: + self.target_widget.x_mode = self.mode_combo.currentText() + self.curve_manager.send_curve_json() + + @SafeSlot() + def refresh(self): + """Refresh the curve tree and the x axis combo box in the case Waveform is modified from rpc.""" + self.curve_manager.refresh_from_waveform() + self._get_x_mode_from_waveform() diff --git a/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_tree.py b/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_tree.py new file mode 100644 index 00000000..aaef0543 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/settings/curve_settings/curve_tree.py @@ -0,0 +1,538 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from bec_qthemes._icon.material_icons import material_icon +from qtpy.QtGui import QColor +from qtpy.QtWidgets import ( + QColorDialog, + QComboBox, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QSizePolicy, + QSpinBox, + QToolButton, + QTreeWidget, + QTreeWidgetItem, + QVBoxLayout, + QWidget, +) + +from bec_widgets.qt_utils.toolbar import MaterialIconAction, ModularToolBar +from bec_widgets.utils import ConnectionConfig, EntryValidator +from bec_widgets.utils.bec_widget import BECWidget +from bec_widgets.utils.colors import Colors +from bec_widgets.widgets.control.device_input.device_line_edit.device_line_edit import ( + DeviceLineEdit, +) +from bec_widgets.widgets.dap.dap_combo_box.dap_combo_box import DapComboBox +from bec_widgets.widgets.plots_next_gen.waveform.curve import CurveConfig, DeviceSignal +from bec_widgets.widgets.utility.visual.colormap_widget.colormap_widget import BECColorMapWidget + +if TYPE_CHECKING: # pragma: no cover + from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform + + +class ColorButton(QPushButton): + """A QPushButton subclass that displays a color. + + The background is set to the given color and the button text is the hex code. + The text color is chosen automatically (black if the background is light, white if dark) + to guarantee good readability. + """ + + def __init__(self, color="#000000", parent=None): + """Initialize the color button. + + Args: + color (str): The initial color in hex format (e.g., '#000000'). + parent: Optional QWidget parent. + """ + super().__init__(parent) + self.set_color(color) + + def set_color(self, color): + """Set the button's color and update its appearance. + + Args: + color (str or QColor): The new color to assign. + """ + if isinstance(color, QColor): + self._color = color.name() + else: + self._color = color + self._update_appearance() + + def color(self): + """Return the current color in hex.""" + return self._color + + def _update_appearance(self): + """Update the button style based on the background color's brightness.""" + c = QColor(self._color) + brightness = c.lightnessF() + text_color = "#000000" if brightness > 0.5 else "#FFFFFF" + self.setStyleSheet(f"background-color: {self._color}; color: {text_color};") + self.setText(self._color) + + +class CurveRow(QTreeWidgetItem): + DELETE_BUTTON_COLOR = "#CC181E" + """A unified row that can represent either a device or a DAP curve. + + Columns: + 0: Actions (delete or "Add DAP" if source=device) + 1..2: DeviceLineEdit and QLineEdit if source=device, or "Model" label and DapComboBox if source=dap + 3: ColorButton + 4: Style QComboBox + 5: Pen width QSpinBox + 6: Symbol size QSpinBox + """ + + def __init__( + self, + tree: QTreeWidget, + parent_item: QTreeWidgetItem | None = None, + config: CurveConfig | None = None, + device_manager=None, + ): + if parent_item: + super().__init__(parent_item) + else: + # A top-level device row. + super().__init__(tree) + + self.tree = tree + self.parent_item = parent_item + self.curve_tree = tree.parent() # The CurveTree widget + self.curve_tree.all_items.append(self) # Track stable ordering + + self.dev = device_manager + self.entry_validator = EntryValidator(self.dev) + + self.config = config or CurveConfig() + self.source = self.config.source + + # Create column 0 (Actions) + self._init_actions() + # Create columns 1..2, depending on source + self._init_source_ui() + # Create columns 3..6 (color, style, width, symbol) + self._init_style_controls() + + def _init_actions(self): + """Create the actions widget in column 0, including a delete button and maybe 'Add DAP'.""" + self.actions_widget = QWidget() + actions_layout = QHBoxLayout(self.actions_widget) + actions_layout.setContentsMargins(0, 0, 0, 0) + actions_layout.setSpacing(0) + + # Delete button + self.delete_button = QToolButton() + delete_icon = material_icon( + "delete", + size=(20, 20), + convert_to_pixmap=False, + filled=False, + color=self.DELETE_BUTTON_COLOR, + ) + self.delete_button.setIcon(delete_icon) + self.delete_button.clicked.connect(lambda: self.remove_self()) + actions_layout.addWidget(self.delete_button) + + # If device row, add "Add DAP" button + if self.source == "device": + self.add_dap_button = QPushButton("DAP") + self.add_dap_button.clicked.connect(lambda: self.add_dap_row()) + actions_layout.addWidget(self.add_dap_button) + + self.tree.setItemWidget(self, 0, self.actions_widget) + + def _init_source_ui(self): + """Create columns 1 and 2. For device rows, we have device/entry edits; for dap rows, label/model combo.""" + if self.source == "device": + # Device row: columns 1..2 are device line edits + self.device_edit = DeviceLineEdit() + self.entry_edit = QLineEdit() # TODO in future will be signal line edit + if self.config.signal: + self.device_edit.setText(self.config.signal.name or "") + self.entry_edit.setText(self.config.signal.entry or "") + + self.tree.setItemWidget(self, 1, self.device_edit) + self.tree.setItemWidget(self, 2, self.entry_edit) + + else: + # DAP row: column1= "Model" label, column2= DapComboBox + self.label_widget = QLabel("Model") + self.tree.setItemWidget(self, 1, self.label_widget) + self.dap_combo = DapComboBox() + self.dap_combo.populate_fit_model_combobox() + # If config.signal has a dap + if self.config.signal and self.config.signal.dap: + dap_value = self.config.signal.dap + idx = self.dap_combo.fit_model_combobox.findText(dap_value) + if idx >= 0: + self.dap_combo.fit_model_combobox.setCurrentIndex(idx) + else: + self.dap_combo.select_fit_model("GaussianModel") # default + + self.tree.setItemWidget(self, 2, self.dap_combo) + + def _init_style_controls(self): + """Create columns 3..6: color button, style combo, width spin, symbol spin.""" + # Color in col 3 + self.color_button = ColorButton(self.config.color) + self.color_button.clicked.connect(lambda: self._select_color(self.color_button)) + self.tree.setItemWidget(self, 3, self.color_button) + + # Style in col 4 + self.style_combo = QComboBox() + self.style_combo.addItems(["solid", "dash", "dot", "dashdot"]) + idx = self.style_combo.findText(self.config.pen_style) + if idx >= 0: + self.style_combo.setCurrentIndex(idx) + self.tree.setItemWidget(self, 4, self.style_combo) + + # Pen width in col 5 + self.width_spin = QSpinBox() + self.width_spin.setRange(1, 20) + self.width_spin.setValue(self.config.pen_width) + self.tree.setItemWidget(self, 5, self.width_spin) + + # Symbol size in col 6 + self.symbol_spin = QSpinBox() + self.symbol_spin.setRange(1, 20) + self.symbol_spin.setValue(self.config.symbol_size) + self.tree.setItemWidget(self, 6, self.symbol_spin) + + def _select_color(self, button): + """ + Selects a new color using a color dialog and applies it to the specified button. Updates + related configuration properties based on the chosen color. + + Args: + button: The button widget whose color is being modified. + """ + current_color = QColor(button.color()) + chosen_color = QColorDialog.getColor(current_color, self.tree, "Select Curve Color") + if chosen_color.isValid(): + button.set_color(chosen_color) + self.config.color = chosen_color.name() + self.config.symbol_color = chosen_color.name() + + def add_dap_row(self): + """Create a new DAP row as a child. Only valid if source='device'.""" + if self.source != "device": + return + curve_tree = self.tree.parent() + parent_label = self.config.label + + # Inherit device name/entry + dev_name = "" + dev_entry = "" + if self.config.signal: + dev_name = self.config.signal.name + dev_entry = self.config.signal.entry + + # Create a new config for the DAP row + dap_cfg = CurveConfig( + widget_class="Curve", + source="dap", + parent_label=parent_label, + signal=DeviceSignal(name=dev_name, entry=dev_entry), + ) + new_dap = CurveRow(self.tree, parent_item=self, config=dap_cfg, device_manager=self.dev) + # Expand device row to show new child + self.tree.expandItem(self) + + # Give the new row a color from the buffer: + curve_tree._ensure_color_buffer_size() + idx = len(curve_tree.all_items) - 1 + new_col = curve_tree.color_buffer[idx] + new_dap.color_button.set_color(new_col) + new_dap.config.color = new_col + new_dap.config.symbol_color = new_col + + def remove_self(self): + """Remove this row from the tree and from the parent's item list.""" + # If top-level: + index = self.tree.indexOfTopLevelItem(self) + if index != -1: + self.tree.takeTopLevelItem(index) + else: + # If child item + if self.parent_item: + self.parent_item.removeChild(self) + # Also remove from all_items + curve_tree = self.tree.parent() + if self in curve_tree.all_items: + curve_tree.all_items.remove(self) + + def export_data(self) -> dict: + """Collect data from the GUI widgets, update config, and return as a dict. + + Returns: + dict: The serialized config based on the GUI state. + """ + if self.source == "device": + # Gather device name/entry + device_name = "" + device_entry = "" + if hasattr(self, "device_edit"): + device_name = self.device_edit.text() + if hasattr(self, "entry_edit"): + device_entry = self.entry_validator.validate_signal( + name=device_name, entry=self.entry_edit.text() + ) + self.entry_edit.setText(device_entry) + self.config.signal = DeviceSignal(name=device_name, entry=device_entry) + self.config.source = "device" + if not self.config.label: + self.config.label = f"{device_name}-{device_entry}".strip("-") + else: + # DAP logic + parent_conf_dict = {} + if self.parent_item: + parent_conf_dict = self.parent_item.export_data() + parent_conf = CurveConfig(**parent_conf_dict) + dev_name = "" + dev_entry = "" + if parent_conf.signal: + dev_name = parent_conf.signal.name + dev_entry = parent_conf.signal.entry + # Dap from the DapComboBox + new_dap = "GaussianModel" + if hasattr(self, "dap_combo"): + new_dap = self.dap_combo.fit_model_combobox.currentText() + self.config.signal = DeviceSignal(name=dev_name, entry=dev_entry, dap=new_dap) + self.config.source = "dap" + self.config.parent_label = parent_conf.label + self.config.label = f"{parent_conf.label}-{new_dap}".strip("-") + + # Common style fields + self.config.color = self.color_button.color() + self.config.symbol_color = self.color_button.color() + self.config.pen_style = self.style_combo.currentText() + self.config.pen_width = self.width_spin.value() + self.config.symbol_size = self.symbol_spin.value() + + return self.config.model_dump() + + +class CurveTree(BECWidget, QWidget): + """A tree widget that manages device and DAP curves.""" + + PLUGIN = False + RPC = False + + def __init__( + self, + parent: QWidget | None = None, + config: ConnectionConfig | None = None, + client=None, + gui_id: str | None = None, + waveform: Waveform | None = None, + ) -> None: + if config is None: + config = ConnectionConfig(widget_class=self.__class__.__name__) + super().__init__(client=client, gui_id=gui_id, config=config) + QWidget.__init__(self, parent=parent) + + self.waveform = waveform + if self.waveform and hasattr(self.waveform, "color_palette"): + self.color_palette = self.waveform.color_palette + else: + self.color_palette = "magma" + + self.get_bec_shortcuts() + + self.color_buffer = [] + self.all_items = [] + self.layout = QVBoxLayout(self) + self._init_toolbar() + self._init_tree() + self.refresh_from_waveform() + + def _init_toolbar(self): + """Initialize the toolbar with actions: add, send, refresh, expand, collapse, renormalize.""" + self.toolbar = ModularToolBar(target_widget=self, orientation="horizontal") + add = MaterialIconAction( + icon_name="add", tooltip="Add new curve", checkable=False, parent=self + ) + expand = MaterialIconAction( + icon_name="unfold_more", tooltip="Expand All DAP", checkable=False, parent=self + ) + collapse = MaterialIconAction( + icon_name="unfold_less", tooltip="Collapse All DAP", checkable=False, parent=self + ) + + self.toolbar.add_action("add", add, self) + self.toolbar.add_action("expand_all", expand, self) + self.toolbar.add_action("collapse_all", collapse, self) + + # Add colormap widget (not updating waveform's color_palette until Send is pressed) + self.spacer = QWidget() + self.spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.toolbar.addWidget(self.spacer) + + # Renormalize colors button + renorm_action = MaterialIconAction( + icon_name="palette", tooltip="Normalize All Colors", checkable=False, parent=self + ) + self.toolbar.add_action("renormalize_colors", renorm_action, self) + renorm_action.action.triggered.connect(lambda checked: self.renormalize_colors()) + + self.colormap_widget = BECColorMapWidget(cmap=self.color_palette or "magma") + self.toolbar.addWidget(self.colormap_widget) + self.colormap_widget.colormap_changed_signal.connect(self.handle_colormap_changed) + + add.action.triggered.connect(lambda checked: self.add_new_curve()) + expand.action.triggered.connect(lambda checked: self.expand_all_daps()) + collapse.action.triggered.connect(lambda checked: self.collapse_all_daps()) + + self.layout.addWidget(self.toolbar) + + def _init_tree(self): + """Initialize the QTreeWidget with 7 columns and compact widths.""" + self.tree = QTreeWidget() + self.tree.setColumnCount(7) + self.tree.setHeaderLabels(["Actions", "Name", "Entry", "Color", "Style", "Width", "Symbol"]) + self.tree.setColumnWidth(0, 90) + self.tree.setColumnWidth(1, 100) + self.tree.setColumnWidth(2, 100) + self.tree.setColumnWidth(3, 70) + self.tree.setColumnWidth(4, 80) + self.tree.setColumnWidth(5, 40) + self.tree.setColumnWidth(6, 40) + self.layout.addWidget(self.tree) + + def _init_color_buffer(self, size: int): + """ + Initializes the color buffer with a calculated set of colors based on the golden + angle sequence. + + Args: + size (int): The number of colors to be generated for the color buffer. + """ + self.color_buffer = Colors.golden_angle_color( + colormap=self.colormap_widget.colormap, num=size, format="HEX" + ) + + def _ensure_color_buffer_size(self): + """ + Ensures that the color buffer size meets the required number of items. + """ + current_count = len(self.color_buffer) + color_list = Colors.golden_angle_color( + colormap=self.color_palette, num=max(10, current_count + 1), format="HEX" + ) + self.color_buffer = color_list + + def handle_colormap_changed(self, new_cmap: str): + """ + Handles the updating of the color palette when the colormap is changed. + + Args: + new_cmap: The new colormap to be set as the color palette. + """ + self.color_palette = new_cmap + + def renormalize_colors(self): + """Overwrite all existing rows with new colors from the buffer in their creation order.""" + total = len(self.all_items) + self._ensure_color_buffer_size() + for idx, item in enumerate(self.all_items): + if hasattr(item, "color_button"): + new_col = self.color_buffer[idx] + item.color_button.set_color(new_col) + if hasattr(item, "config"): + item.config.color = new_col + item.config.symbol_color = new_col + + def add_new_curve(self, name: str = None, entry: str = None): + """Add a new device-type CurveRow with an assigned colormap color. + + Args: + name (str, optional): Device name. + entry (str, optional): Device entry. + style (str, optional): Pen style. Defaults to "solid". + width (int, optional): Pen width. Defaults to 4. + symbol_size (int, optional): Symbol size. Defaults to 7. + + Returns: + CurveRow: The newly created top-level row. + """ + cfg = CurveConfig( + widget_class="Curve", + parent_id=self.waveform.gui_id, + source="device", + signal=DeviceSignal(name=name or "", entry=entry or ""), + ) + new_row = CurveRow(self.tree, parent_item=None, config=cfg, device_manager=self.dev) + + # Assign color from the buffer ONLY to this new curve. + total_items = len(self.all_items) + self._ensure_color_buffer_size() + color_idx = total_items - 1 # new row is last + new_col = self.color_buffer[color_idx] + new_row.color_button.set_color(new_col) + new_row.config.color = new_col + new_row.config.symbol_color = new_col + + return new_row + + def send_curve_json(self): + """Send the current tree's config as JSON to the waveform, updating wavefrom.color_palette as well.""" + if self.waveform is not None: + self.waveform.color_palette = self.color_palette + data = self.export_all_curves() + json_data = json.dumps(data, indent=2) + if self.waveform is not None: + self.waveform.curve_json = json_data + + def export_all_curves(self) -> list: + """Recursively export data from each row. + + Returns: + list: A list of exported config dicts for every row (device and DAP). + """ + curves = [] + for i in range(self.tree.topLevelItemCount()): + item = self.tree.topLevelItem(i) + if isinstance(item, CurveRow): + curves.append(item.export_data()) + for j in range(item.childCount()): + child = item.child(j) + if isinstance(child, CurveRow): + curves.append(child.export_data()) + return curves + + def expand_all_daps(self): + """Expand all top-level rows to reveal child DAP rows.""" + for i in range(self.tree.topLevelItemCount()): + item = self.tree.topLevelItem(i) + self.tree.expandItem(item) + + def collapse_all_daps(self): + """Collapse all top-level rows, hiding child DAP rows.""" + for i in range(self.tree.topLevelItemCount()): + item = self.tree.topLevelItem(i) + self.tree.collapseItem(item) + + def refresh_from_waveform(self): + """Clear the tree and rebuild from the waveform's existing curves if any, else add sample rows.""" + if self.waveform is None: + return + self.tree.clear() + self.all_items = [] + + device_curves = [c for c in self.waveform.curves if c.config.source == "device"] + dap_curves = [c for c in self.waveform.curves if c.config.source == "dap"] + for dev in device_curves: + dr = CurveRow(self.tree, parent_item=None, config=dev.config, device_manager=self.dev) + for dap in dap_curves: + if dap.config.parent_label == dev.config.label: + CurveRow(self.tree, parent_item=dr, config=dap.config, device_manager=self.dev) diff --git a/bec_widgets/widgets/plots_next_gen/waveform/utils/__init__.py b/bec_widgets/widgets/plots_next_gen/waveform/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_widgets/widgets/plots_next_gen/waveform/utils/roi_manager.py b/bec_widgets/widgets/plots_next_gen/waveform/utils/roi_manager.py new file mode 100644 index 00000000..e40c5874 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/utils/roi_manager.py @@ -0,0 +1,84 @@ +import pyqtgraph as pg +from qtpy.QtCore import QObject, Signal, Slot + +from bec_widgets.utils.colors import get_accent_colors +from bec_widgets.utils.linear_region_selector import LinearRegionWrapper + + +class WaveformROIManager(QObject): + """ + A reusable helper class that manages a single linear ROI region on a given plot item. + It provides signals to notify about region changes and active state. + """ + + roi_changed = Signal(tuple) # Emitted when the ROI (left, right) changes + roi_active = Signal(bool) # Emitted when ROI is enabled or disabled + + def __init__(self, plot_item: pg.PlotItem, parent=None): + super().__init__(parent) + self._plot_item = plot_item + self._roi_wrapper: LinearRegionWrapper | None = None + self._roi_region: tuple[float, float] | None = None + self._accent_colors = get_accent_colors() + + @property + def roi_region(self) -> tuple[float, float] | None: + return self._roi_region + + @roi_region.setter + def roi_region(self, value: tuple[float, float] | None): + self._roi_region = value + if self._roi_wrapper is not None and value is not None: + self._roi_wrapper.linear_region_selector.setRegion(value) + + @Slot(bool) + def toggle_roi(self, enabled: bool) -> None: + if enabled: + self._enable_roi() + else: + self._disable_roi() + + @Slot(tuple) + def select_roi(self, region: tuple[float, float]): + # If ROI not present, enabling it + if self._roi_wrapper is None: + self.toggle_roi(True) + self.roi_region = region + + def _enable_roi(self): + if self._roi_wrapper is not None: + # Already enabled + return + color = self._accent_colors.default + color.setAlpha(int(0.2 * 255)) + hover_color = self._accent_colors.default + hover_color.setAlpha(int(0.35 * 255)) + + self._roi_wrapper = LinearRegionWrapper( + self._plot_item, color=color, hover_color=hover_color, parent=self + ) + self._roi_wrapper.add_region_selector() + self._roi_wrapper.region_changed.connect(self._on_region_changed) + + # If we already had a region, apply it + if self._roi_region is not None: + self._roi_wrapper.linear_region_selector.setRegion(self._roi_region) + else: + self._roi_region = self._roi_wrapper.linear_region_selector.getRegion() + + self.roi_active.emit(True) + + def _disable_roi(self): + if self._roi_wrapper is not None: + self._roi_wrapper.region_changed.disconnect(self._on_region_changed) + self._roi_wrapper.cleanup() + self._roi_wrapper.deleteLater() + self._roi_wrapper = None + + self._roi_region = None + self.roi_active.emit(False) + + @Slot(tuple) + def _on_region_changed(self, region: tuple[float, float]): + self._roi_region = region + self.roi_changed.emit(region) diff --git a/bec_widgets/widgets/plots_next_gen/waveform/waveform.py b/bec_widgets/widgets/plots_next_gen/waveform/waveform.py new file mode 100644 index 00000000..1efc0169 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/waveform.py @@ -0,0 +1,1611 @@ +from __future__ import annotations + +import json +from typing import Literal + +import lmfit +import numpy as np +import pyqtgraph as pg +from bec_lib import bec_logger, messages +from bec_lib.endpoints import MessageEndpoints +from pydantic import Field, ValidationError, field_validator +from qtpy.QtCore import Signal +from qtpy.QtWidgets import QDialog, QHBoxLayout, QMainWindow, QVBoxLayout, QWidget + +from bec_widgets.qt_utils.error_popups import SafeProperty, SafeSlot +from bec_widgets.qt_utils.settings_dialog import SettingsDialog +from bec_widgets.qt_utils.toolbar import MaterialIconAction +from bec_widgets.utils import ConnectionConfig +from bec_widgets.utils.bec_signal_proxy import BECSignalProxy +from bec_widgets.utils.colors import Colors, set_theme +from bec_widgets.widgets.dap.lmfit_dialog.lmfit_dialog import LMFitDialog +from bec_widgets.widgets.plots_next_gen.plot_base import PlotBase +from bec_widgets.widgets.plots_next_gen.waveform.curve import Curve, CurveConfig, DeviceSignal +from bec_widgets.widgets.plots_next_gen.waveform.settings.curve_settings.curve_setting import ( + CurveSetting, +) +from bec_widgets.widgets.plots_next_gen.waveform.utils.roi_manager import WaveformROIManager + +logger = bec_logger.logger + + +# noinspection PyDataclass +class WaveformConfig(ConnectionConfig): + color_palette: str | None = Field( + "magma", description="The color palette of the figure widget.", validate_default=True + ) + + model_config: dict = {"validate_assignment": True} + _validate_color_palette = field_validator("color_palette")(Colors.validate_color_map) + + +class Waveform(PlotBase): + PLUGIN = True + RPC = True + ICON_NAME = "show_chart" + USER_ACCESS = [ + # General PlotBase Settings + "enable_toolbar", + "enable_toolbar.setter", + "enable_side_panel", + "enable_side_panel.setter", + "enable_fps_monitor", + "enable_fps_monitor.setter", + "set", + "title", + "title.setter", + "x_label", + "x_label.setter", + "y_label", + "y_label.setter", + "x_limits", + "x_limits.setter", + "y_limits", + "y_limits.setter", + "x_grid", + "x_grid.setter", + "y_grid", + "y_grid.setter", + "inner_axes", + "inner_axes.setter", + "outer_axes", + "outer_axes.setter", + "lock_aspect_ratio", + "lock_aspect_ratio.setter", + "auto_range_x", + "auto_range_x.setter", + "auto_range_y", + "auto_range_y.setter", + "x_log", + "x_log.setter", + "y_log", + "y_log.setter", + "legend_label_size", + "legend_label_size.setter", + # Waveform Specific RPC Access + "__getitem__", + "curves", + "x_mode", + "x_mode.setter", + "color_palette", + "color_palette.setter", + "plot", + "add_dap_curve", + "remove_curve", + "update_with_scan_history", + "get_dap_params", + "get_dap_summary", + ] + + sync_signal_update = Signal() + async_signal_update = Signal() + request_dap_update = Signal() + unblock_dap_proxy = Signal() + dap_params_update = Signal(dict, dict) + dap_summary_update = Signal(dict, dict) + new_scan = Signal() + new_scan_id = Signal(str) + + roi_changed = Signal(tuple) + roi_active = Signal(bool) + roi_enable = Signal(bool) # enable toolbar icon + + def __init__( + self, + parent: QWidget | None = None, + config: WaveformConfig | None = None, + client=None, + gui_id: str | None = None, + popups: bool = True, + ): + if config is None: + config = WaveformConfig(widget_class=self.__class__.__name__) + super().__init__(parent=parent, config=config, client=client, gui_id=gui_id, popups=popups) + + # For PropertyManager identification + self.setObjectName("Waveform") + + # Curve data + self._sync_curves = [] + self._async_curves = [] + self._dap_curves = [] + self._mode: Literal["none", "sync", "async", "mixed"] = "none" + + # Scan data + self.old_scan_id = None + self.scan_id = None + self.scan_item = None + self.readout_priority = None + self.x_axis_mode = { + "name": "auto", + "entry": None, + "readout_priority": None, + "label_suffix": "", + } + + # Specific GUI elements + self._init_roi_manager() + self.dap_summary = None + self.dap_summary_dialog = None + self._enable_roi_toolbar_action(False) # default state where are no dap curves + self._init_curve_dialog() + self.curve_settings_dialog = None + + # Scan status update loop + self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status()) + self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress()) + + # Curve update loop + self.proxy_update_sync = pg.SignalProxy( + self.sync_signal_update, rateLimit=25, slot=self.update_sync_curves + ) + self.proxy_update_async = pg.SignalProxy( + self.async_signal_update, rateLimit=25, slot=self.update_async_curves + ) + self.proxy_dap_request = BECSignalProxy( + self.request_dap_update, rateLimit=25, slot=self.request_dap, timeout=10.0 + ) + self.unblock_dap_proxy.connect(self.proxy_dap_request.unblock_proxy) + self.roi_enable.connect(self._enable_roi_toolbar_action) + + self.update_with_scan_history(-1) + + # for updating a color scheme of curves + self._connect_to_theme_change() + + def __getitem__(self, key: int | str): + return self.get_curve(key) + + ################################################################################ + # Widget Specific GUI interactions + ################################################################################ + @SafeSlot(str) + def apply_theme(self, theme: str): + """ + Apply the theme to the widget. + + Args: + theme(str, optional): The theme to be applied. + """ + self._refresh_colors() + super().apply_theme(theme) + + def add_side_menus(self): + """ + Add side menus to the Waveform widget. + """ + super().add_side_menus() + self._add_dap_summary_side_menu() + + def add_popups(self): + """ + Add popups to the Waveform widget. + """ + super().add_popups() + LMFitDialog_action = MaterialIconAction( + icon_name="monitoring", tooltip="Open Fit Parameters", checkable=True, parent=self + ) + self.toolbar.add_action_to_bundle( + bundle_id="popup_bundle", + action_id="fit_params", + action=LMFitDialog_action, + target_widget=self, + ) + self.toolbar.widgets["fit_params"].action.triggered.connect(self.show_dap_summary_popup) + + ################################################################################ + # Roi manager + + def _init_roi_manager(self): + """ + Initialize the ROI manager for the Waveform widget. + """ + # Add toolbar icon + roi = MaterialIconAction( + icon_name="align_justify_space_between", + tooltip="Add ROI region for DAP", + checkable=True, + ) + self.toolbar.add_action_to_bundle( + bundle_id="roi", action_id="roi_linear", action=roi, target_widget=self + ) + self._roi_manager = WaveformROIManager(self.plot_item, parent=self) + + # Connect manager signals -> forward them via Waveform's own signals + self._roi_manager.roi_changed.connect(self.roi_changed) + self._roi_manager.roi_active.connect(self.roi_active) + + # Example: connect ROI changed to re-request DAP + self.roi_changed.connect(self._on_roi_changed_for_dap) + self._roi_manager.roi_active.connect(self.request_dap_update) + self.toolbar.widgets["roi_linear"].action.toggled.connect(self._roi_manager.toggle_roi) + + def _init_curve_dialog(self): + """ + Initializes the Curve dialog within the toolbar. + """ + curve_settings = MaterialIconAction( + icon_name="timeline", tooltip="Show Curve dialog.", checkable=True + ) + self.toolbar.add_action("curve", curve_settings, target_widget=self) + self.toolbar.widgets["curve"].action.triggered.connect(self.show_curve_settings_popup) + + def show_curve_settings_popup(self): + """ + Displays the curve settings popup to allow users to modify curve-related configurations. + """ + curve_action = self.toolbar.widgets["curve"].action + + if self.curve_settings_dialog is None or not self.curve_settings_dialog.isVisible(): + curve_setting = CurveSetting(target_widget=self) + self.curve_settings_dialog = SettingsDialog( + self, settings_widget=curve_setting, window_title="Curve Settings", modal=False + ) + self.curve_settings_dialog.setFixedWidth(580) + # When the dialog is closed, update the toolbar icon and clear the reference + self.curve_settings_dialog.finished.connect(self._curve_settings_closed) + self.curve_settings_dialog.show() + curve_action.setChecked(True) + else: + # If already open, bring it to the front + self.curve_settings_dialog.raise_() + self.curve_settings_dialog.activateWindow() + curve_action.setChecked(True) # keep it toggled + + def _curve_settings_closed(self): + """ + Slot for when the axis settings dialog is closed. + """ + self.curve_settings_dialog = None + self.toolbar.widgets["curve"].action.setChecked(False) + + @property + def roi_region(self) -> tuple[float, float] | None: + """ + Allows external code to get/set the ROI region easily via Waveform. + """ + return self._roi_manager.roi_region + + @roi_region.setter + def roi_region(self, value: tuple[float, float] | None): + """ + Set the ROI region limits. + + Args: + value(tuple[float, float] | None): The new ROI region limits. + """ + self._roi_manager.roi_region = value + + def select_roi(self, region: tuple[float, float]): + """ + Public method if you want the old `select_roi` style. + """ + self._roi_manager.select_roi(region) + + def toggle_roi(self, enabled: bool): + """ + Toggle the ROI on or off. + + Args: + enabled(bool): Whether to enable or disable the ROI. + """ + self._roi_manager.toggle_roi(enabled) + + def _on_roi_changed_for_dap(self): + """ + Whenever the ROI changes, you might want to re-request DAP with the new x_min, x_max. + """ + self.request_dap_update.emit() + + def _enable_roi_toolbar_action(self, enable: bool): + """ + Enable or disable the ROI toolbar action. + + Args: + enable(bool): Enable or disable the ROI toolbar action. + """ + self.toolbar.widgets["roi_linear"].action.setEnabled(enable) + if enable is False: + self.toolbar.widgets["roi_linear"].action.setChecked(False) + self._roi_manager.toggle_roi(False) + + ################################################################################ + # Dap Summary + + def _add_dap_summary_side_menu(self): + """ + Add the DAP summary to the side panel. + """ + self.dap_summary = LMFitDialog(parent=self) + self.side_panel.add_menu( + action_id="fit_params", + icon_name="monitoring", + tooltip="Open Fit Parameters", + widget=self.dap_summary, + title="Fit Parameters", + ) + self.dap_summary_update.connect(self.dap_summary.update_summary_tree) + + def show_dap_summary_popup(self): + """ + Show the DAP summary popup. + """ + fit_action = self.toolbar.widgets["fit_params"].action + if self.dap_summary_dialog is None or not self.dap_summary_dialog.isVisible(): + self.dap_summary = LMFitDialog(parent=self) + self.dap_summary_dialog = QDialog(modal=False) + self.dap_summary_dialog.layout = QVBoxLayout(self.dap_summary_dialog) + self.dap_summary_dialog.layout.addWidget(self.dap_summary) + self.dap_summary_update.connect(self.dap_summary.update_summary_tree) + self.dap_summary_dialog.finished.connect(self._dap_summary_closed) + self.dap_summary_dialog.show() + self._refresh_dap_signals() # Get current dap data + self.dap_summary_dialog.resize(300, 300) + fit_action.setChecked(True) + else: + # If already open, bring it to the front + self.dap_summary_dialog.raise_() + self.dap_summary_dialog.activateWindow() + fit_action.setChecked(True) # keep it toggle + + def _dap_summary_closed(self): + """ + Slot for when the axis settings dialog is closed. + """ + self.dap_summary_dialog.deleteLater() + self.dap_summary_dialog = None + self.toolbar.widgets["fit_params"].action.setChecked(False) + + def _get_dap_from_target_widget(self) -> None: + """Get the DAP data from the target widget and update the DAP dialog manually on creation.""" + dap_summary = self.get_dap_summary() + for curve_id, data in dap_summary.items(): + md = {"curve_id": curve_id} + self.dap_summary.update_summary_tree(data=data, metadata=md) + + @SafeSlot() + def get_dap_params(self) -> dict[str, dict]: + """ + Get the DAP parameters of all DAP curves. + + Returns: + dict[str, dict]: DAP parameters of all DAP curves. + """ + return {curve.name(): curve.dap_params for curve in self._dap_curves} + + @SafeSlot() + def get_dap_summary(self) -> dict[str, dict]: + """ + Get the DAP summary of all DAP curves. + + Returns: + dict[str, dict]: DAP summary of all DAP curves. + """ + return {curve.name(): curve.dap_summary for curve in self._dap_curves} + + ################################################################################ + # Widget Specific Properties + ################################################################################ + + @SafeProperty(str) + def x_mode(self) -> str: + return self.x_axis_mode["name"] + + @x_mode.setter + def x_mode(self, value: str): + self.x_axis_mode["name"] = value + self._switch_x_axis_item(mode=value) + self.async_signal_update.emit() + self.sync_signal_update.emit() + self.plot_item.enableAutoRange(x=True) + self.round_plot_widget.apply_plot_widget_style() # To keep the correct theme + + @SafeProperty(str) + def color_palette(self) -> str: + """ + The color palette of the figure widget. + """ + return self.config.color_palette + + @color_palette.setter + def color_palette(self, value: str): + """ + Set the color palette of the figure widget. + + Args: + value(str): The color palette to set. + """ + try: + self.config.color_palette = value + except ValidationError: + return + + colors = Colors.golden_angle_color( + colormap=self.config.color_palette, num=max(10, len(self.curves) + 1), format="HEX" + ) + for i, curve in enumerate(self.curves): + curve.set_color(colors[i]) + + @SafeProperty(str, designable=False, popup_error=True) + def curve_json(self) -> str: + """ + A JSON string property that serializes all curves' pydantic configs. + """ + raw_list = [] + for c in self.curves: + if c.config.source == "custom": # Do not serialize custom curves + continue + cfg_dict = c.config.model_dump() + raw_list.append(cfg_dict) + return json.dumps(raw_list, indent=2) + + @curve_json.setter + def curve_json(self, json_data: str): + """ + Load curves from a JSON string and add them to the plot, omitting custom source curves. + """ + try: + curve_configs = json.loads(json_data) + self.clear_all() + for cfg_dict in curve_configs: + if cfg_dict.get("source") == "custom": + logger.warning(f"Custom source curve '{cfg_dict['label']}' not loaded.") + continue + config = CurveConfig(**cfg_dict) + self._add_curve(config=config) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON: {e}") + + @property + def curves(self) -> list[Curve]: + """ + Get the curves of the plot widget as a list. + + Returns: + list: List of curves. + """ + return [item for item in self.plot_item.curves if isinstance(item, Curve)] + + ################################################################################ + # High Level methods for API + ################################################################################ + @SafeSlot(popup_error=True) + def plot( + self, + arg1: list | np.ndarray | str | None = None, + y: list | np.ndarray | None = None, + x: list | np.ndarray | None = None, + x_name: str | None = None, + y_name: str | None = None, + x_entry: str | None = None, + y_entry: str | None = None, + color: str | None = None, + label: str | None = None, + dap: str | None = None, + **kwargs, + ) -> Curve: + """ + Plot a curve to the plot widget. + + Args: + arg1(list | np.ndarray | str | None): First argument, which can be x data, y data, or y_name. + y(list | np.ndarray): Custom y data to plot. + x(list | np.ndarray): Custom y data to plot. + x_name(str): Name of the x signal. + - "auto": Use the best effort signal. + - "timestamp": Use the timestamp signal. + - "index": Use the index signal. + - Custom signal name of a device from BEC. + y_name(str): The name of the device for the y-axis. + x_entry(str): The name of the entry for the x-axis. + y_entry(str): The name of the entry for the y-axis. + color(str): The color of the curve. + label(str): The label of the curve. + dap(str): The dap model to use for the curve, only available for sync devices. + If not specified, none will be added. + Use the same string as is the name of the LMFit model. + + Returns: + Curve: The curve object. + """ + # 0) preallocate + source = "custom" + x_data = None + y_data = None + + # 1. Custom curve logic + if x is not None and y is not None: + source = "custom" + x_data = np.asarray(x) + y_data = np.asarray(y) + + if isinstance(arg1, str): + y_name = arg1 + elif isinstance(arg1, list): + if isinstance(y, list): + source = "custom" + x_data = np.asarray(arg1) + y_data = np.asarray(y) + if y is None: + source = "custom" + arr = np.asarray(arg1) + x_data = np.arange(len(arr)) + y_data = arr + elif isinstance(arg1, np.ndarray) and y is None: + if arg1.ndim == 1: + source = "custom" + x_data = np.arange(len(arg1)) + y_data = arg1 + if arg1.ndim == 2 and arg1.shape[1] == 2: + source = "custom" + x_data = arg1[:, 0] + y_data = arg1[:, 1] + + # If y_name is set => device data + if y_name is not None and x_data is None and y_data is None: + source = "device" + # Validate or obtain entry + y_entry = self.entry_validator.validate_signal(name=y_name, entry=y_entry) + + # If user gave x_name => store in x_axis_mode, but do not set data here + if x_name is not None: + self.x_mode = x_name + if x_name not in ["timestamp", "index", "auto"]: + self.x_axis_mode["entry"] = self.entry_validator.validate_signal(x_name, x_entry) + + # Decide label if not provided + if label is None: + if source == "custom": + label = f"Curve {len(self.curves) + 1}" + else: + label = f"{y_name}-{y_entry}" + + # If color not provided, generate from palette + if color is None: + color = self._generate_color_from_palette() + + # Build the config + config = CurveConfig( + widget_class="Curve", + parent_id=self.gui_id, + label=label, + color=color, + source=source, + **kwargs, + ) + + # If it's device-based, attach DeviceSignal + if source == "device": + config.signal = DeviceSignal(name=y_name, entry=y_entry) + + # CREATE THE CURVE + curve = self._add_curve(config=config, x_data=x_data, y_data=y_data) + + if dap is not None and source == "device": + self.add_dap_curve(device_label=curve.name(), dap_name=dap, **kwargs) + + return curve + + ################################################################################ + # Curve Management Methods + @SafeSlot() + def add_dap_curve( + self, + device_label: str, + dap_name: str, + color: str | None = None, + dap_oversample: int = 1, + **kwargs, + ) -> Curve: + """ + Create a new DAP curve referencing the existing device curve `device_label`, + with the data processing model `dap_name`. + + Args: + device_label(str): The label of the device curve to add DAP to. + dap_name(str): The name of the DAP model to use. + color(str): The color of the curve. + dap_oversample(int): The oversampling factor for the DAP curve. + **kwargs + + Returns: + Curve: The new DAP curve. + """ + + # 1) Find the existing device curve by label + device_curve = self._find_curve_by_label(device_label) + if not device_curve: + raise ValueError(f"No existing curve found with label '{device_label}'.") + if device_curve.config.source != "device": + raise ValueError( + f"Curve '{device_label}' is not a device curve. Only device curves can have DAP." + ) + + dev_name = device_curve.config.signal.name + dev_entry = device_curve.config.signal.entry + + # 2) Build a label for the new DAP curve + dap_label = f"{dev_name}-{dev_entry}-{dap_name}" + + # 3) Possibly raise if the DAP curve already exists + if self._check_curve_id(dap_label): + raise ValueError(f"DAP curve '{dap_label}' already exists.") + + if color is None: + color = self._generate_color_from_palette() + + # Build config for DAP + config = CurveConfig( + widget_class="Curve", + parent_id=self.gui_id, + label=dap_label, + color=color, + source="dap", + parent_label=device_label, + symbol="star", + **kwargs, + ) + + # Attach device signal with DAP + config.signal = DeviceSignal( + name=dev_name, entry=dev_entry, dap=dap_name, dap_oversample=dap_oversample + ) + + # 4) Create the DAP curve config using `_add_curve(...)` + dap_curve = self._add_curve(config=config) + + return dap_curve + + def _add_curve( + self, + config: CurveConfig, + x_data: np.ndarray | None = None, + y_data: np.ndarray | None = None, + ) -> Curve: + """ + Private method to finalize the creation of a new Curve in this Waveform widget + based on an already-built `CurveConfig`. + + Args: + config (CurveConfig): A fully populated pydantic model describing how to create and style the curve. + x_data (np.ndarray | None): If this is a custom curve (config.source == "custom"), optional x data array. + y_data (np.ndarray | None): If this is a custom curve (config.source == "custom"), optional y data array. + + Returns: + Curve: The newly created curve object. + + Raises: + ValueError: If a duplicate curve label/config is found, or if + custom data is missing for `source='custom'`. + """ + label = config.label + if not label: + # Fallback label + label = f"Curve {len(self.curves) + 1}" + config.label = label + + # Check for duplicates + if self._check_curve_id(label): + raise ValueError(f"Curve with ID '{label}' already exists in widget '{self.gui_id}'.") + + # If a user did not provide color in config, pick from palette + if not config.color: + config.color = self._generate_color_from_palette() + + # For custom data, ensure x_data, y_data + if config.source == "custom": + if x_data is None or y_data is None: + raise ValueError("For 'custom' curves, x_data and y_data must be provided.") + + # Actually create the Curve item + curve = self._add_curve_object(name=label, config=config) + + # If custom => set initial data + if config.source == "custom" and x_data is not None and y_data is not None: + curve.setData(x_data, y_data) + + # If device => schedule BEC updates + if config.source == "device": + if self.scan_item is None: + self.update_with_scan_history(-1) + if curve in self._async_curves: + self._setup_async_curve(curve) + self.async_signal_update.emit() + self.sync_signal_update.emit() + if config.source == "dap": + self.setup_dap_for_scan() + self.request_dap() # Request DAP update directly without blocking proxy + + return curve + + def _add_curve_object(self, name: str, config: CurveConfig) -> Curve: + """ + Low-level creation of the PlotDataItem (Curve) from a `CurveConfig`. + + Args: + name (str): The name/label of the curve. + config (CurveConfig): Configuration model describing the curve. + + Returns: + Curve: The newly created curve object, added to the plot. + """ + curve = Curve(config=config, name=name, parent_item=self) + self.plot_item.addItem(curve) + self._categorise_device_curves() + return curve + + def _generate_color_from_palette(self) -> str: + """ + Generate a color for the next new curve, based on the current number of curves. + """ + current_count = len(self.curves) + color_list = Colors.golden_angle_color( + colormap=self.config.color_palette, num=max(10, current_count + 1), format="HEX" + ) + return color_list[current_count] + + def _refresh_colors(self): + """ + Re-assign colors to all existing curves so they match the new count-based distribution. + """ + all_curves = self.curves + # Generate enough colors for the new total + color_list = Colors.golden_angle_color( + colormap=self.config.color_palette, num=max(10, len(all_curves)), format="HEX" + ) + for i, curve in enumerate(all_curves): + curve.set_color(color_list[i]) + + def clear_data(self): + """ + Clear all data from the plot widget, but keep the curve references. + """ + for c in self.curves: + c.clear_data() + + def clear_all(self): + """ + Clear all curves from the plot widget. + """ + curve_list = self.curves + for curve in curve_list: + self.remove_curve(curve.name()) + if self.crosshair is not None: + self.crosshair.clear_markers() + + def get_curve(self, curve: int | str) -> Curve | None: + """ + Get a curve from the plot widget. + + Args: + curve(int|str): The curve to get. It Can be the order of the curve or the name of the curve. + + Return(Curve|None): The curve object if found, None otherwise. + """ + if isinstance(curve, int): + if curve < len(self.curves): + return self.curves[curve] + elif isinstance(curve, str): + for c in self.curves: + if c.name() == curve: + return c + return None + + @SafeSlot(int, popup_error=True) + @SafeSlot(str, popup_error=True) + def remove_curve(self, curve: int | str): + """ + Remove a curve from the plot widget. + + Args: + curve(int|str): The curve to remove. It Can be the order of the curve or the name of the curve. + """ + if isinstance(curve, int): + self._remove_curve_by_order(curve) + elif isinstance(curve, str): + self._remove_curve_by_name(curve) + + self._refresh_colors() + self._categorise_device_curves() + + def _remove_curve_by_name(self, name: str): + """ + Remove a curve by its name from the plot widget. + + Args: + name(str): Name of the curve to be removed. + """ + for curve in self.curves: + if curve.name() == name: + self.plot_item.removeItem(curve) + self._curve_clean_up(curve) + return + + def _remove_curve_by_order(self, N: int): + """ + Remove a curve by its order from the plot widget. + + Args: + N(int): Order of the curve to be removed. + """ + if N < len(self.curves): + curve = self.curves[N] + self.plot_item.removeItem(curve) + self._curve_clean_up(curve) + + else: + logger.error(f"Curve order {N} out of range.") + raise IndexError(f"Curve order {N} out of range.") + + def _curve_clean_up(self, curve: Curve): + """ + Clean up the curve by disconnecting the async update signal (even for sync curves). + + Args: + curve(Curve): The curve to clean up. + """ + self.bec_dispatcher.disconnect_slot( + self.on_async_readback, + MessageEndpoints.device_async_readback(self.scan_id, curve.name()), + ) + + # Remove itself from the DAP summary only for side panels + if ( + curve.config.source == "dap" + and self.dap_summary is not None + and self.enable_side_panel is True + ): + self.dap_summary.remove_dap_data(curve.name()) + + # find a corresponding dap curve and remove it + for c in self.curves: + if c.config.parent_label == curve.name(): + self.plot_item.removeItem(c) + self._curve_clean_up(c) + + def _check_curve_id(self, curve_id: str) -> bool: + """ + Check if a curve ID exists in the plot widget. + + Args: + curve_id(str): The ID of the curve to check. + + Returns: + bool: True if the curve ID exists, False otherwise. + """ + curve_ids = [curve.name() for curve in self.curves] + if curve_id in curve_ids: + return True + return False + + def _find_curve_by_label(self, label: str) -> Curve | None: + """ + Find a curve by its label. + + Args: + label(str): The label of the curve to find. + + Returns: + Curve|None: The curve object if found, None otherwise. + """ + for c in self.curves: + if c.name() == label: + return c + return None + + ################################################################################ + # BEC Update Methods + ################################################################################ + @SafeSlot(dict, dict) + def on_scan_status(self, msg: dict, meta: dict): + """ + Initial scan status message handler, which is triggered at the begging and end of scan. + Used for triggering the update of the sync and async curves. + + Args: + msg(dict): The message content. + meta(dict): The message metadata. + """ + current_scan_id = msg.get("scan_id", None) + if current_scan_id is None: + return + + if current_scan_id != self.scan_id: + self.reset() + self.new_scan.emit() + self.new_scan_id.emit(current_scan_id) + self.auto_range_x = True + self.auto_range_y = True + self.old_scan_id = self.scan_id + self.scan_id = current_scan_id + self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # live scan + + self._mode = self._categorise_device_curves() + + # First trigger to sync and async data + if self._mode == "sync": + self.sync_signal_update.emit() + logger.info("Scan status: Sync mode") + elif self._mode == "async": + for curve in self._async_curves: + self._setup_async_curve(curve) + self.async_signal_update.emit() + logger.info("Scan status: Async mode") + else: + self.sync_signal_update.emit() + for curve in self._async_curves: + self._setup_async_curve(curve) + self.async_signal_update.emit() + logger.info("Scan status: Mixed mode") + logger.warning("Mixed mode - integrity of x axis cannot be guaranteed.") + self.setup_dap_for_scan() + + @SafeSlot(dict, dict) + def on_scan_progress(self, msg: dict, meta: dict): + """ + Slot for handling scan progress messages. Used for triggering the update of the sync curves. + + Args: + msg(dict): The message content. + meta(dict): The message metadata. + """ + self.sync_signal_update.emit() + + def _fetch_scan_data_and_access(self): + """ + Decide whether the widget is in live or historical mode + and return the appropriate data dict and access key. + + Returns: + data_dict (dict): The data structure for the current scan. + access_key (str): Either 'val' (live) or 'value' (history). + """ + if self.scan_item is None: + # Optionally fetch the latest from history if nothing is set + self.update_with_scan_history(-1) + if self.scan_item is None: + logger.info("No scan executed so far; skipping device curves categorisation.") + return "none", "none" + + if hasattr(self.scan_item, "live_data"): + # Live scan + return self.scan_item.live_data, "val" + else: + # Historical + scan_devices = self.scan_item.devices + return (scan_devices, "value") + + def update_sync_curves(self): + """ + Update the sync curves with the latest data from the scan. + """ + if self.scan_item is None: + logger.info("No scan executed so far; skipping device curves categorisation.") + return "none" + data, access_key = self._fetch_scan_data_and_access() + for curve in self._sync_curves: + device_name = curve.config.signal.name + device_entry = curve.config.signal.entry + if access_key == "val": + device_data = data.get(device_name, {}).get(device_entry, {}).get(access_key, None) + else: + device_data = ( + data.get(device_name, {}).get(device_entry, {}).read().get("value", None) + ) + x_data = self._get_x_data(device_name, device_entry) + if x_data is not None: + if np.isscalar(x_data) and len(x_data) == 1: + self.clear_data() + return + if device_data is not None and x_data is not None: + curve.setData(x_data, device_data) + if device_data is not None and x_data is None: + curve.setData(device_data) + self.request_dap_update.emit() + + def update_async_curves(self): + """ + Updates asynchronously displayed curves with the latest scan data. + + Fetches the scan data and access key to update each curve in `_async_curves` with + new values. If the data is available for a specific curve, it sets the x and y + data for the curve. Emits a signal to request an update once all curves are updated. + + Raises: + The raised errors are dependent on the internal methods such as + `_fetch_scan_data_and_access`, `_get_x_data`, or `setData` used in this + function. + + """ + data, access_key = self._fetch_scan_data_and_access() + + for curve in self._async_curves: + device_name = curve.config.signal.name + device_entry = curve.config.signal.entry + if access_key == "val": # live access + device_data = data.get(device_name, {}).get(device_entry, {}).get(access_key, None) + else: # history access + device_data = ( + data.get(device_name, {}).get(device_entry, {}).read().get("value", None) + ) + + # if shape is 2D cast it into 1D and take the last waveform + if len(np.shape(device_data)) > 1: + device_data = device_data[-1, :] + + x_data = self._get_x_data(device_name, device_entry) + + # If there's actual data, set it + if device_data is not None: + if x_data is not None: + curve.setData(x_data, device_data) + else: + curve.setData(device_data) + self.request_dap_update.emit() + + def _setup_async_curve(self, curve: Curve): + """ + Setup async curve. + + Args: + curve(Curve): The curve to set up. + """ + name = curve.config.signal.name + self.bec_dispatcher.disconnect_slot( + self.on_async_readback, MessageEndpoints.device_async_readback(self.old_scan_id, name) + ) + try: + curve.clear_data() + except KeyError: + logger.warning(f"Curve {name} not found in plot item.") + pass + self.bec_dispatcher.connect_slot( + self.on_async_readback, + MessageEndpoints.device_async_readback(self.scan_id, name), + from_start=True, + ) + logger.info(f"Setup async curve {name}") + + @SafeSlot(dict, dict) + def on_async_readback(self, msg, metadata): + """ + Get async data readback. + + Args: + msg(dict): Message with the async data. + metadata(dict): Metadata of the message. + """ + y_data = None + x_data = None + instruction = metadata.get("async_update", {}).get("type") + max_shape = metadata.get("async_update", {}).get("max_shape", []) + for curve in self._async_curves: + y_entry = curve.config.signal.entry + x_name = self.x_axis_mode["name"] + for device, async_data in msg["signals"].items(): + if device == y_entry: + data_plot = async_data["value"] + if instruction == "add": + if len(max_shape) > 1: + if len(data_plot.shape) > 1: + data_plot = data_plot[-1, :] + else: + x_data, y_data = curve.get_data() + + if y_data is not None: + new_data = np.hstack((y_data, data_plot)) # TODO check performance + else: + new_data = data_plot + if x_name == "timestamp": + if x_data is not None: + x_data = np.hstack((x_data, async_data["timestamp"])) + else: + x_data = async_data["timestamp"] + # FIXME x axis wrong if timestamp switched during scan + curve.setData(x_data, new_data) + else: # this means index as x + curve.setData(new_data) + elif instruction == "replace": + if x_name == "timestamp": + x_data = async_data["timestamp"] + curve.setData(x_data, data_plot) + else: + curve.setData(data_plot) + self.request_dap_update.emit() + + def setup_dap_for_scan(self): + """Setup DAP updates for the new scan.""" + self.bec_dispatcher.disconnect_slot( + self.update_dap_curves, + MessageEndpoints.dap_response(f"{self.old_scan_id}-{self.gui_id}"), + ) + if len(self._dap_curves) > 0: + self.bec_dispatcher.connect_slot( + self.update_dap_curves, + MessageEndpoints.dap_response(f"{self.scan_id}-{self.gui_id}"), + ) + + # @SafeSlot() #FIXME type error + def request_dap(self): + """Request new fit for data""" + + for dap_curve in self._dap_curves: + parent_label = getattr(dap_curve.config, "parent_label", None) + if not parent_label: + continue + # find the device curve + parent_curve = self._find_curve_by_label(parent_label) + if parent_curve is None: + logger.warning(f"No device curve found for DAP curve '{dap_curve.name()}'!") + continue + + x_data, y_data = parent_curve.get_data() + model_name = dap_curve.config.signal.dap + model = getattr(self.dap, model_name) + try: + x_min, x_max = self.roi_region + x_data, y_data = self._crop_data(x_data, y_data, x_min, x_max) + except TypeError: + x_min = None + x_max = None + + msg = messages.DAPRequestMessage( + dap_cls="LmfitService1D", + dap_type="on_demand", + config={ + "args": [], + "kwargs": {"data_x": x_data, "data_y": y_data}, + "class_args": model._plugin_info["class_args"], + "class_kwargs": model._plugin_info["class_kwargs"], + "curve_label": dap_curve.name(), + }, + metadata={"RID": f"{self.scan_id}-{self.gui_id}"}, + ) + self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg) + + @SafeSlot(dict, dict) + def update_dap_curves(self, msg, metadata): + """ + Update the DAP curves with the new data. + + Args: + msg(dict): Message with the DAP data. + metadata(dict): Metadata of the message. + """ + self.unblock_dap_proxy.emit() + # Extract configuration from the message + msg_config = msg.get("dap_request", None).content.get("config", {}) + curve_id = msg_config.get("curve_label", None) + curve = self._find_curve_by_label(curve_id) + if not curve: + return + + # Get data from the parent (device) curve + parent_curve = self._find_curve_by_label(curve.config.parent_label) + if parent_curve is None: + return + x_parent, _ = parent_curve.get_data() + if x_parent is None or len(x_parent) == 0: + return + + # Retrieve and store the fit parameters and summary from the DAP server response + try: + curve.dap_params = msg["data"][1]["fit_parameters"] + curve.dap_summary = msg["data"][1]["fit_summary"] + except TypeError: + logger.warning(f"Failed to retrieve DAP data for curve '{curve.name()}'") + return + + # Render model according to the DAP model name and parameters + model_name = curve.config.signal.dap + model_function = getattr(lmfit.models, model_name)() + + x_min, x_max = x_parent.min(), x_parent.max() + oversample = curve.dap_oversample + new_x = np.linspace(x_min, x_max, int(len(x_parent) * oversample)) + + # Evaluate the model with the provided parameters to generate the y values + new_y = model_function.eval(**curve.dap_params, x=new_x) + + # Update the curve with the new data + curve.setData(new_x, new_y) + + metadata.update({"curve_id": curve_id}) + self.dap_params_update.emit(curve.dap_params, metadata) + self.dap_summary_update.emit(curve.dap_summary, metadata) + + def _refresh_dap_signals(self): + """ + Refresh the DAP signals for all curves. + """ + for curve in self._dap_curves: + self.dap_params_update.emit(curve.dap_params, {"curve_id": curve.name()}) + self.dap_summary_update.emit(curve.dap_summary, {"curve_id": curve.name()}) + + def _get_x_data(self, device_name: str, device_entry: str) -> list | np.ndarray | None: + """ + Get the x data for the curves with the decision logic based on the widget x mode configuration: + - If x is called 'timestamp', use the timestamp data from the scan item. + - If x is called 'index', use the rolling index. + - If x is a custom signal, use the data from the scan item. + - If x is not specified, use the first device from the scan report. + + Additionally, checks and updates the x label suffix. + + Args: + device_name(str): The name of the device. + device_entry(str): The entry of the device + + Returns: + list|np.ndarray|None: X data for the curve. + """ + x_data = None + new_suffix = None + data, access_key = self._fetch_scan_data_and_access() + + # 1 User wants custom signal + if self.x_axis_mode["name"] not in ["timestamp", "index", "auto"]: + x_name = self.x_axis_mode["name"] + x_entry = self.x_axis_mode.get("entry", None) + if x_entry is None: + x_entry = self.entry_validator.validate_signal(x_name, None) + # if the motor was not scanned, an empty list is returned and curves are not updated + if access_key == "val": # live data + x_data = data.get(x_name, {}).get(x_entry, {}).get(access_key, 0) + else: # history data + x_data = data.get(x_name, {}).get(x_entry, {}).read().get("value", 0) + new_suffix = f" [custom: {x_name}-{x_entry}]" + + # 2 User wants timestamp + if self.x_axis_mode["name"] == "timestamp": + if access_key == "val": # live + timestamps = data[device_name][device_entry].timestamps + else: # history data + timestamps = data[device_name][device_entry].read().get("timestamp", 0) + x_data = timestamps + new_suffix = " [timestamp]" + + # 3 User wants index + if self.x_axis_mode["name"] == "index": + x_data = None + new_suffix = " [index]" + + # 4 Best effort automatic mode + if self.x_axis_mode["name"] is None or self.x_axis_mode["name"] == "auto": + # 4.1 If there are async curves, use index + if len(self._async_curves) > 0: + x_data = None + new_suffix = " [auto: index]" + # 4.2 If there are sync curves, use the first device from the scan report + else: + try: + x_name = self._ensure_str_list( + self.scan_item.metadata["bec"]["scan_report_devices"] + )[0] + except: + x_name = self.scan_item.status_message.info["scan_report_devices"][0] + x_entry = self.entry_validator.validate_signal(x_name, None) + if access_key == "val": + x_data = data.get(x_name, {}).get(x_entry, {}).get(access_key, None) + else: + x_data = data.get(x_name, {}).get(x_entry, {}).read().get("value", None) + new_suffix = f" [auto: {x_name}-{x_entry}]" + self._update_x_label_suffix(new_suffix) + return x_data + + def _update_x_label_suffix(self, new_suffix: str): + """ + Update x_label so it ends with `new_suffix`, removing any old suffix. + + Args: + new_suffix(str): The new suffix to add to the x_label. + """ + if new_suffix == self.x_axis_mode["label_suffix"]: + return + + self.x_axis_mode["label_suffix"] = new_suffix + self.set_x_label_suffix(new_suffix) + + def _switch_x_axis_item(self, mode: str): + """ + Switch the x-axis mode between timestamp, index, the best effort and custom signal. + + Args: + mode(str): Mode of the x-axis. + - "timestamp": Use the timestamp signal. + - "index": Use the index signal. + - "best_effort": Use the best effort signal. + - Custom signal name of a device from BEC. + """ + logger.info(f'Switching x-axis mode to "{mode}"') + current_axis = self.plot_item.axes["bottom"]["item"] + # Only update the axis if the mode change requires it. + if mode == "timestamp": + # Only update if the current axis is not a DateAxisItem. + if not isinstance(current_axis, pg.graphicsItems.DateAxisItem.DateAxisItem): + date_axis = pg.graphicsItems.DateAxisItem.DateAxisItem(orientation="bottom") + self.plot_item.setAxisItems({"bottom": date_axis}) + else: + # For non-timestamp modes, only update if the current axis is a DateAxisItem. + if isinstance(current_axis, pg.graphicsItems.DateAxisItem.DateAxisItem): + default_axis = pg.AxisItem(orientation="bottom") + self.plot_item.setAxisItems({"bottom": default_axis}) + + if mode not in ["timestamp", "index", "auto"]: + self.x_axis_mode["entry"] = self.entry_validator.validate_signal(mode, None) + + self.set_x_label_suffix(self.x_axis_mode["label_suffix"]) + + def _categorise_device_curves(self) -> str: + """ + Categorise the device curves into sync and async based on the readout priority. + """ + if self.scan_item is None: + self.update_with_scan_history(-1) + if self.scan_item is None: + logger.info("No scan executed so far; skipping device curves categorisation.") + return "none" + + if hasattr(self.scan_item, "live_data"): + readout_priority = self.scan_item.status_message.info["readout_priority"] # live data + else: + readout_priority = self.scan_item.metadata["bec"]["readout_priority"] # history + + # Reset sync/async curve lists + self._async_curves.clear() + self._sync_curves.clear() + self._dap_curves.clear() + found_async = False + found_sync = False + found_dap = False + mode = "sync" + + readout_priority_async = self._ensure_str_list(readout_priority.get("async", [])) + readout_priority_sync = self._ensure_str_list(readout_priority.get("monitored", [])) + + # Iterate over all curves + for curve in self.curves: + # categorise dap curves firsts + if curve.config.source == "custom": + continue + if curve.config.source == "dap": + self._dap_curves.append(curve) + found_dap = True + continue + dev_name = curve.config.signal.name + if dev_name in readout_priority_async: + self._async_curves.append(curve) + found_async = True + elif dev_name in readout_priority_sync: + self._sync_curves.append(curve) + found_sync = True + else: + logger.warning("Device {dev_name} not found in readout priority list.") + + # Determine the mode of the scan + if found_async and found_sync: + mode = "mixed" + logger.warning( + f"Found both async and sync devices in the scan. X-axis integrity cannot be guaranteed." + ) + elif found_async: + mode = "async" + elif found_sync: + mode = "sync" + + self.roi_enable.emit(found_dap) + + logger.info(f"Scan {self.scan_id} => mode={self._mode}") + return mode + + @SafeSlot(int) + @SafeSlot(str) + @SafeSlot() + def update_with_scan_history(self, scan_index: int = None, scan_id: str = None): + """ + Update the scan curves with the data from the scan storage. + Provide only one of scan_id or scan_index. + + Args: + scan_id(str, optional): ScanID of the scan to be updated. Defaults to None. + scan_index(int, optional): Index of the scan to be updated. Defaults to None. + """ + if scan_index is not None and scan_id is not None: + raise ValueError("Only one of scan_id or scan_index can be provided.") + + if scan_index is None and scan_id is None: + logger.warning(f"Neither scan_id or scan_number was provided, fetching the latest scan") + scan_index = -1 + + if scan_index is not None: + if len(self.client.history) == 0: + logger.info("No scans executed so far. Skipping scan history update.") + return + + self.scan_item = self.client.history[scan_index] + metadata = self.scan_item.metadata + self.scan_id = metadata["bec"]["scan_id"] + else: + self.scan_id = scan_id + self.scan_item = self.client.history.get_by_scan_id(scan_id) + + self._categorise_device_curves() + + self.setup_dap_for_scan() + self.sync_signal_update.emit() + self.async_signal_update.emit() + + ################################################################################ + # Utility Methods + ################################################################################ + def _ensure_str_list(self, entries: list | tuple | np.ndarray): + """ + Convert a variety of possible inputs (string, bytes, list/tuple/ndarray of either) + into a list of Python strings. + + Args: + entries: + + Returns: + list[str]: A list of Python strings. + """ + + if isinstance(entries, (list, tuple, np.ndarray)): + return [self._to_str(e) for e in entries] + else: + return [self._to_str(entries)] + + @staticmethod + def _to_str(x): + """ + Convert a single object x (which may be a Python string, bytes, or something else) + into a plain Python string. + """ + if isinstance(x, bytes): + return x.decode("utf-8", errors="replace") + return str(x) + + @staticmethod + def _crop_data(x_data, y_data, x_min=None, x_max=None): + """ + Utility function to crop x_data and y_data based on x_min and x_max. + + Args: + x_data (np.ndarray): The array of x-values. + y_data (np.ndarray): The array of y-values corresponding to x_data. + x_min (float, optional): The lower bound for cropping. Defaults to None. + x_max (float, optional): The upper bound for cropping. Defaults to None. + + Returns: + tuple: (cropped_x_data, cropped_y_data) + """ + # If either bound is None, skip cropping + if x_min is None or x_max is None: + return x_data, y_data + + # Create a boolean mask to select only those points within [x_min, x_max] + mask = (x_data >= x_min) & (x_data <= x_max) + + return x_data[mask], y_data[mask] + + ################################################################################ + # Export Methods + ################################################################################ + def get_all_data(self, output: Literal["dict", "pandas"] = "dict") -> dict: # | pd.DataFrame: + """ + Extract all curve data into a dictionary or a pandas DataFrame. + + Args: + output (Literal["dict", "pandas"]): Format of the output data. + + Returns: + dict | pd.DataFrame: Data of all curves in the specified format. + """ + data = {} + if output == "pandas": # pragma: no cover + try: + import pandas as pd + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Pandas is not installed. Please install pandas using 'pip install pandas'." + ) + + for curve in self.curves: + x_data, y_data = curve.get_data() + if x_data is not None or y_data is not None: + if output == "dict": + data[curve.name()] = {"x": x_data.tolist(), "y": y_data.tolist()} + elif output == "pandas" and pd is not None: + data[curve.name()] = pd.DataFrame({"x": x_data, "y": y_data}) + + if output == "pandas" and pd is not None: # pragma: no cover + combined_data = pd.concat( + [data[curve.name()] for curve in self.curves], + axis=1, + keys=[curve.name() for curve in self.curves], + ) + return combined_data + return data + + def export_to_matplotlib(self): # pragma: no cover + """ + Export current waveform to matplotlib gui. Available only if matplotlib is installed in the environment. + + """ + try: + import matplotlib as mpl + from pyqtgraph.exporters import MatplotlibExporter + + MatplotlibExporter(self.plot_item).export() + except ModuleNotFoundError: + logger.error("Matplotlib is not installed in the environment.") + + ################################################################################ + # Cleanup + ################################################################################ + def cleanup(self): + """ + Cleanup the widget by disconnecting signals and closing dialogs. + """ + self.proxy_dap_request.cleanup() + self.clear_all() + if self.curve_settings_dialog is not None: + self.curve_settings_dialog.close() + self.curve_settings_dialog = None + if self.dap_summary_dialog is not None: + self.dap_summary_dialog.close() + self.dap_summary_dialog = None + super().cleanup() + + +class DemoApp(QMainWindow): # pragma: no cover + def __init__(self): + super().__init__() + self.setWindowTitle("Waveform Demo") + self.resize(800, 600) + self.main_widget = QWidget() + self.layout = QHBoxLayout(self.main_widget) + self.setCentralWidget(self.main_widget) + + self.waveform_popup = Waveform(popups=True) + self.waveform_popup.plot(y_name="monitor_async") + + self.waveform_side = Waveform(popups=False) + self.waveform_side.plot(y_name="bpm4i", y_entry="bpm4i", dap="GaussianModel") + self.waveform_side.plot(y_name="bpm3a", y_entry="bpm3a") + + self.layout.addWidget(self.waveform_side) + self.layout.addWidget(self.waveform_popup) + + +if __name__ == "__main__": # pragma: no cover + import sys + + from qtpy.QtWidgets import QApplication + + app = QApplication(sys.argv) + set_theme("dark") + widget = DemoApp() + widget.show() + widget.resize(1400, 600) + sys.exit(app.exec_()) diff --git a/bec_widgets/widgets/plots_next_gen/waveform/waveform.pyproject b/bec_widgets/widgets/plots_next_gen/waveform/waveform.pyproject new file mode 100644 index 00000000..cde2c224 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/waveform.pyproject @@ -0,0 +1 @@ +{'files': ['waveform.py']} \ No newline at end of file diff --git a/bec_widgets/widgets/plots_next_gen/waveform/waveform_plugin.py b/bec_widgets/widgets/plots_next_gen/waveform/waveform_plugin.py new file mode 100644 index 00000000..338eea67 --- /dev/null +++ b/bec_widgets/widgets/plots_next_gen/waveform/waveform_plugin.py @@ -0,0 +1,54 @@ +# Copyright (C) 2022 The Qt Company Ltd. +# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause + +from qtpy.QtDesigner import QDesignerCustomWidgetInterface + +from bec_widgets.utils.bec_designer import designer_material_icon +from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform + +DOM_XML = """ + + + + +""" + + +class WaveformPlugin(QDesignerCustomWidgetInterface): # pragma: no cover + def __init__(self): + super().__init__() + self._form_editor = None + + def createWidget(self, parent): + t = Waveform(parent) + return t + + def domXml(self): + return DOM_XML + + def group(self): + return "Plot Widgets Next Gen" + + def icon(self): + return designer_material_icon(Waveform.ICON_NAME) + + def includeFile(self): + return "waveform" + + def initialize(self, form_editor): + self._form_editor = form_editor + + def isContainer(self): + return False + + def isInitialized(self): + return self._form_editor is not None + + def name(self): + return "Waveform" + + def toolTip(self): + return "Waveform" + + def whatsThis(self): + return self.toolTip() diff --git a/tests/unit_tests/client_mocks.py b/tests/unit_tests/client_mocks.py index 15885fb5..4c422d0d 100644 --- a/tests/unit_tests/client_mocks.py +++ b/tests/unit_tests/client_mocks.py @@ -1,8 +1,11 @@ # pylint: disable = no-name-in-module,missing-class-docstring, missing-module-docstring +from math import inf from unittest.mock import MagicMock, patch import fakeredis import pytest +from bec_lib.bec_service import messages +from bec_lib.endpoints import MessageEndpoints from bec_lib.redis_connector import RedisConnector from bec_widgets.tests.utils import DEVICES, DMMock, FakePositioner, Positioner @@ -50,3 +53,150 @@ def mocked_client(bec_dispatcher): with patch("builtins.isinstance", new=isinstance_mock): yield client connector.shutdown() # TODO change to real BECClient + + +################################################## +# Client Fixture with DAP +################################################## +@pytest.fixture(scope="function") +def dap_plugin_message(): + msg = messages.AvailableResourceMessage( + **{ + "resource": { + "GaussianModel": { + "class": "LmfitService1D", + "user_friendly_name": "GaussianModel", + "class_doc": "A model based on a Gaussian or normal distribution lineshape.\n\n The model has three Parameters: `amplitude`, `center`, and `sigma`.\n In addition, parameters `fwhm` and `height` are included as\n constraints to report full width at half maximum and maximum peak\n height, respectively.\n\n .. math::\n\n f(x; A, \\mu, \\sigma) = \\frac{A}{\\sigma\\sqrt{2\\pi}} e^{[{-{(x-\\mu)^2}/{{2\\sigma}^2}}]}\n\n where the parameter `amplitude` corresponds to :math:`A`, `center` to\n :math:`\\mu`, and `sigma` to :math:`\\sigma`. The full width at half\n maximum is :math:`2\\sigma\\sqrt{2\\ln{2}}`, approximately\n :math:`2.3548\\sigma`.\n\n For more information, see: https://en.wikipedia.org/wiki/Normal_distribution\n\n ", + "run_doc": "A model based on a Gaussian or normal distribution lineshape.\n\n The model has three Parameters: `amplitude`, `center`, and `sigma`.\n In addition, parameters `fwhm` and `height` are included as\n constraints to report full width at half maximum and maximum peak\n height, respectively.\n\n .. math::\n\n f(x; A, \\mu, \\sigma) = \\frac{A}{\\sigma\\sqrt{2\\pi}} e^{[{-{(x-\\mu)^2}/{{2\\sigma}^2}}]}\n\n where the parameter `amplitude` corresponds to :math:`A`, `center` to\n :math:`\\mu`, and `sigma` to :math:`\\sigma`. The full width at half\n maximum is :math:`2\\sigma\\sqrt{2\\ln{2}}`, approximately\n :math:`2.3548\\sigma`.\n\n For more information, see: https://en.wikipedia.org/wiki/Normal_distribution\n\n \n Args:\n scan_item (ScanItem): Scan item or scan ID\n device_x (DeviceBase | str): Device name for x\n signal_x (DeviceBase | str): Signal name for x\n device_y (DeviceBase | str): Device name for y\n signal_y (DeviceBase | str): Signal name for y\n parameters (dict): Fit parameters\n ", + "run_name": "fit", + "signature": [ + { + "name": "args", + "kind": "VAR_POSITIONAL", + "default": "_empty", + "annotation": "_empty", + }, + { + "name": "scan_item", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "ScanItem | str", + }, + { + "name": "device_x", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "DeviceBase | str", + }, + { + "name": "signal_x", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "DeviceBase | str", + }, + { + "name": "device_y", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "DeviceBase | str", + }, + { + "name": "signal_y", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "DeviceBase | str", + }, + { + "name": "parameters", + "kind": "KEYWORD_ONLY", + "default": None, + "annotation": "dict", + }, + { + "name": "kwargs", + "kind": "VAR_KEYWORD", + "default": "_empty", + "annotation": "_empty", + }, + ], + "auto_fit_supported": True, + "params": { + "amplitude": { + "name": "amplitude", + "value": 1.0, + "vary": True, + "min": -inf, + "max": inf, + "expr": None, + "brute_step": None, + "user_data": None, + }, + "center": { + "name": "center", + "value": 0.0, + "vary": True, + "min": -inf, + "max": inf, + "expr": None, + "brute_step": None, + "user_data": None, + }, + "sigma": { + "name": "sigma", + "value": 1.0, + "vary": True, + "min": 0, + "max": inf, + "expr": None, + "brute_step": None, + "user_data": None, + }, + "fwhm": { + "name": "fwhm", + "value": 2.35482, + "vary": False, + "min": -inf, + "max": inf, + "expr": "2.3548200*sigma", + "brute_step": None, + "user_data": None, + }, + "height": { + "name": "height", + "value": 0.3989423, + "vary": False, + "min": -inf, + "max": inf, + "expr": "0.3989423*amplitude/max(1e-15, sigma)", + "brute_step": None, + "user_data": None, + }, + }, + "class_args": [], + "class_kwargs": {"model": "GaussianModel"}, + } + } + } + ) + yield msg + + +@pytest.fixture(scope="function") +def mocked_client_with_dap(mocked_client, dap_plugin_message): + dap_services = { + "BECClient": messages.StatusMessage(name="BECClient", status=1, info={}), + "DAPServer/LmfitService1D": messages.StatusMessage( + name="LmfitService1D", status=1, info={} + ), + } + client = mocked_client + client.service_status = dap_services + client.connector.set( + topic=MessageEndpoints.dap_available_plugins("dap"), msg=dap_plugin_message + ) + + # Patch the client's DAP attribute so that the available models include "GaussianModel" + patched_models = {"GaussianModel": {}, "LorentzModel": {}, "SineModel": {}} + client.dap._available_dap_plugins = patched_models + + yield client diff --git a/tests/unit_tests/test_curve_settings.py b/tests/unit_tests/test_curve_settings.py new file mode 100644 index 00000000..a39b5ae9 --- /dev/null +++ b/tests/unit_tests/test_curve_settings.py @@ -0,0 +1,367 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from qtpy.QtWidgets import QComboBox, QVBoxLayout + +from bec_widgets.widgets.plots_next_gen.waveform.settings.curve_settings.curve_setting import ( + CurveSetting, +) +from bec_widgets.widgets.plots_next_gen.waveform.settings.curve_settings.curve_tree import CurveTree +from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform +from tests.unit_tests.client_mocks import dap_plugin_message, mocked_client, mocked_client_with_dap +from tests.unit_tests.conftest import create_widget + +################################################## +# CurveSetting +################################################## + + +@pytest.fixture +def curve_setting_fixture(qtbot, mocked_client): + """ + Creates a CurveSetting widget targeting a mock or real Waveform widget. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.x_mode = "auto" + curve_setting = create_widget(qtbot, CurveSetting, parent=None, target_widget=wf) + return curve_setting, wf + + +def test_curve_setting_init(curve_setting_fixture): + """ + Ensure CurveSetting constructs properly, with a CurveTree inside + and an x-axis group box for modes. + """ + curve_setting, wf = curve_setting_fixture + + # Basic checks + assert curve_setting.objectName() == "CurveSetting" + # The layout should be QVBoxLayout + assert isinstance(curve_setting.layout, QVBoxLayout) + + # There's an x_axis_box group and a y_axis_box group + assert hasattr(curve_setting, "x_axis_box") + assert hasattr(curve_setting, "y_axis_box") + + # The x_axis_box should contain a QComboBox for mode + mode_combo = curve_setting.mode_combo + assert isinstance(mode_combo, QComboBox) + # Should contain these items: ["auto", "index", "timestamp", "device"] + expected_modes = ["auto", "index", "timestamp", "device"] + for m in expected_modes: + assert m in [ + curve_setting.mode_combo.itemText(i) for i in range(curve_setting.mode_combo.count()) + ] + + # Check that there's a curve_manager inside y_axis_box + assert hasattr(curve_setting, "curve_manager") + assert curve_setting.y_axis_box.layout.count() > 0 + + +def test_curve_setting_accept_changes(curve_setting_fixture, qtbot): + """ + Test that calling accept_changes() applies x-axis mode changes + and triggers the CurveTree to send its curve JSON to the target waveform. + """ + curve_setting, wf = curve_setting_fixture + + # Suppose user chooses "index" from the combo + curve_setting.mode_combo.setCurrentText("index") + # The device_x is disabled if not device mode + + # Spy on 'send_curve_json' from the curve_manager + send_spy = MagicMock() + curve_setting.curve_manager.send_curve_json = send_spy + + # Call accept_changes() + curve_setting.accept_changes() + + # Check that we updated the waveform + assert wf.x_mode == "index" + # Check that the manager send_curve_json was called + send_spy.assert_called_once() + + +def test_curve_setting_switch_device_mode(curve_setting_fixture, qtbot): + """ + If user chooses device mode from the combo, the device_x line edit should be enabled + and set to the current wavefrom.x_axis_mode["name"]. + """ + curve_setting, wf = curve_setting_fixture + + # Initially we assume "auto" + assert curve_setting.mode_combo.currentText() == "auto" + # Switch to device + curve_setting.mode_combo.setCurrentText("device") + assert curve_setting.device_x.isEnabled() + + # This line edit should reflect the waveform.x_axis_mode["name"], or be blank if none + assert curve_setting.device_x.text() == wf.x_axis_mode["name"] + + +def test_curve_setting_refresh(curve_setting_fixture, qtbot): + """ + Test that calling refresh() refreshes the embedded CurveTree + and re-reads the x axis mode from the waveform. + """ + curve_setting, wf = curve_setting_fixture + + # Suppose the waveform changed x_mode from "auto" to "timestamp" behind the scenes + wf.x_mode = "timestamp" + # Spy on the curve_manager + refresh_spy = MagicMock() + curve_setting.curve_manager.refresh_from_waveform = refresh_spy + + # Call refresh + curve_setting.refresh() + + refresh_spy.assert_called_once() + # The combo should now read "timestamp" + assert curve_setting.mode_combo.currentText() == "timestamp" + + +################################################## +# CurveTree +################################################## + + +@pytest.fixture +def curve_tree_fixture(qtbot, mocked_client_with_dap): + """ + Creates a CurveTree widget referencing a mocked or real Waveform. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap) + wf.color_palette = "magma" + curve_tree = create_widget(qtbot, CurveTree, parent=None, waveform=wf) + return curve_tree, wf + + +def test_curve_tree_init(curve_tree_fixture): + """ + Test that the CurveTree initializes properly with references to the waveform, + sets up the toolbar, and an empty QTreeWidget. + """ + curve_tree, wf = curve_tree_fixture + assert curve_tree.waveform == wf + assert curve_tree.color_palette == "magma" + assert curve_tree.tree.columnCount() == 7 + + assert "add" in curve_tree.toolbar.widgets + assert "expand_all" in curve_tree.toolbar.widgets + assert "collapse_all" in curve_tree.toolbar.widgets + assert "renormalize_colors" in curve_tree.toolbar.widgets + + +def test_add_new_curve(curve_tree_fixture): + """ + Test that add_new_curve() adds a top-level item with a device curve config, + assigns it a color from the buffer, and doesn't modify existing rows. + """ + curve_tree, wf = curve_tree_fixture + curve_tree.color_buffer = ["#111111", "#222222", "#333333", "#444444", "#555555"] + + assert curve_tree.tree.topLevelItemCount() == 0 + + with patch.object(curve_tree, "_ensure_color_buffer_size") as ensure_spy: + new_item = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + ensure_spy.assert_called_once() + + assert curve_tree.tree.topLevelItemCount() == 1 + last_item = curve_tree.all_items[-1] + assert last_item is new_item + assert new_item.config.source == "device" + assert new_item.config.signal.name == "bpm4i" + assert new_item.config.signal.entry == "bpm4i" + assert new_item.config.color in curve_tree.color_buffer + + +def test_renormalize_colors(curve_tree_fixture): + """ + Test that renormalize_colors overwrites colors for all items in creation order. + """ + curve_tree, wf = curve_tree_fixture + # Add multiple curves + c1 = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + c2 = curve_tree.add_new_curve(name="bpm3a", entry="bpm3a") + curve_tree.color_buffer = [] + + set_color_spy_c1 = patch.object(c1.color_button, "set_color") + set_color_spy_c2 = patch.object(c2.color_button, "set_color") + + with set_color_spy_c1 as spy1, set_color_spy_c2 as spy2: + curve_tree.renormalize_colors() + spy1.assert_called_once() + spy2.assert_called_once() + + +def test_expand_collapse(curve_tree_fixture): + """ + Test expand_all_daps() and collapse_all_daps() calls expand/collapse on every top-level item. + """ + curve_tree, wf = curve_tree_fixture + c1 = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + curve_tree.tree.expandAll() + expand_spy = patch.object(curve_tree.tree, "expandItem") + collapse_spy = patch.object(curve_tree.tree, "collapseItem") + + with expand_spy as e_spy: + curve_tree.expand_all_daps() + e_spy.assert_called_once_with(c1) + + with collapse_spy as c_spy: + curve_tree.collapse_all_daps() + c_spy.assert_called_once_with(c1) + + +def test_send_curve_json(curve_tree_fixture, monkeypatch): + """ + Test that send_curve_json sets the waveform's color_palette and curve_json + to the exported config from the tree. + """ + curve_tree, wf = curve_tree_fixture + # Add multiple curves + curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + curve_tree.add_new_curve(name="bpm3a", entry="bpm3a") + + curve_tree.color_palette = "viridis" + curve_tree.send_curve_json() + + assert wf.color_palette == "viridis" + data = json.loads(wf.curve_json) + assert len(data) == 2 + labels = [d["label"] for d in data] + assert "bpm4i-bpm4i" in labels + assert "bpm3a-bpm3a" in labels + + +def test_refresh_from_waveform(qtbot, mocked_client_with_dap, monkeypatch): + """ + Test that refresh_from_waveform() rebuilds the tree from the waveform's curve_json + """ + patched_models = {"GaussianModel": {}, "LorentzModel": {}, "SineModel": {}} + monkeypatch.setattr(mocked_client_with_dap.dap, "_available_dap_plugins", patched_models) + + wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap) + wf.x_mode = "auto" + curve_tree = create_widget(qtbot, CurveTree, parent=None, waveform=wf) + + wf.plot(arg1="bpm4i", dap="GaussianModel") + wf.plot(arg1="bpm3a", dap="GaussianModel") + + # Clear the tree to simulate a fresh rebuild. + curve_tree.tree.clear() + curve_tree.all_items.clear() + assert curve_tree.tree.topLevelItemCount() == 0 + + # For DAP rows + curve_tree.refresh_from_waveform() + assert curve_tree.tree.topLevelItemCount() == 2 + + +def test_add_dap_row(curve_tree_fixture): + """ + Test that add_dap_row creates a new DAP curve as a child of a device curve, + with the correct configuration and parent-child relationship. + """ + curve_tree, wf = curve_tree_fixture + + # Add a device curve first + device_row = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + assert device_row.source == "device" + assert curve_tree.tree.topLevelItemCount() == 1 + assert device_row.childCount() == 0 + + # Now add a DAP row to it + device_row.add_dap_row() + + # Check that child was added + assert device_row.childCount() == 1 + dap_child = device_row.child(0) + + # Verify the DAP child has the correct configuration + assert dap_child.source == "dap" + assert dap_child.config.parent_label == device_row.config.label + + # Check that the DAP inherits device name/entry from parent + assert dap_child.config.signal.name == "bpm4i" + assert dap_child.config.signal.entry == "bpm4i" + + # Check that the item is in the curve_tree's all_items list + assert dap_child in curve_tree.all_items + + +def test_remove_self_top_level(curve_tree_fixture): + """ + Test that remove_self removes a top-level device row from the tree. + """ + curve_tree, wf = curve_tree_fixture + + # Add two device curves + row1 = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + row2 = curve_tree.add_new_curve(name="bpm3a", entry="bpm3a") + assert curve_tree.tree.topLevelItemCount() == 2 + assert len(curve_tree.all_items) == 2 + + # Remove the first row + row1.remove_self() + + # Check that only one row remains and it's the correct one + assert curve_tree.tree.topLevelItemCount() == 1 + assert curve_tree.tree.topLevelItem(0) == row2 + assert len(curve_tree.all_items) == 1 + assert curve_tree.all_items[0] == row2 + + +def test_remove_self_child(curve_tree_fixture): + """ + Test that remove_self removes a child DAP row while preserving the parent device row. + """ + curve_tree, wf = curve_tree_fixture + + # Add a device curve and a DAP child + device_row = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + device_row.add_dap_row() + dap_child = device_row.child(0) + + assert curve_tree.tree.topLevelItemCount() == 1 + assert device_row.childCount() == 1 + assert len(curve_tree.all_items) == 2 + + # Remove the DAP child + dap_child.remove_self() + + # Check that the parent device row still exists but has no children + assert curve_tree.tree.topLevelItemCount() == 1 + assert device_row.childCount() == 0 + assert len(curve_tree.all_items) == 1 + assert curve_tree.all_items[0] == device_row + + +def test_export_data_dap(curve_tree_fixture): + """ + Test that export_data from a DAP row correctly includes parent relationship and DAP model. + """ + curve_tree, wf = curve_tree_fixture + + # Add a device curve with specific parameters + device_row = curve_tree.add_new_curve(name="bpm4i", entry="bpm4i") + device_row.config.label = "bpm4i-main" + + # Add a DAP child + device_row.add_dap_row() + dap_child = device_row.child(0) + + # Set a specific model in the DAP combobox + dap_child.dap_combo.fit_model_combobox.setCurrentText("GaussianModel") + + # Export data from the DAP row + exported = dap_child.export_data() + + # Check the exported data + assert exported["source"] == "dap" + assert exported["parent_label"] == "bpm4i-main" + assert exported["signal"]["name"] == "bpm4i" + assert exported["signal"]["entry"] == "bpm4i" + assert exported["signal"]["dap"] == "GaussianModel" + assert exported["label"] == "bpm4i-main-GaussianModel" diff --git a/tests/unit_tests/test_waveform_next_gen.py b/tests/unit_tests/test_waveform_next_gen.py new file mode 100644 index 00000000..456cf5cd --- /dev/null +++ b/tests/unit_tests/test_waveform_next_gen.py @@ -0,0 +1,787 @@ +import json +from unittest.mock import MagicMock + +import numpy as np +import pyqtgraph as pg +import pytest +from pyqtgraph.graphicsItems.DateAxisItem import DateAxisItem + +from bec_widgets.widgets.plots_next_gen.plot_base import UIMode +from bec_widgets.widgets.plots_next_gen.waveform.curve import DeviceSignal +from bec_widgets.widgets.plots_next_gen.waveform.waveform import Waveform +from tests.unit_tests.client_mocks import dap_plugin_message, mocked_client, mocked_client_with_dap + +from .conftest import create_widget + +################################################## +# Waveform widget base functionality tests +################################################## + + +def test_waveform_initialization(qtbot, mocked_client): + """ + Test that a new Waveform widget initializes with the correct defaults. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + assert wf.objectName() == "Waveform" + # Inherited from PlotBase + assert wf.title == "" + assert wf.x_label == "" + assert wf.y_label == "" + # No crosshair or FPS monitor by default + assert wf.crosshair is None + assert wf.fps_monitor is None + # No curves initially + assert len(wf.plot_item.curves) == 0 + + +def test_waveform_with_side_menu(qtbot, mocked_client): + wf = create_widget(qtbot, Waveform, client=mocked_client, popups=False) + + assert wf.ui_mode == UIMode.SIDE + + +def test_plot_custom_curve(qtbot, mocked_client): + """ + Test that calling plot with explicit x and y data creates a custom curve. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + curve = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="custom_curve") + assert curve is not None + assert curve.config.source == "custom" + assert curve.config.label == "custom_curve" + x_data, y_data = curve.get_data() + np.testing.assert_array_equal(x_data, np.array([1, 2, 3])) + np.testing.assert_array_equal(y_data, np.array([4, 5, 6])) + + +def test_plot_single_arg_input_1d(qtbot, mocked_client): + """ + Test that when a single 1D numpy array is passed, the curve is created with + x-data as a generated index. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + data = np.array([10, 20, 30]) + curve = wf.plot(data, label="curve_1d") + x_data, y_data = curve.get_data() + np.testing.assert_array_equal(x_data, np.arange(len(data))) + np.testing.assert_array_equal(y_data, data) + + +def test_plot_single_arg_input_2d(qtbot, mocked_client): + """ + Test that when a single 2D numpy array (N x 2) is passed, + x and y data are extracted from the first and second columns. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + data = np.array([[1, 4], [2, 5], [3, 6]]) + curve = wf.plot(data, label="curve_2d") + x_data, y_data = curve.get_data() + np.testing.assert_array_equal(x_data, data[:, 0]) + np.testing.assert_array_equal(y_data, data[:, 1]) + + +def test_plot_single_arg_input_sync(qtbot, mocked_client): + wf = create_widget(qtbot, Waveform, client=mocked_client) + + c1 = wf.plot(arg1="bpm4i") + c2 = wf.plot(arg1="bpm3a") + + assert c1.config.source == "device" + assert c2.config.source == "device" + assert c1.config.signal == DeviceSignal(name="bpm4i", entry="bpm4i", dap=None) + assert c2.config.signal == DeviceSignal(name="bpm3a", entry="bpm3a", dap=None) + + # Check that the curve is added to the plot + assert len(wf.plot_item.curves) == 2 + + +def test_plot_single_arg_input_async(qtbot, mocked_client): + wf = create_widget(qtbot, Waveform, client=mocked_client) + + c1 = wf.plot(arg1="eiger") + c2 = wf.plot(arg1="async_device") + + assert c1.config.source == "device" + assert c2.config.source == "device" + assert c1.config.signal == DeviceSignal(name="eiger", entry="eiger", dap=None) + assert c2.config.signal == DeviceSignal(name="async_device", entry="async_device", dap=None) + + # Check that the curve is added to the plot + assert len(wf.plot_item.curves) == 2 + + +def test_curve_access_pattern(qtbot, mocked_client): + wf = create_widget(qtbot, Waveform, client=mocked_client) + + c1 = wf.plot(arg1="bpm4i") + c2 = wf.plot(arg1="bpm3a") + + # Check that the curve is added to the plot + assert len(wf.plot_item.curves) == 2 + + # Check that the curve is accessible by label + assert wf.get_curve("bpm4i-bpm4i") == c1 + assert wf.get_curve("bpm3a-bpm3a") == c2 + + # Check that the curve is accessible by index + assert wf.get_curve(0) == c1 + assert wf.get_curve(1) == c2 + + # Check that the curve is accessible by label + assert wf["bpm4i-bpm4i"] == c1 + assert wf["bpm3a-bpm3a"] == c2 + assert wf[0] == c1 + assert wf[1] == c2 + + assert wf.curves[0] == c1 + assert wf.curves[1] == c2 + + +def test_find_curve_by_label(qtbot, mocked_client): + """ + Test the _find_curve_by_label method returns the correct curve or None if not found. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c1 = wf.plot(arg1="bpm4i", label="c1_label") + c2 = wf.plot(arg1="bpm3a", label="c2_label") + + found = wf._find_curve_by_label("c1_label") + assert found == c1, "Should return the first curve" + missing = wf._find_curve_by_label("bogus_label") + assert missing is None, "Should return None if not found" + + +def test_set_x_mode(qtbot, mocked_client): + """ + Test that setting x_mode updates the internal x-axis mode state and switches + the bottom axis of the plot. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.x_mode = "timestamp" + assert wf.x_axis_mode["name"] == "timestamp" + # When x_mode is 'timestamp', the bottom axis should be a DateAxisItem. + assert isinstance(wf.plot_item.axes["bottom"]["item"], DateAxisItem) + + wf.x_mode = "index" + # For other modes, the bottom axis becomes the default AxisItem. + assert isinstance(wf.plot_item.axes["bottom"]["item"], pg.AxisItem) + + wf.x_mode = "samx" + assert wf.x_axis_mode["name"] == "samx" + assert isinstance(wf.plot_item.axes["bottom"]["item"], pg.AxisItem) + + +def test_color_palette_update(qtbot, mocked_client): + """ + Test that updating the color_palette property changes the color of existing curves. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + curve = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="test_curve") + original_color = curve.config.color + # Change to a different valid palette + wf.color_palette = "plasma" + assert wf.config.color_palette == "plasma" + # After updating the palette, the curve's color should be re-generated. + assert curve.config.color != original_color + + +def test_curve_json_property(qtbot, mocked_client): + """ + Test that the curve_json property returns a JSON string representing + non-custom curves. Since custom curves are not serialized, if only a custom + curve is added, an empty list should be returned. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="custom_curve") + json_str = wf.curve_json + data = json.loads(json_str) + assert isinstance(data, list) + # Only custom curves exist so none should be serialized. + assert len(data) == 0 + + +def test_remove_curve_waveform(qtbot, mocked_client): + """ + Test that curves can be removed from the waveform using either their label or index. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="curve1") + wf.plot(x=[4, 5, 6], y=[7, 8, 9], label="curve2") + num_before = len(wf.plot_item.curves) + wf.remove_curve("curve1") + num_after = len(wf.plot_item.curves) + assert num_after == num_before - 1 + + wf.remove_curve(0) + assert len(wf.plot_item.curves) == num_after - 1 + + +def test_get_all_data_empty(qtbot, mocked_client): + """ + Test that get_all_data returns an empty dictionary when no curves have been added. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + all_data = wf.get_all_data(output="dict") + assert all_data == {} + + +def test_get_all_data_dict(qtbot, mocked_client): + """ + Test that get_all_data returns a dictionary with the expected x and y data for each curve. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="curve1") + wf.plot(x=[7, 8, 9], y=[10, 11, 12], label="curve2") + + all_data = wf.get_all_data(output="dict") + + expected = { + "curve1": {"x": [1, 2, 3], "y": [4, 5, 6]}, + "curve2": {"x": [7, 8, 9], "y": [10, 11, 12]}, + } + assert all_data == expected + + +def test_curve_json_getter_setter(qtbot, mocked_client): + """ + Test that the curve_json getter returns a JSON string representing device curves + and that setting curve_json re-creates the curves. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + + # These curves should be in JSON + wf.plot(arg1="bpm4i") + wf.plot(arg1="bpm3a") + # Custom curves should be ignored + wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="custom_curve") + wf.plot([1, 2, 3, 4]) + + # Get JSON from the getter. + json_str = wf.curve_json + curve_configs = json.loads(json_str) + # Only device curves are serialized; expect two configurations. + assert isinstance(curve_configs, list) + assert len(curve_configs) == 2 + labels = [cfg["label"] for cfg in curve_configs] + assert "bpm4i-bpm4i" in labels + assert "bpm3a-bpm3a" in labels + + # Clear all curves. + wf.clear_all() + assert len(wf.plot_item.curves) == 0 + + # Use the JSON setter to re-create the curves. + wf.curve_json = json_str + # After setting, the waveform should have two curves. + assert len(wf.plot_item.curves) == 2 + new_labels = [curve.name() for curve in wf.plot_item.curves] + for lab in labels: + assert lab in new_labels + + +def test_curve_json_setter_ignores_custom(qtbot, mocked_client): + """ + Test that when curve_json setter is given a JSON string containing a + curve with source "custom", that curve is not added. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + + device_curve_config = { + "widget_class": "Curve", + "parent_id": wf.gui_id, + "label": "device_curve", + "color": "#ff0000", + "source": "device", + "signal": {"name": "bpm4i", "entry": "bpm4i", "dap": None}, + } + custom_curve_config = { + "widget_class": "Curve", + "parent_id": wf.gui_id, + "label": "custom_curve", + "color": "#00ff00", + "source": "custom", + # No signal for custom curves. + } + json_str = json.dumps([device_curve_config, custom_curve_config], indent=2) + wf.curve_json = json_str + # Only the device curve should be added. + curves = wf.plot_item.curves + assert len(curves) == 1 + assert curves[0].name() == "device_curve" + + +################################################## +# Waveform widget scan logic tests +################################################## + + +class DummyData: + def __init__(self, val, timestamps): + self.val = val + self.timestamps = timestamps + + def get(self, key, default=None): + if key == "val": + return self.val + return default + + +def create_dummy_scan_item(): + """ + Helper to create a dummy scan item with both live_data and metadata/status_message info. + """ + dummy_live_data = { + "samx": {"samx": DummyData(val=[10, 20, 30], timestamps=[100, 200, 300])}, + "bpm4i": {"bpm4i": DummyData(val=[5, 6, 7], timestamps=[101, 201, 301])}, + "async_device": {"async_device": DummyData(val=[1, 2, 3], timestamps=[11, 21, 31])}, + } + dummy_scan = MagicMock() + dummy_scan.live_data = dummy_live_data + dummy_scan.metadata = { + "bec": { + "scan_id": "dummy", + "scan_report_devices": ["samx"], + "readout_priority": {"monitored": ["bpm4i"], "async": ["async_device"]}, + } + } + dummy_scan.status_message = MagicMock() + dummy_scan.status_message.info = { + "readout_priority": {"monitored": ["bpm4i"], "async": ["async_device"]}, + "scan_report_devices": ["samx"], + } + return dummy_scan + + +def test_update_sync_curves(monkeypatch, qtbot, mocked_client): + """ + Test that update_sync_curves retrieves live data correctly and calls setData on sync curves. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(arg1="bpm4i") + wf._sync_curves = [c] + wf.x_mode = "timestamp" + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + + recorded = {} + + def fake_setData(x, y): + recorded["x"] = x + recorded["y"] = y + + monkeypatch.setattr(c, "setData", fake_setData) + + wf.update_sync_curves() + np.testing.assert_array_equal(recorded.get("x"), [101, 201, 301]) + np.testing.assert_array_equal(recorded.get("y"), [5, 6, 7]) + + +def test_update_async_curves(monkeypatch, qtbot, mocked_client): + """ + Test that update_async_curves retrieves live data correctly and calls setData on async curves. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(arg1="async_device", label="async_device-async_device") + wf._async_curves = [c] + wf.x_mode = "timestamp" + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + + recorded = {} + + def fake_setData(x, y): + recorded["x"] = x + recorded["y"] = y + + monkeypatch.setattr(c, "setData", fake_setData) + + wf.update_async_curves() + np.testing.assert_array_equal(recorded.get("x"), [11, 21, 31]) + np.testing.assert_array_equal(recorded.get("y"), [1, 2, 3]) + + +def test_get_x_data_custom(monkeypatch, qtbot, mocked_client): + """ + Test that _get_x_data returns the correct custom signal data. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + # Set x_mode to a custom mode. + wf.x_axis_mode["name"] = "custom_signal" + wf.x_axis_mode["entry"] = "custom_entry" + dummy_data = DummyData(val=[50, 60, 70], timestamps=[150, 160, 170]) + dummy_live = {"custom_signal": {"custom_entry": dummy_data}} + monkeypatch.setattr(wf, "_fetch_scan_data_and_access", lambda: (dummy_live, "val")) + x_data = wf._get_x_data("irrelevant", "irrelevant") + np.testing.assert_array_equal(x_data, [50, 60, 70]) + + +def test_get_x_data_timestamp(monkeypatch, qtbot, mocked_client): + """ + Test that _get_x_data returns the correct timestamp data. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.x_axis_mode["name"] = "timestamp" + dummy_data = DummyData(val=[50, 60, 70], timestamps=[101, 202, 303]) + dummy_live = {"deviceX": {"entryX": dummy_data}} + monkeypatch.setattr(wf, "_fetch_scan_data_and_access", lambda: (dummy_live, "val")) + x_data = wf._get_x_data("deviceX", "entryX") + np.testing.assert_array_equal(x_data, [101, 202, 303]) + + +def test_categorise_device_curves(monkeypatch, qtbot, mocked_client): + """ + Test that _categorise_device_curves correctly categorizes curves. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + + c_sync = wf.plot(arg1="bpm4i", label="bpm4i-bpm4i") + c_async = wf.plot(arg1="async_device", label="async_device-async_device") + + mode = wf._categorise_device_curves() + + assert mode == "mixed" + assert c_sync in wf._sync_curves + assert c_async in wf._async_curves + + +@pytest.mark.parametrize( + ["mode", "calls"], [("sync", (1, 0)), ("async", (0, 1)), ("mixed", (1, 1))] +) +def test_on_scan_status(qtbot, mocked_client, monkeypatch, mode, calls): + """ + Test that on_scan_status sets up a new scan correctly, + categorizes curves, and triggers sync/async updates as needed. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + # Force creation of a couple of device curves + if mode == "sync": + wf.plot(arg1="bpm4i") + elif mode == "async": + wf.plot(arg1="async_device") + else: + wf.plot(arg1="bpm4i") + wf.plot(arg1="async_device") + + # We mock out the scan_item, pretending we found a new scan. + dummy_scan = create_dummy_scan_item() + dummy_scan.metadata["bec"]["scan_id"] = "1234" + monkeypatch.setattr(wf.queue.scan_storage, "find_scan_by_ID", lambda scan_id: dummy_scan) + + # We'll track calls to sync_signal_update and async_signal_update + sync_spy = MagicMock() + async_spy = MagicMock() + wf.sync_signal_update.connect(sync_spy) + wf.async_signal_update.connect(async_spy) + + # Prepare fake message data + msg = {"scan_id": "1234"} + meta = {} + wf.on_scan_status(msg, meta) + + assert wf.scan_id == "1234" + assert wf.scan_item == dummy_scan + assert wf._mode == mode + + assert sync_spy.call_count == calls[0], "sync_signal_update should be called exactly once" + assert async_spy.call_count == calls[1], "async_signal_update should be called exactly once" + + +def test_add_dap_curve(qtbot, mocked_client_with_dap, monkeypatch): + """ + Test add_dap_curve creates a new DAP curve from an existing device curve + and verifies that the DAP call doesn't fail due to mock-based plugin_info. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap) + wf.plot(arg1="bpm4i", label="bpm4i-bpm4i") + + dap_curve = wf.add_dap_curve(device_label="bpm4i-bpm4i", dap_name="GaussianModel") + assert dap_curve is not None + assert dap_curve.config.source == "dap" + assert dap_curve.config.signal.name == "bpm4i" + assert dap_curve.config.signal.dap == "GaussianModel" + + +def test_fetch_scan_data_and_access(qtbot, mocked_client, monkeypatch): + """ + Test the _fetch_scan_data_and_access method returns live_data/val if in a live scan, + or device dict/value if in a historical scan. Also test fallback if no scan_item. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + + wf.scan_item = None + + hist_mock = MagicMock() + monkeypatch.setattr(wf, "update_with_scan_history", hist_mock) + + wf._fetch_scan_data_and_access() + hist_mock.assert_called_once_with(-1) + + # Ckeck live mode + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + data_dict, access_key = wf._fetch_scan_data_and_access() + assert data_dict == dummy_scan.live_data + assert access_key == "val" + + # Check history mode + del dummy_scan.live_data + dummy_scan.devices = {"some_device": {"some_entry": "some_value"}} + data_dict, access_key = wf._fetch_scan_data_and_access() + assert "some_device" in data_dict # from dummy_scan.devices + assert access_key == "value" + + +def test_setup_async_curve(qtbot, mocked_client, monkeypatch): + """ + Test that _setup_async_curve properly disconnects old signals + and re-connects the async readback for a new scan ID. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + wf.old_scan_id = "111" + wf.scan_id = "222" + + c = wf.plot(arg1="async_device", label="async_device-async_device") + # check that it was placed in _async_curves or so + wf._async_curves = [c] + + # We'll spy on connect_slot + connect_spy = MagicMock() + monkeypatch.setattr(wf.bec_dispatcher, "connect_slot", connect_spy) + + wf._setup_async_curve(c) + connect_spy.assert_called_once() + endpoint_called = connect_spy.call_args[0][1].endpoint + # We expect MessageEndpoints.device_async_readback('222', 'async_device') + assert "222" in endpoint_called + assert "async_device" in endpoint_called + + +@pytest.mark.parametrize("x_mode", ("timestamp", "index")) +def test_on_async_readback(qtbot, mocked_client, x_mode): + """ + Test that on_async_readback extends or replaces async data depending on metadata instruction. + For 'timestamp' mode, new timestamps are appended to x_data. + For 'index' mode, x_data simply increases by integer index. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + + c = wf.plot(arg1="async_device", label="async_device-async_device") + wf._async_curves = [c] + # Suppose existing data + c.setData([0, 1, 2], [10, 11, 12]) + + # Set the x_axis_mode + wf.x_axis_mode["name"] = x_mode + + # Extend readback + msg = {"signals": {"async_device": {"value": [100, 200], "timestamp": [1001, 1002]}}} + metadata = {"async_update": {"max_shape": [None], "type": "add"}} + wf.on_async_readback(msg, metadata) + + x_data, y_data = c.get_data() + assert len(x_data) == 5 + # Check x_data based on x_mode + if x_mode == "timestamp": + np.testing.assert_array_equal(x_data, [0, 1, 2, 1001, 1002]) + else: # x_mode == "index" + np.testing.assert_array_equal(x_data, [0, 1, 2, 3, 4]) + + np.testing.assert_array_equal(y_data, [10, 11, 12, 100, 200]) + + # instruction='replace' + msg2 = {"signals": {"async_device": {"value": [999], "timestamp": [555]}}} + metadata2 = {"async_update": {"max_shape": [None], "type": "replace"}} + wf.on_async_readback(msg2, metadata2) + x_data2, y_data2 = c.get_data() + if x_mode == "timestamp": + np.testing.assert_array_equal(x_data2, [555]) + else: + + np.testing.assert_array_equal(x_data2, [0]) + + np.testing.assert_array_equal(y_data2, [999]) + + +def test_get_x_data(qtbot, mocked_client, monkeypatch): + """ + Test _get_x_data logic for multiple modes: 'timestamp', 'index', 'custom', 'auto'. + Use a dummy scan_item that returns specific data for the requested signal. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + dummy_scan = create_dummy_scan_item() + wf.scan_item = dummy_scan + + # 1) x_mode == 'timestamp' + wf.x_axis_mode["name"] = "timestamp" + x_data = wf._get_x_data("bpm4i", "bpm4i") + np.testing.assert_array_equal(x_data, [101, 201, 301]) + + # 2) x_mode == 'index' => returns None => means use Y data indexing + wf.x_axis_mode["name"] = "index" + x_data2 = wf._get_x_data("bpm4i", "bpm4i") + assert x_data2 is None + + # 3) custom x => e.g. "samx" + wf.x_axis_mode["name"] = "samx" + x_custom = wf._get_x_data("bpm4i", "bpm4i") + # because dummy_scan.live_data["samx"]["samx"].val => [10,20,30] + np.testing.assert_array_equal(x_custom, [10, 20, 30]) + + # 4) auto + wf._async_curves.clear() + wf._sync_curves = [MagicMock()] # pretend we have a sync device + wf.x_axis_mode["name"] = "auto" + x_auto = wf._get_x_data("bpm4i", "bpm4i") + # By default it tries the "scan_report_devices" => "samx" => same as custom above + np.testing.assert_array_equal(x_auto, [10, 20, 30]) + + +################################################## +# The following tests are for the Curve class +################################################## + + +def test_curve_set_appearance_methods(qtbot, mocked_client): + """ + Test that the Curve appearance setter methods update the configuration properly. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="appearance_curve") + c.set_color("#0000ff") + c.set_symbol("x") + c.set_symbol_color("#ff0000") + c.set_symbol_size(10) + c.set_pen_width(3) + c.set_pen_style("dashdot") + assert c.config.color == "#0000ff" + assert c.config.symbol == "x" + assert c.config.symbol_color == "#ff0000" + assert c.config.symbol_size == 10 + assert c.config.pen_width == 3 + assert c.config.pen_style == "dashdot" + + +def test_curve_set_custom_data(qtbot, mocked_client): + """ + Test that custom curves allow setting new data via set_data. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="custom_data_curve") + # Change data + c.set_data([7, 8, 9], [10, 11, 12]) + x_data, y_data = c.get_data() + np.testing.assert_array_equal(x_data, np.array([7, 8, 9])) + np.testing.assert_array_equal(y_data, np.array([10, 11, 12])) + + +def test_curve_set_data_error_non_custom(qtbot, mocked_client): + """ + Test that calling set_data on a non-custom (device) curve raises a ValueError. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + # Create a device curve by providing y_name (which makes source 'device') + # Assume that entry_validator returns a valid entry. + c = wf.plot(arg1="bpm4i", label="device_curve") + with pytest.raises(ValueError): + c.set_data([1, 2, 3], [4, 5, 6]) + + +def test_curve_remove(qtbot, mocked_client): + """ + Test that calling remove() on a Curve calls its parent's remove_curve method. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c1 = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="curve_1") + c2 = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="curve_2") + + assert len(wf.plot_item.curves) == 2 + c1.remove() + assert len(wf.plot_item.curves) == 1 + assert c1 not in wf.plot_item.curves + assert c2 in wf.plot_item.curves + + +def test_curve_dap_params_and_summary(qtbot, mocked_client): + """ + Test that dap_params and dap_summary properties work as expected. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="dap_curve") + c.dap_params = {"param": 1} + c.dap_summary = {"summary": "test"} + assert c.dap_params == {"param": 1} + assert c.dap_summary == {"summary": "test"} + + +def test_curve_set_method(qtbot, mocked_client): + """ + Test the convenience set(...) method of the Curve for updating appearance properties. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + c = wf.plot(x=[1, 2, 3], y=[4, 5, 6], label="set_method_curve") + c.set( + color="#123456", + symbol="d", + symbol_color="#654321", + symbol_size=12, + pen_width=5, + pen_style="dot", + ) + assert c.config.color == "#123456" + assert c.config.symbol == "d" + assert c.config.symbol_color == "#654321" + assert c.config.symbol_size == 12 + assert c.config.pen_width == 5 + assert c.config.pen_style == "dot" + + +################################################## +# Settings and popups +################################################## + + +def test_show_curve_settings_popup(qtbot, mocked_client): + """ + Test that show_curve_settings_popup displays the settings dialog and toggles the toolbar icon. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client) + + curve_action = wf.toolbar.widgets["curve"].action + assert not curve_action.isChecked(), "Should start unchecked" + + wf.show_curve_settings_popup() + + assert wf.curve_settings_dialog is not None + assert wf.curve_settings_dialog.isVisible() + assert curve_action.isChecked() + + wf.curve_settings_dialog.close() + assert wf.curve_settings_dialog is None + assert not curve_action.isChecked(), "Should be unchecked after closing dialog" + + +def test_show_dap_summary_popup(qtbot, mocked_client): + """ + Test that show_dap_summary_popup displays the DAP summary dialog and toggles the 'fit_params' toolbar icon. + """ + wf = create_widget(qtbot, Waveform, client=mocked_client, popups=True) + + assert "fit_params" in wf.toolbar.widgets + + fit_action = wf.toolbar.widgets["fit_params"].action + assert fit_action.isChecked() is False + + wf.show_dap_summary_popup() + + assert wf.dap_summary_dialog is not None + assert wf.dap_summary_dialog.isVisible() + assert fit_action.isChecked() is True + + wf.dap_summary_dialog.close() + assert wf.dap_summary_dialog is None + assert fit_action.isChecked() is False