From 70c4e9bc5ebba2445480a56d2bf6721840cfd170 Mon Sep 17 00:00:00 2001 From: wyzula-jan <133381102+wyzula-jan@users.noreply.github.com> Date: Sun, 25 Feb 2024 18:41:31 +0100 Subject: [PATCH] refactor(plots/plot_base): BECPlotBase inherits from pg.GraphicalLayout instead of pg.PlotItem, this will allow us to add multiple plots into each coordinate of BECFigure. --- bec_widgets/cli/client.py | 10 +- bec_widgets/widgets/figure/figure.py | 58 ++++-- bec_widgets/widgets/plots/__init__.py | 1 + bec_widgets/widgets/plots/image.py | 246 ++++++++++++++++++++++++ bec_widgets/widgets/plots/plot_base.py | 39 ++-- bec_widgets/widgets/plots/waveform1d.py | 36 ++-- tests/test_waveform1d.py | 4 +- 7 files changed, 348 insertions(+), 46 deletions(-) create mode 100644 bec_widgets/widgets/plots/image.py diff --git a/bec_widgets/cli/client.py b/bec_widgets/cli/client.py index efae38ac..7806b8c7 100644 --- a/bec_widgets/cli/client.py +++ b/bec_widgets/cli/client.py @@ -101,7 +101,15 @@ class BECPlotBase(RPCBase): """ @rpc_call - def plot_data(self, data_x: "list | np.ndarray", data_y: "list | np.ndarray", **kwargs): + def lock_aspect_ratio(self, lock): + """ + Lock aspect ratio. + Args: + lock(bool): True to lock, False to unlock. + """ + + @rpc_call + def plot(self, data_x: "list | np.ndarray", data_y: "list | np.ndarray", **kwargs): """ Plot custom data on the plot widget. These data are not saved in config. Args: diff --git a/bec_widgets/widgets/figure/figure.py b/bec_widgets/widgets/figure/figure.py index 8bf423d5..45bbb02d 100644 --- a/bec_widgets/widgets/figure/figure.py +++ b/bec_widgets/widgets/figure/figure.py @@ -15,7 +15,14 @@ from qtpy.QtWidgets import QApplication, QWidget from qtpy.QtWidgets import QVBoxLayout, QMainWindow from bec_widgets.utils import BECConnector, BECDispatcher, ConnectionConfig -from bec_widgets.widgets.plots import BECPlotBase, BECWaveform1D, Waveform1DConfig, WidgetConfig +from bec_widgets.widgets.plots import ( + BECPlotBase, + BECWaveform1D, + Waveform1DConfig, + WidgetConfig, + BECImageShow, +) +from bec_widgets.widgets.plots.image import BECImageShowWithHistogram class FigureConfig(ConnectionConfig): @@ -36,6 +43,7 @@ class WidgetHandler: self.widget_factory = { "PlotBase": (BECPlotBase, WidgetConfig), "Waveform1D": (BECWaveform1D, Waveform1DConfig), + "ImShow": (BECImageShow, WidgetConfig), } def create_widget( @@ -140,9 +148,21 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget): **axis_kwargs, ) + def add_image( + self, widget_id: str = None, row: int = None, col: int = None, config=None, **axis_kwargs + ) -> BECImageShow: + return self.add_widget( + widget_type="ImShow", + widget_id=widget_id, + row=row, + col=col, + config=config, + **axis_kwargs, + ) + def add_widget( self, - widget_type: Literal["PlotBase", "Waveform1D"] = "PlotBase", + widget_type: Literal["PlotBase", "Waveform1D", "ImShow"] = "PlotBase", widget_id: str = None, row: int = None, col: int = None, @@ -401,6 +421,16 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget): sys.exit(app.exec_()) + def add_image_with_histogram(self, image_data, widget_id=None, row=None, col=None): + # Create the custom image show widget + image_widget = BECImageShowWithHistogram() + + # Set the image data + image_widget.setImage(image_data) + + # Add the widget to BECFigure + self.addItem(image_widget, row=row, col=col) + ################################################## ################################################## @@ -443,7 +473,7 @@ class DebugWindow(QWidget): # pragma: no cover: # console push self.console.kernel_manager.kernel.shell.push( - {"fig": self.figure, "w1": self.w1, "w2": self.w2} + {"fig": self.figure, "w1": self.w1, "w2": self.w2, "w3": self.w3, "w4": self.w4} ) def _init_ui(self): @@ -454,6 +484,7 @@ class DebugWindow(QWidget): # pragma: no cover: # add stuff to figure self._init_figure() + # self.add_debug_histo() self.console_layout = QVBoxLayout(self.widget_console) self.console = JupyterConsoleWidget() @@ -465,6 +496,7 @@ class DebugWindow(QWidget): # pragma: no cover: self.figure.add_widget(widget_type="Waveform1D", row=1, col=0, title="Widget 2") self.figure.add_widget(widget_type="Waveform1D", row=0, col=1, title="Widget 3") self.figure.add_widget(widget_type="Waveform1D", row=1, col=1, title="Widget 4") + # self.figure.add_image(title="Image", row=1, col=1) self.w1 = self.figure[0, 0] self.w2 = self.figure[1, 0] @@ -500,14 +532,18 @@ class DebugWindow(QWidget): # pragma: no cover: ) # curves for w4 - self.w4.add_curve_scan("samx", "bpm4i", pen_style="dash") - self.w4.add_curve_custom( - x=[1, 2, 3, 4, 5], - y=[1, 2, 3, 4, 5], - label="curve-custom", - color="blue", - pen_style="dashdot", - ) + # self.w4.add_curve_scan("samx", "bpm4i", pen_style="dash") + # self.w4.add_curve_custom( + # x=[1, 2, 3, 4, 5], + # y=[1, 2, 3, 4, 5], + # label="curve-custom", + # color="blue", + # pen_style="dashdot", + # ) + + def add_debug_histo(self): + image_data = np.random.normal(loc=100, scale=50, size=(100, 100)) # Example image data + self.figure.add_image_with_histogram(image_data, row=2, col=0) if __name__ == "__main__": # pragma: no cover diff --git a/bec_widgets/widgets/plots/__init__.py b/bec_widgets/widgets/plots/__init__.py index a0cea61f..28bc0e0a 100644 --- a/bec_widgets/widgets/plots/__init__.py +++ b/bec_widgets/widgets/plots/__init__.py @@ -1,2 +1,3 @@ from .plot_base import AxisConfig, WidgetConfig, BECPlotBase from .waveform1d import Waveform1DConfig, BECWaveform1D, BECCurve +from .image import BECImageShow, BECImageShowConfig, BECImageItem diff --git a/bec_widgets/widgets/plots/image.py b/bec_widgets/widgets/plots/image.py new file mode 100644 index 00000000..2556e004 --- /dev/null +++ b/bec_widgets/widgets/plots/image.py @@ -0,0 +1,246 @@ +from __future__ import annotations +import scipy as sp + +from collections import defaultdict +from typing import Literal, Optional, Any + +import numpy as np +import pyqtgraph as pg +from PyQt6.QtWidgets import QMainWindow +from qtpy.QtCore import QThread +from pydantic import Field, BaseModel, ValidationError +from pyqtgraph import mkBrush +from qtpy import QtCore +from qtpy.QtCore import Signal as pyqtSignal +from qtpy.QtCore import Slot as pyqtSlot +from qtpy.QtWidgets import QWidget + +from bec_lib import MessageEndpoints, RedisConnector +from bec_lib.scan_data import ScanData +from bec_widgets.utils import Colors, ConnectionConfig, BECConnector, EntryValidator, BECDispatcher +from bec_widgets.widgets.plots import BECPlotBase, WidgetConfig + + +class ImageConfig(ConnectionConfig): + pass + + +class BECImageShowConfig(WidgetConfig): + pass + + +class BECImageItem(BECConnector, pg.ImageItem): + USER_ACCESS = [] + + def __init__( + self, + config: Optional[ImageConfig] = None, + gui_id: Optional[str] = None, + **kwargs, + ): + if config is None: + config = ImageConfig(widget_class=self.__class__.__name__) + self.config = config + else: + self.config = config + # config.widget_class = self.__class__.__name__ + super().__init__(config=config, gui_id=gui_id) + pg.ImageItem.__init__(self) + + self.apply_config() + if kwargs: + self.set(**kwargs) + + def apply_config(self): + pass + + def set(self, **kwargs): + pass + + +class BECImageShow(BECPlotBase): + USER_ACCESS = ["show_image"] + + def __init__( + self, + parent: Optional[QWidget] = None, + parent_figure=None, + config: Optional[WidgetConfig] = None, + client=None, + gui_id: Optional[str] = None, + ): + if config is None: + config = BECImageShowConfig(widget_class=self.__class__.__name__) + super().__init__( + parent=parent, parent_figure=parent_figure, config=config, client=client, gui_id=gui_id + ) + + self.image = BECImageItem() + self.addItem(self.image) + self.addColorBar(self.image, values=(0, 100)) + # self.add_histogram() + + # set mock data + # self.image.setImage(np.random.rand(100, 100)) + # self.image.setOpts(axisOrder="row-major") + + self.debug_stream() + + def debug_stream(self): + device = "eiger" + self.image_thread = ImageThread(client=self.client, monitor=device) + # self.image_thread.start() + self.image_thread.image_updated.connect(self.on_image_update) + + def add_color_bar(self, vmap: tuple[int, int] = (0, 100)): + self.addColorBar(self.image, values=vmap) + + def add_histogram(self): + # Create HistogramLUTWidget + self.histogram = pg.HistogramLUTWidget() + + # Link HistogramLUTWidget to ImageItem + self.histogram.setImageItem(self.image) + + # def show_image( + # self, + # image: np.ndarray, + # scale: Optional[tuple] = None, + # pos: Optional[tuple] = None, + # auto_levels: Optional[bool] = True, + # auto_range: Optional[bool] = True, + # lut: Optional[list] = None, + # opacity: Optional[float] = 1.0, + # auto_downsample: Optional[bool] = True, + # ): + # self.image.setImage( + # image, + # scale=scale, + # pos=pos, + # autoLevels=auto_levels, + # autoRange=auto_range, + # lut=lut, + # opacity=opacity, + # autoDownsample=auto_downsample, + # ) + # + # def remove(self): + # self.image.clear() + # self.removeItem(self.image) + # self.image = None + # super().remove() + + def set_monitor(self, monitor: str = None): ... + + def set_zmq(self, address: str = None): ... + + @pyqtSlot(np.ndarray) # TODO specify format + def on_image_update(self, image): + self.image.updateImage(image) + + +class ImageThread(QThread): + image_updated = pyqtSignal(np.ndarray) # TODO add type + + def __init__(self, parent=None, client=None, monitor: str = None, port: int = None): + super().__init__() + + bec_dispatcher = BECDispatcher() + self.client = bec_dispatcher.client if client is None else client + self.dev = self.client.device_manager.devices + self.scans = self.client.scans + self.queue = self.client.queue + + # Monitor Device + self.monitor = monitor + + # Connection + self.port = port + if self.port is None: + self.port = self.client.connector.host + # self.connector = RedisConnector(self.port) + self.connector = RedisConnector("localhost:6379") + self.stream_consumer = None + + if self.monitor is not None: + self.connect_stream_consumer(self.monitor) + + def set_monitor(self, monitor: str = None) -> None: + """ + Set/update monitor device. + Args: + monitor(str): Name of the monitor. + """ + self.monitor = monitor + + def connect_stream_consumer(self, device): + if self.stream_consumer is not None: + self.stream_consumer.shutdown() + + self.stream_consumer = self.connector.stream_consumer( + topics=MessageEndpoints.device_monitor(device=device), + cb=self._streamer_cb, + parent=self, + ) + + self.stream_consumer.start() + + print(f"Stream consumer started for device: {device}") + + def process_FFT(self, data: np.ndarray) -> np.ndarray: + return np.fft.fft2(data) + + def center_of_mass(self, data: np.ndarray) -> tuple: + return np.unravel_index(np.argmax(data), data.shape) + + @staticmethod + def _streamer_cb(msg, *, parent, **_kwargs) -> None: + msg_device = msg.value + metadata = msg_device.metadata + + data = msg_device.content["data"] + parent.image_updated.emit(data) + + +class BECImageShowWithHistogram(pg.GraphicsLayoutWidget): + def __init__(self, parent=None): + super().__init__(parent=parent) + + # Create ImageItem and HistogramLUTItem + self.imageItem = pg.ImageItem() + self.histogram = pg.HistogramLUTItem() + + # Link Histogram to ImageItem + self.histogram.setImageItem(self.imageItem) + + # Create a layout within the GraphicsLayoutWidget + self.layout = self + + # Add ViewBox and Histogram to the layout + self.viewBox = self.addViewBox(row=0, col=0) + self.viewBox.addItem(self.imageItem) + self.viewBox.setAspectLocked(True) # Lock the aspect ratio + + # Add Histogram to the layout in the same cell + self.addItem(self.histogram, row=0, col=1) + self.histogram.setMaximumWidth(200) # Adjust the width of the histogram to fit + + def setImage(self, image): + """Set the image to be displayed.""" + self.imageItem.setImage(image) + + +# if __name__ == "__main__": +# import sys +# from qtpy.QtWidgets import QApplication +# +# bec_dispatcher = BECDispatcher() +# client = bec_dispatcher.client +# client.start() +# +# app = QApplication(sys.argv) +# win = QMainWindow() +# img = BECImageShow(client=client) +# win.setCentralWidget(img) +# win.show() +# sys.exit(app.exec_()) diff --git a/bec_widgets/widgets/plots/plot_base.py b/bec_widgets/widgets/plots/plot_base.py index cb4b98c1..7d462aa0 100644 --- a/bec_widgets/widgets/plots/plot_base.py +++ b/bec_widgets/widgets/plots/plot_base.py @@ -36,7 +36,7 @@ class WidgetConfig(ConnectionConfig): ) -class BECPlotBase(BECConnector, pg.PlotItem): +class BECPlotBase(BECConnector, pg.GraphicsLayout): USER_ACCESS = [ "set", "set_title", @@ -47,7 +47,8 @@ class BECPlotBase(BECConnector, pg.PlotItem): "set_x_lim", "set_y_lim", "set_grid", - "plot_data", + "lock_aspect_ratio", + "plot", "remove", ] @@ -62,9 +63,10 @@ class BECPlotBase(BECConnector, pg.PlotItem): if config is None: config = WidgetConfig(widget_class=self.__class__.__name__) super().__init__(client=client, config=config, gui_id=gui_id) - pg.PlotItem.__init__(self, parent) + pg.GraphicsLayout.__init__(self, parent) self.figure = parent_figure + self.plot_item = self.addPlot() self.add_legend() @@ -118,7 +120,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): Args: title(str): Title of the plot widget. """ - self.setTitle(title) + self.plot_item.setTitle(title) self.config.axis.title = title def set_x_label(self, label: str): @@ -127,7 +129,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): Args: label(str): Label of the x-axis. """ - self.setLabel("bottom", label) + self.plot_item.setLabel("bottom", label) self.config.axis.x_label = label def set_y_label(self, label: str): @@ -136,7 +138,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): Args: label(str): Label of the y-axis. """ - self.setLabel("left", label) + self.plot_item.setLabel("left", label) self.config.axis.y_label = label def set_x_scale(self, scale: Literal["linear", "log"] = "linear"): @@ -145,7 +147,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): Args: scale(Literal["linear", "log"]): Scale of the x-axis. """ - self.setLogMode(x=(scale == "log")) + self.plot_item.setLogMode(x=(scale == "log")) self.config.axis.x_scale = scale def set_y_scale(self, scale: Literal["linear", "log"] = "linear"): @@ -154,7 +156,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): Args: scale(Literal["linear", "log"]): Scale of the y-axis. """ - self.setLogMode(y=(scale == "log")) + self.plot_item.setLogMode(y=(scale == "log")) self.config.axis.y_scale = scale def set_x_lim(self, *args) -> None: @@ -177,7 +179,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): else: raise ValueError("set_x_lim expects either two separate arguments or a single tuple") - self.setXRange(x_min, x_max) + self.plot_item.setXRange(x_min, x_max) self.config.axis.x_lim = (x_min, x_max) def set_y_lim(self, *args) -> None: @@ -200,7 +202,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): else: raise ValueError("set_y_lim expects either two separate arguments or a single tuple") - self.setYRange(y_min, y_max) + self.plot_item.setYRange(y_min, y_max) self.config.axis.y_lim = (y_min, y_max) def set_grid(self, x: bool = False, y: bool = False): @@ -210,14 +212,23 @@ class BECPlotBase(BECConnector, pg.PlotItem): x(bool): Show grid on the x-axis. y(bool): Show grid on the y-axis. """ - self.showGrid(x, y) + self.plot_item.showGrid(x, y) self.config.axis.x_grid = x self.config.axis.y_grid = y def add_legend(self): - self.addLegend() + """Add legend to the plot""" + self.plot_item.addLegend() - def plot_data(self, data_x: list | np.ndarray, data_y: list | np.ndarray, **kwargs): + def lock_aspect_ratio(self, lock): + """ + Lock aspect ratio. + Args: + lock(bool): True to lock, False to unlock. + """ + self.plot_item.setAspectLocked(lock) + + def plot(self, data_x: list | np.ndarray, data_y: list | np.ndarray, **kwargs): """ Plot custom data on the plot widget. These data are not saved in config. Args: @@ -227,7 +238,7 @@ class BECPlotBase(BECConnector, pg.PlotItem): """ # TODO very basic so far, add more options # TODO decide name of the method - self.plot(data_x, data_y, **kwargs) + self.plot_item.plot(data_x, data_y, **kwargs) def remove(self): """Remove the plot widget from the figure.""" diff --git a/bec_widgets/widgets/plots/waveform1d.py b/bec_widgets/widgets/plots/waveform1d.py index 33201601..da7b7587 100644 --- a/bec_widgets/widgets/plots/waveform1d.py +++ b/bec_widgets/widgets/plots/waveform1d.py @@ -260,7 +260,7 @@ class BECWaveform1D(BECPlotBase): self.entry_validator = EntryValidator(self.dev) - self.addLegend() + self.add_legend() self.apply_config(self.config) # TODO check config assigning @@ -268,7 +268,7 @@ class BECWaveform1D(BECPlotBase): def find_widget_by_id( self, item_id: str ): # TODO implement this on level of BECConnector and all other widgets - for curve in self.curves: + for curve in self.plot_item.curves: if curve.gui_id == item_id: return curve @@ -287,12 +287,12 @@ class BECWaveform1D(BECPlotBase): return self.config = config - self.clear() + self.plot_item.clear() # TODO not sure if on the plot or layout level self.apply_axis_config() # Reset curves self._curves_data = defaultdict(dict) - self._curves = [] + self._curves = self.plot_item.curves for curve_id, curve_config in self.config.curves.items(): self.add_curve_by_config(curve_config) if replot_last_scan: @@ -377,7 +377,7 @@ class BECWaveform1D(BECPlotBase): BECCurve: The curve object. """ if isinstance(identifier, int): - return self.curves[identifier] + return self.plot_item.curves[identifier] elif isinstance(identifier, str): for source_type, curves in self.curves_data.items(): if identifier in curves: @@ -407,7 +407,7 @@ class BECWaveform1D(BECPlotBase): BECCurve: The curve object. """ curve_source = "custom" - curve_id = label or f"Curve {len(self.curves) + 1}" + curve_id = label or f"Curve {len(self.plot_item.curves) + 1}" curve_exits = self._check_curve_id(curve_id, self.curves_data) if curve_exits: @@ -418,7 +418,7 @@ class BECWaveform1D(BECPlotBase): color = ( color or Colors.golden_angle_color( - colormap=self.config.color_palette, num=len(self.curves) + 1, format="HEX" + colormap=self.config.color_palette, num=len(self.plot_item.curves) + 1, format="HEX" )[-1] ) @@ -456,7 +456,7 @@ class BECWaveform1D(BECPlotBase): """ curve = BECCurve(config=config, name=name) self.curves_data[source][name] = curve - self.addItem(curve) + self.plot_item.addItem(curve) self.config.curves[name] = curve.config if data is not None: curve.setData(data[0], data[1]) @@ -504,7 +504,7 @@ class BECWaveform1D(BECPlotBase): color = ( color or Colors.golden_angle_color( - colormap=self.config.color_palette, num=len(self.curves) + 1, format="HEX" + colormap=self.config.color_palette, num=len(self.plot_item.curves) + 1, format="HEX" )[-1] ) @@ -595,10 +595,10 @@ class BECWaveform1D(BECPlotBase): for source, curves in self.curves_data.items(): if curve_id in curves: curve = curves.pop(curve_id) - self.removeItem(curve) + self.plot_item.removeItem(curve) del self.config.curves[curve_id] - if curve in self.curves: - self.curves.remove(curve) + if curve in self.plot_item.curves: + self.plot_item.curves.remove(curve) return raise KeyError(f"Curve with ID '{curve_id}' not found.") @@ -608,10 +608,10 @@ class BECWaveform1D(BECPlotBase): Args: N(int): Order of the curve to be removed. """ - if N < len(self.curves): - curve = self.curves[N] + if N < len(self.plot_item.curves): + curve = self.plot_item.curves[N] curve_id = curve.name() # Assuming curve's name is used as its ID - self.removeItem(curve) + self.plot_item.removeItem(curve) del self.config.curves[curve_id] # Remove from self.curve_data for source, curves in self.curves_data.items(): @@ -709,7 +709,7 @@ class BECWaveform1D(BECPlotBase): ) output = "dict" - for curve in self.curves: + for curve in self.plot_item.curves: x_data, y_data = curve.get_data() if x_data is not None or y_data is not None: if output == "dict": @@ -719,9 +719,9 @@ class BECWaveform1D(BECPlotBase): if output == "pandas" and pd is not None: combined_data = pd.concat( - [data[curve.name()] for curve in self.curves], + [data[curve.name()] for curve in self.plot_item.curves], axis=1, - keys=[curve.name() for curve in self.curves], + keys=[curve.name() for curve in self.plot_item.curves], ) return combined_data return data diff --git a/tests/test_waveform1d.py b/tests/test_waveform1d.py index 1205ba53..35017a29 100644 --- a/tests/test_waveform1d.py +++ b/tests/test_waveform1d.py @@ -106,7 +106,7 @@ def test_create_waveform1D_by_config(bec_figure): w1_config_output = w1.get_config() assert w1_config_input == w1_config_output - assert w1.titleLabel.text == "Widget 1" + assert w1.plot_item.titleLabel.text == "Widget 1" assert w1.config.axis.title == "Widget 1" @@ -185,7 +185,7 @@ def test_remove_curve(bec_figure): w1.remove_curve(0) w1.remove_curve("bpm3a-bpm3a") - assert len(w1.curves) == 0 + assert len(w1.plot_item.curves) == 0 assert w1.curves_data["scan_segment"] == {} with pytest.raises(ValueError) as excinfo: