From bb806c149dee88023ecb647b523cbd5189ea9001 Mon Sep 17 00:00:00 2001 From: wyzula-jan <133381102+wyzula-jan@users.noreply.github.com> Date: Tue, 8 Aug 2023 18:25:34 +0200 Subject: [PATCH] feat: plotting from streamer --- bec_widgets/bec_dispatcher.py | 60 ++++++ bec_widgets/line_plot.py | 81 ++++++- bec_widgets/line_plot_legacy.py | 360 ++++++++++++++++++++++++++++++++ 3 files changed, 490 insertions(+), 11 deletions(-) create mode 100644 bec_widgets/line_plot_legacy.py diff --git a/bec_widgets/bec_dispatcher.py b/bec_widgets/bec_dispatcher.py index 4a89cf5e..28851912 100644 --- a/bec_widgets/bec_dispatcher.py +++ b/bec_widgets/bec_dispatcher.py @@ -20,6 +20,9 @@ class _BECDispatcher(QObject): scan_segment = pyqtSignal(dict, dict) new_dap_data = pyqtSignal(dict) + new_projection_id = pyqtSignal(dict) + new_projection_data = pyqtSignal(dict) + def __init__(self): super().__init__() self.client = BECClient() @@ -34,6 +37,8 @@ class _BECDispatcher(QObject): self._scan_id = None scan_lock = RLock() + # self.new_projection_id.connect(self.new_projection_data) + def _scan_segment_cb(scan_segment, metadata): with scan_lock: # TODO: use ScanStatusMessage instead? @@ -89,5 +94,60 @@ class _BECDispatcher(QObject): self._daps[dap_name].consumer.shutdown() del self._daps[dap_name] + # def connect_proj_data(self, slot): + # keys = self.client.producer.keys("px_stream/projection_*") + # keys = keys or [] + # + # def _dap_cb(msg): + # msg = BECMessage.DeviceMessage.loads(msg.value) + # self.new_projection_data.emit(msg.content["data"]) + # + # proj_numbers = set(key.decode().split("px_stream/projection_")[1].split("/")[0] for key in keys) + # last_proj_id = sorted(proj_numbers)[-1] + # dap_ep = MessageEndpoints.processed_data(f"px_stream/projection_{last_proj_id}/") + # + # consumer = self.client.connector.consumer(topics=dap_ep, cb=_dap_cb) + # consumer.start() + # + # self.new_projection_data.connect(slot) + + def connect_proj_id(self, slot): + def _dap_cb(msg): + msg = BECMessage.DeviceMessage.loads(msg.value) + self.new_projection_id.emit(msg.content["signals"]) + + dap_ep = "px_stream/proj_nr" + consumer = self.client.connector.consumer(topics=dap_ep, cb=_dap_cb) + consumer.start() + + self.new_projection_id.connect(slot) + + def connect_proj_data(self, slot: object, data_ep: str) -> object: + def _dap_cb(msg): + msg = BECMessage.DeviceMessage.loads(msg.value) + self.new_projection_data.emit(msg.content["signals"]) + + consumer = self.client.connector.consumer(topics=data_ep, cb=_dap_cb) + consumer.start() + self._daps[data_ep] = _BECDap(consumer) + self._daps[data_ep].slots.add(slot) + + self.new_projection_data.connect(slot) + + def disconnect_proj_data(self, slot, data_ep): + if data_ep not in self._daps: + return + + if slot not in self._daps[data_ep].slots: + return + + self.new_projection_data.disconnect(slot) + self._daps[data_ep].slots.remove(slot) + + if not self._daps[data_ep].slots: + # shutdown consumer if there are no more connected slots + self._daps[data_ep].consumer.shutdown() + del self._daps[data_ep] + bec_dispatcher = _BECDispatcher() diff --git a/bec_widgets/line_plot.py b/bec_widgets/line_plot.py index 6a96b872..10cd4256 100644 --- a/bec_widgets/line_plot.py +++ b/bec_widgets/line_plot.py @@ -1,7 +1,8 @@ import os import warnings +import time from typing import Any - +import threading import numpy as np import pyqtgraph import pyqtgraph as pg @@ -13,6 +14,8 @@ from pyqtgraph import mkBrush, mkColor, mkPen from pyqtgraph.Qt import QtCore, QtWidgets, uic from pyqtgraph.Qt.QtCore import pyqtSignal +from bec_lib.core import BECMessage + class BasicPlot(QtWidgets.QWidget): update_signal = pyqtSignal() @@ -108,6 +111,12 @@ class BasicPlot(QtWidgets.QWidget): self.roi_selector.sigRegionChangeFinished.connect(self.get_roi_region) self.pushButton_debug.clicked.connect(self.debug) + self._current_proj = None + self._current_metadata_ep = "px_stream/projection_{}/metadata" + + self.data_retriever = threading.Thread(target=self.on_projection, daemon=True) + self.data_retriever.start() + def debug(self): """ Debug button just for quick testing @@ -177,20 +186,21 @@ class BasicPlot(QtWidgets.QWidget): def update(self): """Update the plot with the new data.""" # check if roi selector is in the plot - if self.roi_selector not in self.plot.items: - self.plot.addItem(self.roi_selector) + # if self.roi_selector not in self.plot.items: + # self.plot.addItem(self.roi_selector) # check if QTable was initialised and if list of devices was changed if self.y_value_list != self.previous_y_value_list: self.setup_cursor_table() self.previous_y_value_list = self.y_value_list.copy() if self.y_value_list else None - if len(self.plotter_data_x) <= 1: - return - self.plot.setLabel("bottom", self.label_bottom) - self.plot.setLabel("left", self.label_left) - for ii in range(len(self.y_value_list)): - self.curves[ii].setData(self.plotter_data_x, self.plotter_data_y[ii]) + self.curves[0].setData(self.plotter_data_x[0], self.plotter_data_y[0]) + # if len(self.plotter_data_x[0]) <= 1: + # return + # self.plot.setLabel("bottom", self.label_bottom) + # self.plot.setLabel("left", self.label_left) + # for ii in range(len(self.y_value_list)): + # self.curves[0].setData(self.plotter_data_x[0], self.plotter_data_y[0]) @pyqtSlot(dict, dict) def on_scan_segment(self, data: dict, metadata: dict) -> None: @@ -333,6 +343,53 @@ class BasicPlot(QtWidgets.QWidget): ] return colors + def on_projection(self): + while True: + if self._current_proj is None: + time.sleep(0.1) + continue + endpoint_key = f"px_stream/projection_{self._current_proj}/data" + # from_pnt = + # to_pnt = + msgs = client.producer.lrange(topic=endpoint_key, start=0, end=2) + data = [BECMessage.DeviceMessage.loads(msg) for msg in msgs] + if not data: + continue + + self.plotter_data_y = [ + np.sum(data[-1].content["signals"]["data"][75, ...], axis=-2).squeeze() + ] + self.update_signal.emit() + time.sleep(0.3) + + @pyqtSlot(dict) + def new_proj_metadata(self, data): + self.current_q = data["q"] + self._current_norm = data["norm_sum"] + self._current_metadata = data["metadata"] + + self.plotter_data_x = [np.arange(0, 3206, 1)] # [self.current_q] + # while True: + # endpoint_key = self._current_proj + # + # self.plotter_data_x = list(range(data["data"].shape[-1])) + # self.plotter_data_y[0] = np.sum(data["data"][-1,75,...], axis=-2).squeeze() + # self.update_signal.emit() + # time.sleep() + + @pyqtSlot(dict) + def new_proj(self, data): + if self._current_proj is not None: + bec_dispatcher.disconnect_proj_data( + self.new_proj_metadata, self._current_metadata_ep.format(self._current_proj) + ) + + self.plotter_data_x = [np.arange(0, 3206, 1)] + self._current_proj = data["proj_nr"] + bec_dispatcher.connect_proj_data( + self.new_proj_metadata, self._current_metadata_ep.format(self._current_proj) + ) + if __name__ == "__main__": import argparse @@ -345,8 +402,9 @@ if __name__ == "__main__": "--signals", help="specify recorded signals", nargs="+", - default=["gauss_bpm", "bpm4i", "bpm5i", "bpm6i", "xert"], + default=["gauss_bpm"], ) + # default = ["gauss_bpm", "bpm4i", "bpm5i", "bpm6i", "xert"], value = parser.parse_args() print(f"Plotting signals for: {', '.join(value.signals)}") client = bec_dispatcher.client @@ -354,7 +412,8 @@ if __name__ == "__main__": app = QtWidgets.QApplication([]) ctrl_c.setup(app) plot = BasicPlot(y_value_list=value.signals) - bec_dispatcher.connect(plot) + # bec_dispatcher.connect(plot) + bec_dispatcher.connect_proj_id(plot.new_proj) plot.show() # client.callbacks.register("scan_segment", plot, sync=False) app.exec_() diff --git a/bec_widgets/line_plot_legacy.py b/bec_widgets/line_plot_legacy.py new file mode 100644 index 00000000..6a96b872 --- /dev/null +++ b/bec_widgets/line_plot_legacy.py @@ -0,0 +1,360 @@ +import os +import warnings +from typing import Any + +import numpy as np +import pyqtgraph +import pyqtgraph as pg +from PyQt5.QtCore import pyqtSlot +from PyQt5.QtWidgets import QTableWidgetItem, QCheckBox + +from bec_lib import BECClient +from pyqtgraph import mkBrush, mkColor, mkPen +from pyqtgraph.Qt import QtCore, QtWidgets, uic +from pyqtgraph.Qt.QtCore import pyqtSignal + + +class BasicPlot(QtWidgets.QWidget): + update_signal = pyqtSignal() + roi_signal = pyqtSignal(tuple) + + def __init__(self, name="", y_value_list=["gauss_bpm"]) -> None: + """ + Basic plot widget for displaying scan data. + + Args: + name (str, optional): Name of the plot. Defaults to "". + y_value_list (list, optional): List of signals to be plotted. Defaults to ["gauss_bpm"]. + """ + + super(BasicPlot, self).__init__() + # Set style for pyqtgraph plots + pg.setConfigOption("background", "w") + pg.setConfigOption("foreground", "k") + current_path = os.path.dirname(__file__) + uic.loadUi(os.path.join(current_path, "line_plot.ui"), self) + + # Set splitter distribution of widgets + self.splitter.setSizes([5, 2]) + + self._idle_time = 100 + self.title = "" + self.label_bottom = "" + self.label_left = "" + + self.scan_motors = [] + self.y_value_list = y_value_list + self.previous_y_value_list = None + self.plotter_data_x = [] + self.plotter_data_y = [] + self.curves = [] + self.pens = [] + self.brushs = [] + + self.plotter_scan_id = None + + # TODO to be moved to utils function + plotstyles = { + "symbol": "o", + "symbolSize": 10, + } + color_list = ["#384c6b", "#e28a2b", "#5E3023", "#e41a1c", "#984e83", "#4daf4a"] + color_list = BasicPlot.golden_angle_color(colormap="CET-R2", num=len(self.y_value_list)) + + # setup plots - GraphicsLayoutWidget + # LabelItem + self.label = pg.LabelItem(justify="center") + self.glw.addItem(self.label) + self.label.setText("test label") + + # PlotItem - main window + self.glw.nextRow() + self.plot = pg.PlotItem() + self.glw.addItem(self.plot) + self.plot.addLegend() + + # PlotItem - ROI window - disabled for now #TODO add 2D plot for ROI and 1D plot for mouse click + # self.glw.nextRow() + # self.plot_roi = pg.PlotItem() + # self.glw.addItem(self.plot_roi) + + # ROI selector - so far from [-1,1] #TODO update to scale with xrange + self.roi_selector = pg.LinearRegionItem([-1, 1]) + + for ii, y_value in enumerate(self.y_value_list): + pen = mkPen(color=color_list[ii], width=2, style=QtCore.Qt.DashLine) + brush = mkBrush(color=color_list[ii]) + curve = pg.PlotDataItem( + **plotstyles, symbolBrush=brush, pen=pen, skipFiniteCheck=True, name=y_value + ) + self.plot.addItem(curve) + self.curves.append(curve) + self.pens.append(pen) + self.brushs.append(brush) + + self.crosshair_v = pg.InfiniteLine(angle=90, movable=False) + self.crosshair_h = pg.InfiniteLine(angle=0, movable=False) + self.plot.addItem(self.crosshair_v, ignoreBounds=True) + self.plot.addItem(self.crosshair_h, ignoreBounds=True) + + # Add textItems + self.add_text_items() + + # Manage signals + self.proxy = pg.SignalProxy( + self.plot.scene().sigMouseMoved, rateLimit=60, slot=self.mouse_moved + ) + self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update) + self.roi_selector.sigRegionChangeFinished.connect(self.get_roi_region) + self.pushButton_debug.clicked.connect(self.debug) + + def debug(self): + """ + Debug button just for quick testing + """ + + def get_roi_region(self): + """For testing purpose now, get roi region and print it to self.label as tuple""" + region = self.roi_selector.getRegion() + self.label.setText(f"x = {region[0]:.4f}, y ={region[1]:.4f}") + self.roi_signal.emit(region) + + def add_text_items(self): # TODO probably can be removed + """Add text items to the plot""" + + # self.mouse_box_data.setText("Mouse cursor") + # TODO Via StyleSheet, one may set the color of the full QLabel + # self.mouse_box_data.setStyleSheet(f"QLabel {{color : rgba{self.pens[0].color().getRgb()}}}") + + def mouse_moved(self, event: tuple) -> None: + """ + Update the mouse table with the current mouse position and the corresponding data. + + Args: + event (tuple): Mouse event containing the position of the mouse cursor. + The position is stored in first entry as horizontal, vertical pixel. + """ + pos = event[0] + if not self.plot.sceneBoundingRect().contains(pos): + return + mousePoint = self.plot.vb.mapSceneToView(pos) + self.crosshair_v.setPos(mousePoint.x()) + self.crosshair_h.setPos(mousePoint.y()) + if not self.plotter_data_x: + return + + for ii, y_value in enumerate(self.y_value_list): + closest_point = self.closest_x_y_value( + mousePoint.x(), self.plotter_data_x, self.plotter_data_y[ii] + ) + # TODO fix text wobble in plot, see plot when it crosses 0 + x_data = f"{closest_point[0]:.{self.precision}f}" + y_data = f"{closest_point[1]:.{self.precision}f}" + + # Write coordinate to QTable + self.mouse_table.setItem(ii, 1, QTableWidgetItem(str(y_value))) + self.mouse_table.setItem(ii, 2, QTableWidgetItem(str(x_data))) + self.mouse_table.setItem(ii, 3, QTableWidgetItem(str(y_data))) + + self.mouse_table.resizeColumnsToContents() + + def closest_x_y_value(self, input_value, list_x, list_y) -> tuple: + """ + Find the closest x and y value to the input value. + + Args: + input_value (float): Input value + list_x (list): List of x values + list_y (list): List of y values + + Returns: + tuple: Closest x and y value + """ + arr = np.asarray(list_x) + i = (np.abs(arr - input_value)).argmin() + return list_x[i], list_y[i] + + def update(self): + """Update the plot with the new data.""" + # check if roi selector is in the plot + if self.roi_selector not in self.plot.items: + self.plot.addItem(self.roi_selector) + + # check if QTable was initialised and if list of devices was changed + if self.y_value_list != self.previous_y_value_list: + self.setup_cursor_table() + self.previous_y_value_list = self.y_value_list.copy() if self.y_value_list else None + + if len(self.plotter_data_x) <= 1: + return + self.plot.setLabel("bottom", self.label_bottom) + self.plot.setLabel("left", self.label_left) + for ii in range(len(self.y_value_list)): + self.curves[ii].setData(self.plotter_data_x, self.plotter_data_y[ii]) + + @pyqtSlot(dict, dict) + def on_scan_segment(self, data: dict, metadata: dict) -> None: + """Update function that is called during the scan callback. To avoid + too many renderings, the GUI is only processing events every <_idle_time> ms. + + Args: + data (dict): Dictionary containing a new scan segment + metadata (dict): Scan metadata + + """ + if metadata["scanID"] != self.plotter_scan_id: + self.plotter_scan_id = metadata["scanID"] + self._reset_plot_data() + + self.title = f"Scan {metadata['scan_number']}" + + self.scan_motors = scan_motors = metadata.get("scan_report_devices") + # client = BECClient() + remove_y_value_index = [ + index + for index, y_value in enumerate(self.y_value_list) + if y_value not in client.device_manager.devices + ] + if remove_y_value_index: + for ii in sorted(remove_y_value_index, reverse=True): + # TODO Use bec warning message??? to be discussed with Klaus + warnings.warn( + f"Warning: no matching signal for {self.y_value_list[ii]} found in list of devices. Removing from plot." + ) + self.remove_curve_by_name(self.plot, self.y_value_list[ii]) + self.y_value_list.pop(ii) + + self.precision = client.device_manager.devices[scan_motors[0]]._info["describe"][ + scan_motors[0] + ]["precision"] + # TODO after update of bec_lib, this will be new way to access data + # self.precision = client.device_manager.devices[scan_motors[0]].precision + + x = data["data"][scan_motors[0]][scan_motors[0]]["value"] + self.plotter_data_x.append(x) + for ii, y_value in enumerate(self.y_value_list): + y = data["data"][y_value][y_value]["value"] + self.plotter_data_y[ii].append(y) + self.label_bottom = scan_motors[0] + self.label_left = f"{', '.join(self.y_value_list)}" + + # print(f'metadata scan N{metadata["scan_number"]}') #TODO put as label on top of plot + # print(f'Data point = {data["point_id"]}') #TODO can be used for progress bar + + if len(self.plotter_data_x) <= 1: + return + self.update_signal.emit() + + def _reset_plot_data(self): + """Reset the plot data.""" + self.plotter_data_x = [] + self.plotter_data_y = [] + for ii in range(len(self.y_value_list)): + self.curves[ii].setData([], []) + self.plotter_data_y.append([]) + + def setup_cursor_table(self): + """QTable formatting according to N of devices displayed in plot.""" + + # Init number of rows in table according to n of devices + self.mouse_table.setRowCount(len(self.y_value_list)) + + for ii, y_value in enumerate(self.y_value_list): + checkbox = QCheckBox() + checkbox.setChecked(True) + # TODO just for testing, will be replaced by removing/adding curve + checkbox.stateChanged.connect(lambda: print("status Changed")) + # checkbox.stateChanged.connect(lambda: self.remove_curve_by_name(plot=self.plot, checkbox=checkbox, name=y_value)) + self.mouse_table.setCellWidget(ii, 0, checkbox) + self.mouse_table.setItem(ii, 1, QTableWidgetItem(str(y_value))) + + self.mouse_table.resizeColumnsToContents() + + @staticmethod + def remove_curve_by_name(plot: pyqtgraph.PlotItem, name: str) -> None: + # def remove_curve_by_name(plot: pyqtgraph.PlotItem, checkbox: QtWidgets.QCheckBox, name: str) -> None: + """Removes a curve from the given plot by the specified name. + + Args: + plot (pyqtgraph.PlotItem): The plot from which to remove the curve. + name (str): The name of the curve to remove. + """ + # if checkbox.isChecked(): + for item in plot.items: + if isinstance(item, pg.PlotDataItem) and getattr(item, "opts", {}).get("name") == name: + plot.removeItem(item) + return + + # else: + # return + + @staticmethod + def golden_ratio(num: int) -> list: + """Calculate the golden ratio for a given number of angles. + + Args: + num (int): Number of angles + """ + phi = 2 * np.pi * ((1 + np.sqrt(5)) / 2) + angles = [] + for ii in range(num): + x = np.cos(ii * phi) + y = np.sin(ii * phi) + angle = np.arctan2(y, x) + angles.append(angle) + return angles + + @staticmethod + def golden_angle_color(colormap: str, num: int) -> list: + """ + Extract num colors for from the specified colormap following golden angle distribution. + + Args: + colormap (str): Name of the colormap + num (int): Number of requested colors + + Returns: + list: List of colors with length + + Raises: + ValueError: If the number of requested colors is greater than the number of colors in the colormap. + """ + + cmap = pg.colormap.get(colormap) + cmap_colors = cmap.color + if num > len(cmap_colors): + raise ValueError( + f"Number of colors requested ({num}) is greater than the number of colors in the colormap ({len(cmap_colors)})" + ) + angles = BasicPlot.golden_ratio(len(cmap_colors)) + color_selection = np.round(np.interp(angles, (-np.pi, np.pi), (0, len(cmap_colors)))) + colors = [ + mkColor(tuple((cmap_colors[int(ii)] * 255).astype(int))) for ii in color_selection[:num] + ] + return colors + + +if __name__ == "__main__": + import argparse + from bec_widgets.bec_dispatcher import bec_dispatcher + + from bec_widgets import ctrl_c + + parser = argparse.ArgumentParser() + parser.add_argument( + "--signals", + help="specify recorded signals", + nargs="+", + default=["gauss_bpm", "bpm4i", "bpm5i", "bpm6i", "xert"], + ) + value = parser.parse_args() + print(f"Plotting signals for: {', '.join(value.signals)}") + client = bec_dispatcher.client + # client.start() + app = QtWidgets.QApplication([]) + ctrl_c.setup(app) + plot = BasicPlot(y_value_list=value.signals) + bec_dispatcher.connect(plot) + plot.show() + # client.callbacks.register("scan_segment", plot, sync=False) + app.exec_()