diff --git a/bec_widgets/utils/crosshair.py b/bec_widgets/utils/crosshair.py index 932a4d11..54e34aaf 100644 --- a/bec_widgets/utils/crosshair.py +++ b/bec_widgets/utils/crosshair.py @@ -50,9 +50,15 @@ class Crosshair(QObject): self.v_line.skip_auto_range = True self.h_line = pg.InfiniteLine(angle=0, movable=False) self.h_line.skip_auto_range = True + # Add custom attribute to identify crosshair lines + self.v_line.is_crosshair = True + self.h_line.is_crosshair = True self.plot_item.addItem(self.v_line, ignoreBounds=True) self.plot_item.addItem(self.h_line, ignoreBounds=True) + # Initialize highlighted curve in a case of multiple curves + self.highlighted_curve_index = None + # Add TextItem to display coordinates self.coord_label = pg.TextItem("", anchor=(1, 1), fill=(0, 0, 0, 100)) self.coord_label.setVisible(False) # Hide initially @@ -73,6 +79,7 @@ class Crosshair(QObject): self.plot_item.ctrl.downsampleSpin.valueChanged.connect(self.clear_markers) # Initialize markers + self.items = [] self.marker_moved_1d = {} self.marker_clicked_1d = {} self.marker_2d = None @@ -116,34 +123,74 @@ class Crosshair(QObject): self.coord_label.fill = pg.mkBrush(label_bg_color) self.coord_label.border = pg.mkPen(None) + @Slot(int) + def update_highlighted_curve(self, curve_index: int): + """ + Update the highlighted curve in the case of multiple curves in a plot item. + + Args: + curve_index(int): The index of curve to highlight + """ + self.highlighted_curve_index = curve_index + self.clear_markers() + self.update_markers() + def update_markers(self): """Update the markers for the crosshair, creating new ones if necessary.""" - # Create new markers - for item in self.plot_item.items: + if self.highlighted_curve_index is not None and hasattr(self.plot_item, "visible_curves"): + # Focus on the highlighted curve only + self.items = [self.plot_item.visible_curves[self.highlighted_curve_index]] + else: + # Handle all curves + self.items = self.plot_item.items + + # Create or update markers + for item in self.items: if isinstance(item, pg.PlotDataItem): # 1D plot - if item.name() in self.marker_moved_1d: - continue pen = item.opts["pen"] color = pen.color() if hasattr(pen, "color") else pg.mkColor(pen) - marker_moved = CrosshairScatterItem( - size=10, pen=pg.mkPen(color), brush=pg.mkBrush(None) - ) - marker_moved.skip_auto_range = True - self.marker_moved_1d[item.name()] = marker_moved - self.plot_item.addItem(marker_moved) - - # Create glowing effect markers for clicked events - for size, alpha in [(18, 64), (14, 128), (10, 255)]: - marker_clicked = CrosshairScatterItem( - size=size, - pen=pg.mkPen(None), - brush=pg.mkBrush(color.red(), color.green(), color.blue(), alpha), + name = item.name() or str(id(item)) + if name in self.marker_moved_1d: + # Update existing markers + marker_moved = self.marker_moved_1d[name] + marker_moved.setPen(pg.mkPen(color)) + # Update clicked markers' brushes + for marker_clicked in self.marker_clicked_1d[name]: + alpha = marker_clicked.opts["brush"].color().alpha() + marker_clicked.setBrush( + pg.mkBrush(color.red(), color.green(), color.blue(), alpha) + ) + # Update z-values + marker_moved.setZValue(item.zValue() + 1) + for marker_clicked in self.marker_clicked_1d[name]: + marker_clicked.setZValue(item.zValue() + 1) + else: + # Create new markers + marker_moved = CrosshairScatterItem( + size=10, pen=pg.mkPen(color), brush=pg.mkBrush(None) ) - marker_clicked.skip_auto_range = True - self.marker_clicked_1d[item.name()] = marker_clicked - self.plot_item.addItem(marker_clicked) + marker_moved.skip_auto_range = True + marker_moved.is_crosshair = True + self.marker_moved_1d[name] = marker_moved + self.plot_item.addItem(marker_moved) + # Set marker z-value higher than the curve + marker_moved.setZValue(item.zValue() + 1) + # Create glowing effect markers for clicked events + marker_clicked_list = [] + for size, alpha in [(18, 64), (14, 128), (10, 255)]: + marker_clicked = CrosshairScatterItem( + size=size, + pen=pg.mkPen(None), + brush=pg.mkBrush(color.red(), color.green(), color.blue(), alpha), + ) + marker_clicked.skip_auto_range = True + marker_clicked.is_crosshair = True + self.plot_item.addItem(marker_clicked) + marker_clicked.setZValue(item.zValue() + 1) + marker_clicked_list.append(marker_clicked) + self.marker_clicked_1d[name] = marker_clicked_list elif isinstance(item, pg.ImageItem): # 2D plot if self.marker_2d is not None: continue @@ -165,12 +212,11 @@ class Crosshair(QObject): """ y_values = defaultdict(list) x_values = defaultdict(list) - image_2d = None # Iterate through items in the plot - for item in self.plot_item.items: + for item in self.items: if isinstance(item, pg.PlotDataItem): # 1D plot - name = item.name() + name = item.name() or str(id(item)) plot_data = item._getDisplayDataset() if plot_data is None: continue @@ -191,7 +237,7 @@ class Crosshair(QObject): elif isinstance(item, pg.ImageItem): # 2D plot name = item.config.monitor image_2d = item.image - # clip the x and y values to the image dimensions to avoid out of bounds errors + # Clip the x and y values to the image dimensions to avoid out of bounds errors y_values[name] = int(np.clip(y, 0, image_2d.shape[1] - 1)) x_values[name] = int(np.clip(x, 0, image_2d.shape[0] - 1)) @@ -259,9 +305,9 @@ class Crosshair(QObject): # not sure how we got here, but just to be safe... return - for item in self.plot_item.items: + for item in self.items: if isinstance(item, pg.PlotDataItem): - name = item.name() + name = item.name() or str(id(item)) x, y = x_snap_values[name], y_snap_values[name] if x is None or y is None: continue @@ -312,13 +358,14 @@ class Crosshair(QObject): # not sure how we got here, but just to be safe... return - for item in self.plot_item.items: + for item in self.items: if isinstance(item, pg.PlotDataItem): - name = item.name() + name = item.name() or str(id(item)) x, y = x_snap_values[name], y_snap_values[name] if x is None or y is None: continue - self.marker_clicked_1d[name].setData([x], [y]) + for marker_clicked in self.marker_clicked_1d[name]: + marker_clicked.setData([x], [y]) x_snapped_scaled, y_snapped_scaled = self.scale_emitted_coordinates(x, y) coordinate_to_emit = ( name, @@ -340,9 +387,12 @@ class Crosshair(QObject): def clear_markers(self): """Clears the markers from the plot.""" for marker in self.marker_moved_1d.values(): - marker.clear() - for marker in self.marker_clicked_1d.values(): - marker.clear() + self.plot_item.removeItem(marker) + for markers in self.marker_clicked_1d.values(): + for marker in markers: + self.plot_item.removeItem(marker) + self.marker_moved_1d.clear() + self.marker_clicked_1d.clear() def scale_emitted_coordinates(self, x, y): """Scales the emitted coordinates if the axes are in log scale. @@ -369,7 +419,7 @@ class Crosshair(QObject): x, y = pos x_scaled, y_scaled = self.scale_emitted_coordinates(x, y) - # # Update coordinate label + # Update coordinate label self.coord_label.setText(f"({x_scaled:.{self.precision}g}, {y_scaled:.{self.precision}g})") self.coord_label.setPos(x, y) self.coord_label.setVisible(True) diff --git a/bec_widgets/widgets/figure/plots/multi_waveform/multi_waveform.py b/bec_widgets/widgets/figure/plots/multi_waveform/multi_waveform.py index 451c76d4..1c5b4b8e 100644 --- a/bec_widgets/widgets/figure/plots/multi_waveform/multi_waveform.py +++ b/bec_widgets/widgets/figure/plots/multi_waveform/multi_waveform.py @@ -40,6 +40,7 @@ class BECMultiWaveformConfig(SubplotConfig): class BECMultiWaveform(BECPlotBase): monitor_signal_updated = Signal() + highlighted_curve_index_changed = Signal(int) USER_ACCESS = [ "_rpc_id", "_config_dict", @@ -85,6 +86,7 @@ class BECMultiWaveform(BECPlotBase): self.connected = False self.current_highlight_index = 0 self._curves = deque() + self.visible_curves = [] self.number_of_visible_curves = 0 # Get bec shortcuts dev, scans, queue, scan_storage, dap @@ -159,8 +161,10 @@ class BECMultiWaveform(BECPlotBase): if current_scan_id != self.scan_id: self.scan_id = current_scan_id - self.plot_item.clear() + self.clear_curves() self.curves.clear() + if self.crosshair: + self.crosshair.clear_markers() # Always create a new curve and add it curve = pg.PlotDataItem() @@ -181,8 +185,8 @@ class BECMultiWaveform(BECPlotBase): Args: index (int): The index of the curve to highlight among visible curves. """ - visible_curves = [curve for curve in self.curves if curve.isVisible()] - num_visible_curves = len(visible_curves) + self.plot_item.visible_curves = [curve for curve in self.curves if curve.isVisible()] + num_visible_curves = len(self.plot_item.visible_curves) self.number_of_visible_curves = num_visible_curves if num_visible_curves == 0: @@ -197,7 +201,7 @@ class BECMultiWaveform(BECPlotBase): colors = Colors.evenly_spaced_colors( colormap=self.config.color_palette, num=num_colors, format="HEX" ) - for i, curve in enumerate(visible_curves): + for i, curve in enumerate(self.plot_item.visible_curves): curve.setPen() if i == self.current_highlight_index: curve.setPen(pg.mkPen(color=colors[i], width=5)) @@ -208,6 +212,8 @@ class BECMultiWaveform(BECPlotBase): curve.setAlpha(alpha=self.config.opacity / 100, auto=False) curve.setZValue(0) + self.highlighted_curve_index_changed.emit(self.current_highlight_index) + @Slot(int) def set_opacity(self, opacity: int): """ @@ -269,6 +275,13 @@ class BECMultiWaveform(BECPlotBase): self.config.color_palette = colormap self.set_curve_highlight(self.current_highlight_index) + def hook_crosshair(self) -> None: + super().hook_crosshair() + if self.crosshair: + self.highlighted_curve_index_changed.connect(self.crosshair.update_highlighted_curve) + if self.curves: + self.crosshair.update_highlighted_curve(self.current_highlight_index) + def get_all_data(self, output: Literal["dict", "pandas"] = "dict") -> dict: """ Extract all curve data into a dictionary or a pandas DataFrame. @@ -309,6 +322,17 @@ class BECMultiWaveform(BECPlotBase): return combined_data return data + def clear_curves(self): + """ + Remove all curves from the plot, excluding crosshair items. + """ + items_to_remove = [] + for item in self.plot_item.items: + if not getattr(item, "is_crosshair", False) and isinstance(item, pg.PlotDataItem): + items_to_remove.append(item) + for item in items_to_remove: + self.plot_item.removeItem(item) + def export_to_matplotlib(self): """ Export current waveform to matplotlib GUI. Available only if matplotlib is installed in the environment. @@ -325,6 +349,5 @@ if __name__ == "__main__": app = QApplication(sys.argv) widget = BECFigure() - widget.multi_waveform(monitor="waveform") widget.show() sys.exit(app.exec_())