1
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2026-04-08 01:37:53 +02:00

Compare commits

..

12 Commits

Author SHA1 Message Date
semantic-release
ca600b057e 2.22.1
Automatically generated by python-semantic-release
2025-07-11 11:57:47 +00:00
6c494258f8 fix(heatmap): fix pixel size calculation for arbitrary shapes 2025-07-11 13:57:01 +02:00
63a8da680d fix(crosshair): crosshair mouse_moved can be set manually 2025-07-11 13:57:01 +02:00
semantic-release
0f2bde1a0a 2.22.0
Automatically generated by python-semantic-release
2025-07-10 12:23:05 +00:00
0c76b0c495 feat: add heatmap widget 2025-07-10 14:22:15 +02:00
e594de3ca3 fix(image): reset crosshair on new scan 2025-07-10 14:22:15 +02:00
adaad4f4d5 fix(crosshair): add slot to reset mouse markers 2025-07-10 14:22:15 +02:00
39c316d6ea fix(image item): fix processor for nans in images 2025-07-10 14:22:15 +02:00
3ba0fc4b44 fix(crosshair): fix crosshair support for transformations 2025-07-10 14:22:15 +02:00
a6fc7993a3 fix(image_processor): support for nans in nd arrays 2025-07-10 14:22:15 +02:00
324a5bd3d9 feat(image_item): add support for qtransform 2025-07-10 14:22:15 +02:00
8929778f07 fix(image_base): move cbar init to image base 2025-07-10 14:22:15 +02:00
25 changed files with 2516 additions and 80 deletions

View File

@@ -1,6 +1,48 @@
# CHANGELOG
## v2.22.1 (2025-07-11)
### Bug Fixes
- **crosshair**: Crosshair mouse_moved can be set manually
([`63a8da6`](https://github.com/bec-project/bec_widgets/commit/63a8da680d263a50102aacf463ec6f6252562f9d))
- **heatmap**: Fix pixel size calculation for arbitrary shapes
([`6c49425`](https://github.com/bec-project/bec_widgets/commit/6c494258f82059a2472f43bb8287390ce1aba704))
## v2.22.0 (2025-07-10)
### Bug Fixes
- **crosshair**: Add slot to reset mouse markers
([`adaad4f`](https://github.com/bec-project/bec_widgets/commit/adaad4f4d5ebf775a337e23a944ba9eb289d01a0))
- **crosshair**: Fix crosshair support for transformations
([`3ba0fc4`](https://github.com/bec-project/bec_widgets/commit/3ba0fc4b442e5926f27a13f09d628c30987f2cf8))
- **image**: Reset crosshair on new scan
([`e594de3`](https://github.com/bec-project/bec_widgets/commit/e594de3ca39970f91f5842693eeb1fac393eaa34))
- **image item**: Fix processor for nans in images
([`39c316d`](https://github.com/bec-project/bec_widgets/commit/39c316d6eadfdfbd483661b67720a7e224a46712))
- **image_base**: Move cbar init to image base
([`8929778`](https://github.com/bec-project/bec_widgets/commit/8929778f073c40a9eabba7eda2415fc9af1072bb))
- **image_processor**: Support for nans in nd arrays
([`a6fc799`](https://github.com/bec-project/bec_widgets/commit/a6fc7993a3d22cfd086310c8e6dad3f9f3d1e9fe))
### Features
- Add heatmap widget
([`0c76b0c`](https://github.com/bec-project/bec_widgets/commit/0c76b0c49598d1456aab266b483de327788028fd))
- **image_item**: Add support for qtransform
([`324a5bd`](https://github.com/bec-project/bec_widgets/commit/324a5bd3d9ed278495c6ba62453b02061900ae32))
## v2.21.4 (2025-07-08)
### Bug Fixes

View File

@@ -37,6 +37,7 @@ _Widgets = {
"DeviceBrowser": "DeviceBrowser",
"DeviceComboBox": "DeviceComboBox",
"DeviceLineEdit": "DeviceLineEdit",
"Heatmap": "Heatmap",
"Image": "Image",
"LogPanel": "LogPanel",
"Minesweeper": "Minesweeper",
@@ -1181,6 +1182,482 @@ class EllipticalROI(RPCBase):
"""
class Heatmap(RPCBase):
"""Heatmap widget for visualizing 2d grid data with color mapping for the z-axis."""
@property
@rpc_call
def enable_toolbar(self) -> "bool":
"""
Show Toolbar.
"""
@enable_toolbar.setter
@rpc_call
def enable_toolbar(self) -> "bool":
"""
Show Toolbar.
"""
@property
@rpc_call
def enable_side_panel(self) -> "bool":
"""
Show Side Panel
"""
@enable_side_panel.setter
@rpc_call
def enable_side_panel(self) -> "bool":
"""
Show Side Panel
"""
@property
@rpc_call
def enable_fps_monitor(self) -> "bool":
"""
Enable the FPS monitor.
"""
@enable_fps_monitor.setter
@rpc_call
def enable_fps_monitor(self) -> "bool":
"""
Enable the FPS monitor.
"""
@rpc_call
def set(self, **kwargs):
"""
Set the properties of the plot widget.
Args:
**kwargs: Keyword arguments for the properties to be set.
Possible properties:
- title: str
- x_label: str
- y_label: str
- x_scale: Literal["linear", "log"]
- y_scale: Literal["linear", "log"]
- x_lim: tuple
- y_lim: tuple
- legend_label_size: int
"""
@property
@rpc_call
def title(self) -> "str":
"""
Set title of the plot.
"""
@title.setter
@rpc_call
def title(self) -> "str":
"""
Set title of the plot.
"""
@property
@rpc_call
def x_label(self) -> "str":
"""
The set label for the x-axis.
"""
@x_label.setter
@rpc_call
def x_label(self) -> "str":
"""
The set label for the x-axis.
"""
@property
@rpc_call
def y_label(self) -> "str":
"""
The set label for the y-axis.
"""
@y_label.setter
@rpc_call
def y_label(self) -> "str":
"""
The set label for the y-axis.
"""
@property
@rpc_call
def x_limits(self) -> "QPointF":
"""
Get the x limits of the plot.
"""
@x_limits.setter
@rpc_call
def x_limits(self) -> "QPointF":
"""
Get the x limits of the plot.
"""
@property
@rpc_call
def y_limits(self) -> "QPointF":
"""
Get the y limits of the plot.
"""
@y_limits.setter
@rpc_call
def y_limits(self) -> "QPointF":
"""
Get the y limits of the plot.
"""
@property
@rpc_call
def x_grid(self) -> "bool":
"""
Show grid on the x-axis.
"""
@x_grid.setter
@rpc_call
def x_grid(self) -> "bool":
"""
Show grid on the x-axis.
"""
@property
@rpc_call
def y_grid(self) -> "bool":
"""
Show grid on the y-axis.
"""
@y_grid.setter
@rpc_call
def y_grid(self) -> "bool":
"""
Show grid on the y-axis.
"""
@property
@rpc_call
def inner_axes(self) -> "bool":
"""
Show inner axes of the plot widget.
"""
@inner_axes.setter
@rpc_call
def inner_axes(self) -> "bool":
"""
Show inner axes of the plot widget.
"""
@property
@rpc_call
def outer_axes(self) -> "bool":
"""
Show the outer axes of the plot widget.
"""
@outer_axes.setter
@rpc_call
def outer_axes(self) -> "bool":
"""
Show the outer axes of the plot widget.
"""
@property
@rpc_call
def auto_range_x(self) -> "bool":
"""
Set auto range for the x-axis.
"""
@auto_range_x.setter
@rpc_call
def auto_range_x(self) -> "bool":
"""
Set auto range for the x-axis.
"""
@property
@rpc_call
def auto_range_y(self) -> "bool":
"""
Set auto range for the y-axis.
"""
@auto_range_y.setter
@rpc_call
def auto_range_y(self) -> "bool":
"""
Set auto range for the y-axis.
"""
@property
@rpc_call
def minimal_crosshair_precision(self) -> "int":
"""
Minimum decimal places for crosshair when dynamic precision is enabled.
"""
@minimal_crosshair_precision.setter
@rpc_call
def minimal_crosshair_precision(self) -> "int":
"""
Minimum decimal places for crosshair when dynamic precision is enabled.
"""
@property
@rpc_call
def color_map(self) -> "str":
"""
Set the color map of the image.
"""
@color_map.setter
@rpc_call
def color_map(self) -> "str":
"""
Set the color map of the image.
"""
@property
@rpc_call
def v_range(self) -> "QPointF":
"""
Set the v_range of the main image.
"""
@v_range.setter
@rpc_call
def v_range(self) -> "QPointF":
"""
Set the v_range of the main image.
"""
@property
@rpc_call
def v_min(self) -> "float":
"""
Get the minimum value of the v_range.
"""
@v_min.setter
@rpc_call
def v_min(self) -> "float":
"""
Get the minimum value of the v_range.
"""
@property
@rpc_call
def v_max(self) -> "float":
"""
Get the maximum value of the v_range.
"""
@v_max.setter
@rpc_call
def v_max(self) -> "float":
"""
Get the maximum value of the v_range.
"""
@property
@rpc_call
def lock_aspect_ratio(self) -> "bool":
"""
Whether the aspect ratio is locked.
"""
@lock_aspect_ratio.setter
@rpc_call
def lock_aspect_ratio(self) -> "bool":
"""
Whether the aspect ratio is locked.
"""
@property
@rpc_call
def autorange(self) -> "bool":
"""
Whether autorange is enabled.
"""
@autorange.setter
@rpc_call
def autorange(self) -> "bool":
"""
Whether autorange is enabled.
"""
@property
@rpc_call
def autorange_mode(self) -> "str":
"""
Autorange mode.
Options:
- "max": Use the maximum value of the image for autoranging.
- "mean": Use the mean value of the image for autoranging.
"""
@autorange_mode.setter
@rpc_call
def autorange_mode(self) -> "str":
"""
Autorange mode.
Options:
- "max": Use the maximum value of the image for autoranging.
- "mean": Use the mean value of the image for autoranging.
"""
@rpc_call
def enable_colorbar(
self,
enabled: "bool",
style: "Literal['full', 'simple']" = "full",
vrange: "tuple[int, int] | None" = None,
):
"""
Enable the colorbar and switch types of colorbars.
Args:
enabled(bool): Whether to enable the colorbar.
style(Literal["full", "simple"]): The type of colorbar to enable.
vrange(tuple): The range of values to use for the colorbar.
"""
@property
@rpc_call
def enable_simple_colorbar(self) -> "bool":
"""
Enable the simple colorbar.
"""
@enable_simple_colorbar.setter
@rpc_call
def enable_simple_colorbar(self) -> "bool":
"""
Enable the simple colorbar.
"""
@property
@rpc_call
def enable_full_colorbar(self) -> "bool":
"""
Enable the full colorbar.
"""
@enable_full_colorbar.setter
@rpc_call
def enable_full_colorbar(self) -> "bool":
"""
Enable the full colorbar.
"""
@property
@rpc_call
def fft(self) -> "bool":
"""
Whether FFT postprocessing is enabled.
"""
@fft.setter
@rpc_call
def fft(self) -> "bool":
"""
Whether FFT postprocessing is enabled.
"""
@property
@rpc_call
def log(self) -> "bool":
"""
Whether logarithmic scaling is applied.
"""
@log.setter
@rpc_call
def log(self) -> "bool":
"""
Whether logarithmic scaling is applied.
"""
@property
@rpc_call
def main_image(self) -> "ImageItem":
"""
Access the main image item.
"""
@rpc_call
def add_roi(
self,
kind: "Literal['rect', 'circle', 'ellipse']" = "rect",
name: "str | None" = None,
line_width: "int | None" = 5,
pos: "tuple[float, float] | None" = (10, 10),
size: "tuple[float, float] | None" = (50, 50),
movable: "bool" = True,
**pg_kwargs,
) -> "RectangularROI | CircularROI":
"""
Add a ROI to the image.
Args:
kind(str): The type of ROI to add. Options are "rect" or "circle".
name(str): The name of the ROI.
line_width(int): The line width of the ROI.
pos(tuple): The position of the ROI.
size(tuple): The size of the ROI.
movable(bool): Whether the ROI is movable.
**pg_kwargs: Additional arguments for the ROI.
Returns:
RectangularROI | CircularROI: The created ROI object.
"""
@rpc_call
def remove_roi(self, roi: "int | str"):
"""
Remove an ROI by index or label via the ROIController.
"""
@property
@rpc_call
def rois(self) -> "list[BaseROI]":
"""
Get the list of ROIs.
"""
@rpc_call
def plot(
self,
x_name: "str",
y_name: "str",
z_name: "str",
x_entry: "None | str" = None,
y_entry: "None | str" = None,
z_entry: "None | str" = None,
color_map: "str | None" = "plasma",
label: "str | None" = None,
validate_bec: "bool" = True,
reload: "bool" = False,
):
"""
Plot the heatmap with the given x, y, and z data.
"""
class Image(RPCBase):
"""Image widget for displaying 2D data."""

View File

@@ -5,9 +5,13 @@ from typing import Any
import numpy as np
import pyqtgraph as pg
from qtpy.QtCore import QObject, Qt, Signal, Slot
from qtpy.QtCore import QObject, QPointF, Qt, Signal
from qtpy.QtGui import QCursor, QTransform
from qtpy.QtWidgets import QApplication
from bec_widgets.utils.error_popups import SafeSlot
from bec_widgets.widgets.plots.image.image_item import ImageItem
class CrosshairScatterItem(pg.ScatterPlotItem):
def setDownsampling(self, ds=None, auto=None, method=None):
@@ -160,7 +164,7 @@ class Crosshair(QObject):
qapp.theme_signal.theme_updated.connect(self._update_theme)
self._update_theme()
@Slot(str)
@SafeSlot(str)
def _update_theme(self, theme: str | None = None):
"""Update the theme."""
if theme is None:
@@ -187,7 +191,7 @@ class Crosshair(QObject):
self.coord_label.fill = pg.mkBrush(label_bg_color)
self.coord_label.border = pg.mkPen(None)
@Slot(int)
@SafeSlot(int)
def update_highlighted_curve(self, curve_index: int):
"""
Update the highlighted curve in the case of multiple curves in a plot item.
@@ -265,15 +269,47 @@ class Crosshair(QObject):
[0, 0], size=[item.image.shape[0], 1], pen=pg.mkPen("r", width=2), movable=False
)
self.marker_2d_row.skip_auto_range = True
if item.image_transform is not None:
self.marker_2d_row.setTransform(item.image_transform)
self.plot_item.addItem(self.marker_2d_row)
# Create vertical ROI for column highlighting
self.marker_2d_col = pg.ROI(
[0, 0], size=[1, item.image.shape[1]], pen=pg.mkPen("r", width=2), movable=False
)
if item.image_transform is not None:
self.marker_2d_col.setTransform(item.image_transform)
self.marker_2d_col.skip_auto_range = True
self.plot_item.addItem(self.marker_2d_col)
@SafeSlot()
def update_markers_on_image_change(self):
"""
Update markers when the image changes, e.g. when the
image shape or transformation changes.
"""
for item in self.items:
if not isinstance(item, pg.ImageItem):
continue
if self.marker_2d_row is not None:
self.marker_2d_row.setSize([item.image.shape[0], 1])
self.marker_2d_row.setTransform(item.image_transform)
if self.marker_2d_col is not None:
self.marker_2d_col.setSize([1, item.image.shape[1]])
self.marker_2d_col.setTransform(item.image_transform)
# Get the current mouse position
views = self.plot_item.vb.scene().views()
if not views:
return
view = views[0]
global_pos = QCursor.pos()
view_pos = view.mapFromGlobal(global_pos)
scene_pos = view.mapToScene(view_pos)
if self.plot_item.vb.sceneBoundingRect().contains(scene_pos):
plot_pt = self.plot_item.vb.mapSceneToView(scene_pos)
self.mouse_moved(manual_pos=(plot_pt.x(), plot_pt.y()))
def snap_to_data(
self, x: float, y: float
) -> tuple[None, None] | tuple[defaultdict[Any, list], defaultdict[Any, list]]:
@@ -316,9 +352,25 @@ class Crosshair(QObject):
image_2d = item.image
if image_2d is None:
continue
# Clip the x and y values to the image dimensions to avoid out of bounds errors
y_values[name] = int(np.clip(y, 0, image_2d.shape[1] - 1))
x_values[name] = int(np.clip(x, 0, image_2d.shape[0] - 1))
# Map scene coordinates (plot units) back to image pixel coordinates
if item.image_transform is not None:
inv_transform, _ = item.image_transform.inverted()
xy_trans = inv_transform.map(QPointF(x, y))
else:
xy_trans = QPointF(x, y)
# Define valid pixel coordinate bounds
min_x_px, min_y_px = 0, 0
max_x_px = image_2d.shape[0] - 1 # columns
max_y_px = image_2d.shape[1] - 1 # rows
# Clip the mapped coordinates to the image bounds
px = int(np.clip(xy_trans.x(), min_x_px, max_x_px))
py = int(np.clip(xy_trans.y(), min_y_px, max_y_px))
# Store snapped pixel positions
x_values[name] = px
y_values[name] = py
if x_values and y_values:
if all(v is None for v in x_values.values()) or all(
@@ -358,60 +410,74 @@ class Crosshair(QObject):
return list_x[original_index], list_y[original_index]
def mouse_moved(self, event):
"""Handles the mouse moved event, updating the crosshair position and emitting signals.
@SafeSlot(object, tuple)
def mouse_moved(self, event=None, manual_pos=None):
"""
Handles the mouse moved event, updating the crosshair position and emitting signals.
Args:
event: The mouse moved event
event(object): The mouse moved event, which contains the scene position.
manual_pos(tuple, optional): A tuple containing the (x, y) coordinates to manually set the crosshair position.
"""
pos = event[0]
# Determine target (x, y) in *plot* coordinates
if manual_pos is not None:
x, y = manual_pos
else:
if event is None:
return # nothing to do
scene_pos = event[0] # SignalProxy bundle
if not self.plot_item.vb.sceneBoundingRect().contains(scene_pos):
return
view_pos = self.plot_item.vb.mapSceneToView(scene_pos)
x, y = view_pos.x(), view_pos.y()
# Update crosshair visuals
self.v_line.setPos(x)
self.h_line.setPos(y)
self.update_markers()
if self.plot_item.vb.sceneBoundingRect().contains(pos):
mouse_point = self.plot_item.vb.mapSceneToView(pos)
x, y = mouse_point.x(), mouse_point.y()
self.v_line.setPos(x)
self.h_line.setPos(y)
scaled_x, scaled_y = self.scale_emitted_coordinates(mouse_point.x(), mouse_point.y())
self.crosshairChanged.emit((scaled_x, scaled_y))
self.positionChanged.emit((x, y))
scaled_x, scaled_y = self.scale_emitted_coordinates(x, y)
self.crosshairChanged.emit((scaled_x, scaled_y))
self.positionChanged.emit((x, y))
x_snap_values, y_snap_values = self.snap_to_data(x, y)
if x_snap_values is None or y_snap_values is None:
return
if all(v is None for v in x_snap_values.values()) or all(
v is None for v in y_snap_values.values()
):
# not sure how we got here, but just to be safe...
return
snap_x_vals, snap_y_vals = self.snap_to_data(x, y)
if snap_x_vals is None or snap_y_vals is None:
return
if all(v is None for v in snap_x_vals.values()) or all(
v is None for v in snap_y_vals.values()
):
return
precision = self._current_precision()
for item in self.items:
if isinstance(item, pg.PlotDataItem):
name = item.name() or str(id(item))
x, y = x_snap_values[name], y_snap_values[name]
if x is None or y is None:
continue
self.marker_moved_1d[name].setData([x], [y])
x_snapped_scaled, y_snapped_scaled = self.scale_emitted_coordinates(x, y)
coordinate_to_emit = (
name,
round(x_snapped_scaled, precision),
round(y_snapped_scaled, precision),
)
self.coordinatesChanged1D.emit(coordinate_to_emit)
elif isinstance(item, pg.ImageItem):
name = item.objectName() or str(id(item))
x, y = x_snap_values[name], y_snap_values[name]
if x is None or y is None:
continue
# Set position of horizontal ROI (row)
self.marker_2d_row.setPos([0, y])
# Set position of vertical ROI (column)
self.marker_2d_col.setPos([x, 0])
coordinate_to_emit = (name, x, y)
self.coordinatesChanged2D.emit(coordinate_to_emit)
else:
precision = self._current_precision()
for item in self.items:
if isinstance(item, pg.PlotDataItem):
name = item.name() or str(id(item))
sx, sy = snap_x_vals[name], snap_y_vals[name]
if sx is None or sy is None:
continue
self.marker_moved_1d[name].setData([sx], [sy])
sx_s, sy_s = self.scale_emitted_coordinates(sx, sy)
self.coordinatesChanged1D.emit(
(name, round(sx_s, precision), round(sy_s, precision))
)
elif isinstance(item, pg.ImageItem):
name = item.objectName() or str(id(item))
px, py = snap_x_vals[name], snap_y_vals[name]
if px is None or py is None:
continue
# Respect image transforms
if isinstance(item, ImageItem) and item.image_transform is not None:
row, col = self._get_transformed_position(px, py, item.image_transform)
self.marker_2d_row.setPos(row)
self.marker_2d_col.setPos(col)
else:
self.marker_2d_row.setPos([0, py])
self.marker_2d_col.setPos([px, 0])
self.coordinatesChanged2D.emit((name, px, py))
def mouse_clicked(self, event):
"""Handles the mouse clicked event, updating the crosshair position and emitting signals.
@@ -462,15 +528,35 @@ class Crosshair(QObject):
x, y = x_snap_values[name], y_snap_values[name]
if x is None or y is None:
continue
# Set position of horizontal ROI (row)
self.marker_2d_row.setPos([0, y])
# Set position of vertical ROI (column)
self.marker_2d_col.setPos([x, 0])
if isinstance(item, ImageItem) and item.image_transform is not None:
row, col = self._get_transformed_position(x, y, item.image_transform)
self.marker_2d_row.setPos(row)
self.marker_2d_col.setPos(col)
else:
self.marker_2d_row.setPos([0, y])
self.marker_2d_col.setPos([x, 0])
coordinate_to_emit = (name, x, y)
self.coordinatesClicked2D.emit(coordinate_to_emit)
else:
continue
def _get_transformed_position(
self, x: float, y: float, transform: QTransform
) -> tuple[QPointF, QPointF]:
"""
Maps the given x and y coordinates to the transformed position using the provided transform.
Args:
x (float): The x-coordinate to transform.
y (float): The y-coordinate to transform.
transform (QTransform): The transformation to apply.
"""
origin = transform.map(QPointF(0, 0))
row = transform.map(QPointF(0, y)) - origin
col = transform.map(QPointF(x, 0)) - origin
return row, col
def clear_markers(self):
"""Clears the markers from the plot."""
for marker in self.marker_moved_1d.values():
@@ -512,8 +598,18 @@ class Crosshair(QObject):
image = item.image
if image is None:
continue
ix = int(np.clip(x, 0, image.shape[0] - 1))
iy = int(np.clip(y, 0, image.shape[1] - 1))
if item.image_transform is not None:
inv_transform, _ = item.image_transform.inverted()
pt = inv_transform.map(QPointF(x, y))
px, py = pt.x(), pt.y()
else:
px, py = x, y
# Clip to valid pixel indices
ix = int(np.clip(px, 0, image.shape[0] - 1)) # column
iy = int(np.clip(py, 0, image.shape[1] - 1)) # row
intensity = image[ix, iy]
text += f"\nIntensity: {intensity:.{precision}f}"
break
@@ -533,15 +629,19 @@ class Crosshair(QObject):
self.is_derivative = self.plot_item.ctrl.derivativeCheck.isChecked()
self.clear_markers()
def cleanup(self):
@SafeSlot()
def reset(self):
"""Resets the crosshair to its initial state."""
if self.marker_2d_row is not None:
self.plot_item.removeItem(self.marker_2d_row)
self.marker_2d_row = None
if self.marker_2d_col is not None:
self.plot_item.removeItem(self.marker_2d_col)
self.marker_2d_col = None
self.clear_markers()
def cleanup(self):
self.reset()
self.plot_item.removeItem(self.v_line)
self.plot_item.removeItem(self.h_line)
self.plot_item.removeItem(self.coord_label)
self.clear_markers()

View File

@@ -37,6 +37,7 @@ def get_plugin_widgets() -> dict[str, BECConnector]:
"""
modules = _get_available_plugins("bec.widgets.user_widgets")
loaded_plugins = {}
print(modules)
for module in modules:
mods = inspect.getmembers(module, predicate=_filter_plugins)
for name, mod_cls in mods:

View File

@@ -28,6 +28,7 @@ from bec_widgets.widgets.containers.main_window.main_window import BECMainWindow
from bec_widgets.widgets.control.device_control.positioner_box import PositionerBox
from bec_widgets.widgets.control.scan_control.scan_control import ScanControl
from bec_widgets.widgets.editors.vscode.vscode import VSCodeEditor
from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap
from bec_widgets.widgets.plots.image.image import Image
from bec_widgets.widgets.plots.motor_map.motor_map import MotorMap
from bec_widgets.widgets.plots.multi_waveform.multi_waveform import MultiWaveform
@@ -154,6 +155,9 @@ class BECDockArea(BECWidget, QWidget):
filled=True,
parent=self,
),
"heatmap": MaterialIconAction(
icon_name=Heatmap.ICON_NAME, tooltip="Add Heatmap", filled=True, parent=self
),
},
),
)
@@ -291,6 +295,9 @@ class BECDockArea(BECWidget, QWidget):
menu_plots.actions["motor_map"].action.triggered.connect(
lambda: self._create_widget_from_toolbar(widget_name="MotorMap")
)
menu_plots.actions["heatmap"].action.triggered.connect(
lambda: self._create_widget_from_toolbar(widget_name="Heatmap")
)
# Menu Devices
menu_devices.actions["scan_control"].action.triggered.connect(

View File

@@ -149,6 +149,7 @@ _web_console_registry = WebConsoleRegistry()
def suppress_qt_messages(type_, context, msg):
if context.category in ["js", "default"]:
return
print(msg)
qInstallMessageHandler(suppress_qt_messages)

View File

@@ -8,6 +8,7 @@ from bec_widgets.utils.bec_widget import BECWidget
def suppress_qt_messages(type_, context, msg):
if context.category in ["js", "default"]:
return
print(msg)
qInstallMessageHandler(suppress_qt_messages)

View File

@@ -0,0 +1,773 @@
from __future__ import annotations
import functools
import json
from typing import Literal
import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger, messages
from bec_lib.endpoints import MessageEndpoints
from pydantic import BaseModel, Field, field_validator
from qtpy.QtCore import QTimer, Signal
from qtpy.QtGui import QTransform
from scipy.interpolate import LinearNDInterpolator
from scipy.spatial import cKDTree
from toolz import partition
from bec_widgets.utils import Colors
from bec_widgets.utils.bec_connector import ConnectionConfig
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils.settings_dialog import SettingsDialog
from bec_widgets.utils.toolbars.actions import MaterialIconAction
from bec_widgets.widgets.plots.heatmap.settings.heatmap_setting import HeatmapSettings
from bec_widgets.widgets.plots.image.image_base import ImageBase
from bec_widgets.widgets.plots.image.image_item import ImageItem
logger = bec_logger.logger
class HeatmapDeviceSignal(BaseModel):
"""The configuration of a signal in the scatter waveform widget."""
name: str
entry: str
model_config: dict = {"validate_assignment": True}
class HeatmapConfig(ConnectionConfig):
parent_id: str | None = Field(None, description="The parent plot of the curve.")
color_map: str | None = Field(
"plasma", description="The color palette of the heatmap widget.", validate_default=True
)
color_bar: Literal["full", "simple"] | None = Field(
None, description="The type of the color bar."
)
lock_aspect_ratio: bool = Field(
False, description="Whether to lock the aspect ratio of the image."
)
x_device: HeatmapDeviceSignal | None = Field(
None, description="The x device signal of the heatmap."
)
y_device: HeatmapDeviceSignal | None = Field(
None, description="The y device signal of the heatmap."
)
z_device: HeatmapDeviceSignal | None = Field(
None, description="The z device signal of the heatmap."
)
model_config: dict = {"validate_assignment": True}
_validate_color_palette = field_validator("color_map")(Colors.validate_color_map)
class Heatmap(ImageBase):
"""
Heatmap widget for visualizing 2d grid data with color mapping for the z-axis.
"""
USER_ACCESS = [
# General PlotBase Settings
"enable_toolbar",
"enable_toolbar.setter",
"enable_side_panel",
"enable_side_panel.setter",
"enable_fps_monitor",
"enable_fps_monitor.setter",
"set",
"title",
"title.setter",
"x_label",
"x_label.setter",
"y_label",
"y_label.setter",
"x_limits",
"x_limits.setter",
"y_limits",
"y_limits.setter",
"x_grid",
"x_grid.setter",
"y_grid",
"y_grid.setter",
"inner_axes",
"inner_axes.setter",
"outer_axes",
"outer_axes.setter",
"auto_range_x",
"auto_range_x.setter",
"auto_range_y",
"auto_range_y.setter",
"minimal_crosshair_precision",
"minimal_crosshair_precision.setter",
# ImageView Specific Settings
"color_map",
"color_map.setter",
"v_range",
"v_range.setter",
"v_min",
"v_min.setter",
"v_max",
"v_max.setter",
"lock_aspect_ratio",
"lock_aspect_ratio.setter",
"autorange",
"autorange.setter",
"autorange_mode",
"autorange_mode.setter",
"enable_colorbar",
"enable_simple_colorbar",
"enable_simple_colorbar.setter",
"enable_full_colorbar",
"enable_full_colorbar.setter",
"fft",
"fft.setter",
"log",
"log.setter",
"main_image",
"add_roi",
"remove_roi",
"rois",
"plot",
]
PLUGIN = True
RPC = True
ICON_NAME = "dataset"
new_scan = Signal()
new_scan_id = Signal(str)
sync_signal_update = Signal()
heatmap_property_changed = Signal()
def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs):
if config is None:
config = HeatmapConfig(widget_class=self.__class__.__name__)
super().__init__(parent=parent, config=config, **kwargs)
self._image_config = config
self.scan_id = None
self.old_scan_id = None
self.scan_item = None
self.status_message = None
self._grid_index = None
self.heatmap_dialog = None
self.reload = False
self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status())
self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress())
self.proxy_update_sync = pg.SignalProxy(
self.sync_signal_update, rateLimit=5, slot=self.update_plot
)
self._init_toolbar_heatmap()
self.toolbar.show_bundles(
[
"heatmap_settings",
"plot_export",
"image_crosshair",
"mouse_interaction",
"image_autorange",
"image_colorbar",
"image_processing",
"axis_popup",
]
)
@property
def main_image(self) -> ImageItem:
"""Access the main image item."""
return self.layer_manager["main"].image
################################################################################
# Widget Specific GUI interactions
################################################################################
@SafeSlot(popup_error=True)
def plot(
self,
x_name: str,
y_name: str,
z_name: str,
x_entry: None | str = None,
y_entry: None | str = None,
z_entry: None | str = None,
color_map: str | None = "plasma",
label: str | None = None,
validate_bec: bool = True,
reload: bool = False,
):
"""
Plot the heatmap with the given x, y, and z data.
"""
if validate_bec:
x_entry = self.entry_validator.validate_signal(x_name, x_entry)
y_entry = self.entry_validator.validate_signal(y_name, y_entry)
z_entry = self.entry_validator.validate_signal(z_name, z_entry)
if x_entry is None or y_entry is None or z_entry is None:
raise ValueError("x, y, and z entries must be provided.")
if x_name is None or y_name is None or z_name is None:
raise ValueError("x, y, and z names must be provided.")
self._image_config = HeatmapConfig(
parent_id=self.gui_id,
x_device=HeatmapDeviceSignal(name=x_name, entry=x_entry),
y_device=HeatmapDeviceSignal(name=y_name, entry=y_entry),
z_device=HeatmapDeviceSignal(name=z_name, entry=z_entry),
color_map=color_map,
)
self.color_map = color_map
self.reload = reload
self.update_labels()
self._fetch_running_scan()
self.sync_signal_update.emit()
def _fetch_running_scan(self):
scan = self.client.queue.scan_storage.current_scan
if scan is not None:
self.scan_item = scan
self.scan_id = scan.scan_id
elif self.client.history and len(self.client.history) > 0:
self.scan_item = self.client.history[-1]
self.scan_id = self.client.history._scan_ids[-1]
self.old_scan_id = None
self.update_plot()
def update_labels(self):
"""
Update the labels of the x, y, and z axes.
"""
if self._image_config is None:
return
x_name = self._image_config.x_device.name
y_name = self._image_config.y_device.name
z_name = self._image_config.z_device.name
if x_name is not None:
self.x_label = x_name # type: ignore
x_dev = self.dev.get(x_name)
if x_dev and hasattr(x_dev, "egu"):
self.x_label_units = x_dev.egu()
if y_name is not None:
self.y_label = y_name # type: ignore
y_dev = self.dev.get(y_name)
if y_dev and hasattr(y_dev, "egu"):
self.y_label_units = y_dev.egu()
if z_name is not None:
self.title = z_name
def _init_toolbar_heatmap(self):
"""
Initialize the toolbar for the heatmap widget, adding actions for heatmap settings.
"""
self.toolbar.add_action(
"heatmap_settings",
MaterialIconAction(
icon_name="scatter_plot",
tooltip="Show Heatmap Settings",
checkable=True,
parent=self,
),
)
self.toolbar.components.get_action("heatmap_settings").action.triggered.connect(
self.show_heatmap_settings
)
# disable all processing actions except for the fft and log
bundle = self.toolbar.get_bundle("image_processing")
for name, action in bundle.bundle_actions.items():
if name not in ["image_processing_fft", "image_processing_log"]:
action().action.setVisible(False)
def show_heatmap_settings(self):
"""
Show the heatmap settings dialog.
"""
heatmap_settings_action = self.toolbar.components.get_action("heatmap_settings").action
if self.heatmap_dialog is None or not self.heatmap_dialog.isVisible():
heatmap_settings = HeatmapSettings(parent=self, target_widget=self, popup=True)
self.heatmap_dialog = SettingsDialog(
self, settings_widget=heatmap_settings, window_title="Heatmap Settings", modal=False
)
self.heatmap_dialog.resize(620, 200)
# When the dialog is closed, update the toolbar icon and clear the reference
self.heatmap_dialog.finished.connect(self._heatmap_dialog_closed)
self.heatmap_dialog.show()
heatmap_settings_action.setChecked(True)
else:
# If already open, bring it to the front
self.heatmap_dialog.raise_()
self.heatmap_dialog.activateWindow()
heatmap_settings_action.setChecked(True) # keep it toggled
def _heatmap_dialog_closed(self):
"""
Slot for when the heatmap settings dialog is closed.
"""
self.heatmap_dialog = None
self.toolbar.components.get_action("heatmap_settings").action.setChecked(False)
@SafeSlot(dict, dict)
def on_scan_status(self, msg: dict, meta: dict):
"""
Initial scan status message handler, which is triggered at the begging and end of scan.
Args:
msg(dict): The message content.
meta(dict): The message metadata.
"""
current_scan_id = msg.get("scan_id", None)
if current_scan_id is None:
return
if current_scan_id != self.scan_id:
self.reset()
self.new_scan.emit()
self.new_scan_id.emit(current_scan_id)
self.old_scan_id = self.scan_id
self.scan_id = current_scan_id
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # type: ignore
# First trigger to update the scan curves
self.sync_signal_update.emit()
@SafeSlot(dict, dict)
def on_scan_progress(self, msg: dict, meta: dict):
self.sync_signal_update.emit()
status = msg.get("done")
if status:
QTimer.singleShot(100, self.update_plot)
QTimer.singleShot(300, self.update_plot)
@SafeSlot(verify_sender=True)
def update_plot(self, _=None) -> None:
"""
Update the plot with the current data.
"""
if self.scan_item is None:
logger.info("No scan executed so far; skipping update.")
return
data, access_key = self._fetch_scan_data_and_access()
if data == "none":
logger.info("No scan executed so far; skipping update.")
return
if self._image_config is None:
return
try:
x_name = self._image_config.x_device.name
x_entry = self._image_config.x_device.entry
y_name = self._image_config.y_device.name
y_entry = self._image_config.y_device.entry
z_name = self._image_config.z_device.name
z_entry = self._image_config.z_device.entry
except AttributeError:
return
if access_key == "val":
x_data = data.get(x_name, {}).get(x_entry, {}).get(access_key, None)
y_data = data.get(y_name, {}).get(y_entry, {}).get(access_key, None)
z_data = data.get(z_name, {}).get(z_entry, {}).get(access_key, None)
else:
x_data = data.get(x_name, {}).get(x_entry, {}).read().get("value", None)
y_data = data.get(y_name, {}).get(y_entry, {}).read().get("value", None)
z_data = data.get(z_name, {}).get(z_entry, {}).read().get("value", None)
if not isinstance(x_data, list):
x_data = x_data.tolist() if isinstance(x_data, np.ndarray) else None
if not isinstance(y_data, list):
y_data = y_data.tolist() if isinstance(y_data, np.ndarray) else None
if not isinstance(z_data, list):
z_data = z_data.tolist() if isinstance(z_data, np.ndarray) else None
if x_data is None or y_data is None or z_data is None:
logger.warning("x, y, or z data is None; skipping update.")
return
if len(x_data) != len(y_data) or len(x_data) != len(z_data):
logger.warning(
"x, y, and z data lengths do not match; skipping update. "
f"Lengths: x={len(x_data)}, y={len(y_data)}, z={len(z_data)}"
)
return
if hasattr(self.scan_item, "status_message"):
scan_msg = self.scan_item.status_message
elif hasattr(self.scan_item, "metadata"):
metadata = self.scan_item.metadata["bec"]
status = metadata["exit_status"]
scan_id = metadata["scan_id"]
scan_name = metadata["scan_name"]
scan_type = metadata["scan_type"]
request_inputs = metadata["request_inputs"]
if "arg_bundle" in request_inputs and isinstance(request_inputs["arg_bundle"], str):
# Convert the arg_bundle from a JSON string to a dictionary
request_inputs["arg_bundle"] = json.loads(request_inputs["arg_bundle"])
positions = metadata.get("positions", [])
positions = positions.tolist() if isinstance(positions, np.ndarray) else positions
scan_msg = messages.ScanStatusMessage(
status=status,
scan_id=scan_id,
scan_name=scan_name,
scan_type=scan_type,
request_inputs=request_inputs,
info={"positions": positions},
)
else:
scan_msg = None
if scan_msg is None:
logger.warning("Scan message is None; skipping update.")
return
self.status_message = scan_msg
img, transform = self.get_image_data(x_data=x_data, y_data=y_data, z_data=z_data)
if img is None:
logger.warning("Image data is None; skipping update.")
return
if self._color_bar is not None:
self._color_bar.blockSignals(True)
self.main_image.set_data(img, transform=transform)
if self._color_bar is not None:
self._color_bar.blockSignals(False)
self.image_updated.emit()
if self.crosshair is not None:
self.crosshair.update_markers_on_image_change()
def get_image_data(
self,
x_data: list[float] | None = None,
y_data: list[float] | None = None,
z_data: list[float] | None = None,
) -> tuple[np.ndarray | None, QTransform | None]:
"""
Get the image data for the heatmap. Depending on the scan type, it will
either pre-allocate the grid (grid_scan) or interpolate the data (step scan).
Args:
x_data (np.ndarray): The x data.
y_data (np.ndarray): The y data.
z_data (np.ndarray): The z data.
msg (messages.ScanStatusMessage): The scan status message.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
msg = self.status_message
if x_data is None or y_data is None or z_data is None or msg is None:
logger.warning("x, y, or z data is None; skipping update.")
return None, None
if msg.scan_name == "grid_scan":
# We only support the grid scan mode if both scanning motors
# are configured in the heatmap config.
device_x = self._image_config.x_device.entry
device_y = self._image_config.y_device.entry
if (
device_x in msg.request_inputs["arg_bundle"]
and device_y in msg.request_inputs["arg_bundle"]
):
return self.get_grid_scan_image(z_data, msg)
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
return None, None
return self.get_step_scan_image(x_data, y_data, z_data, msg)
def get_grid_scan_image(
self, z_data: list[float], msg: messages.ScanStatusMessage
) -> tuple[np.ndarray, QTransform]:
"""
Get the image data for a grid scan.
Args:
z_data (np.ndarray): The z data.
msg (messages.ScanStatusMessage): The scan status message.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
args = self.arg_bundle_to_dict(4, msg.request_inputs["arg_bundle"])
shape = (
args[self._image_config.x_device.entry][-1],
args[self._image_config.y_device.entry][-1],
)
data = self.main_image.raw_data
if data is None or data.shape != shape:
data = np.empty(shape)
data.fill(np.nan)
def _get_grid_data(axis, snaked=True):
x_grid, y_grid = np.meshgrid(axis[0], axis[1])
if snaked:
y_grid.T[::2] = np.fliplr(y_grid.T[::2])
x_flat = x_grid.T.ravel()
y_flat = y_grid.T.ravel()
positions = np.vstack((x_flat, y_flat)).T
return positions
snaked = msg.request_inputs["kwargs"].get("snaked", True)
# If the scan's fast axis is x, we need to swap the x and y axes
swap = bool(msg.request_inputs["arg_bundle"][4] == self._image_config.x_device.entry)
# calculate the QTransform to put (0,0) at the axis origin
scan_pos = np.asarray(msg.info["positions"])
x_min = min(scan_pos[:, 0])
x_max = max(scan_pos[:, 0])
y_min = min(scan_pos[:, 1])
y_max = max(scan_pos[:, 1])
x_range = x_max - x_min
y_range = y_max - y_min
pixel_size_x = x_range / (shape[0] - 1)
pixel_size_y = y_range / (shape[1] - 1)
transform = QTransform()
if swap:
transform.scale(pixel_size_y, pixel_size_x)
transform.translate(y_min / pixel_size_y - 0.5, x_min / pixel_size_x - 0.5)
else:
transform.scale(pixel_size_x, pixel_size_y)
transform.translate(x_min / pixel_size_x - 0.5, y_min / pixel_size_y - 0.5)
target_positions = _get_grid_data(
(np.arange(shape[int(swap)]), np.arange(shape[int(not swap)])), snaked=snaked
)
# Fill the data array with the z values
if self._grid_index is None or self.reload:
self._grid_index = 0
self.reload = False
for i in range(self._grid_index, len(z_data)):
data[target_positions[i, int(swap)], target_positions[i, int(not swap)]] = z_data[i]
self._grid_index = len(z_data)
return data, transform
def get_step_scan_image(
self,
x_data: list[float],
y_data: list[float],
z_data: list[float],
msg: messages.ScanStatusMessage,
) -> tuple[np.ndarray, QTransform]:
"""
Get the image data for an arbitrary step scan.
Args:
x_data (list[float]): The x data.
y_data (list[float]): The y data.
z_data (list[float]): The z data.
msg (messages.ScanStatusMessage): The scan status message.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
xy_data = np.column_stack((x_data, y_data))
grid_x, grid_y, transform = self.get_image_grid(xy_data)
# Interpolate the z data onto the grid
interp = LinearNDInterpolator(xy_data, z_data)
grid_z = interp(grid_x, grid_y)
return grid_z, transform
def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]:
"""
LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id
to avoid recalculating the grid for the same scan.
Args:
_scan_id (str): The scan ID. Needed for caching but not used in the function.
Returns:
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
"""
width, height = self.estimate_image_resolution(positions)
# Create a grid of points for interpolation
grid_x, grid_y = np.mgrid[
min(positions[:, 0]) : max(positions[:, 0]) : width * 1j,
min(positions[:, 1]) : max(positions[:, 1]) : height * 1j,
]
# Calculate the QTransform to put (0,0) at the axis origin
x_min = min(positions[:, 0])
y_min = min(positions[:, 1])
x_max = max(positions[:, 0])
y_max = max(positions[:, 1])
x_range = x_max - x_min
y_range = y_max - y_min
x_scale = x_range / width
y_scale = y_range / height
transform = QTransform()
transform.scale(x_scale, y_scale)
transform.translate(x_min / x_scale - 0.5, y_min / y_scale - 0.5)
return grid_x, grid_y, transform
@staticmethod
def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]:
"""
Estimate the number of pixels needed for the image based on the coordinates.
Args:
coords (np.ndarray): The coordinates of the points.
Returns:
tuple[int, int]: The estimated width and height of the image."""
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError("Input must be an (m x 2) array of (x, y) coordinates.")
x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
tree = cKDTree(coords)
distances, _ = tree.query(coords, k=2)
distances = distances[:, 1] # Get the second nearest neighbor distance
avg_distance = np.mean(distances)
width_extent = x_max - x_min
height_extent = y_max - y_min
# Calculate the number of pixels needed based on the average distance
width_pixels = int(np.ceil(width_extent / avg_distance))
height_pixels = int(np.ceil(height_extent / avg_distance))
return max(1, width_pixels), max(1, height_pixels)
def arg_bundle_to_dict(self, bundle_size: int, args: list) -> dict:
"""
Convert the argument bundle to a dictionary.
Args:
args (list): The argument bundle.
Returns:
dict: The dictionary representation of the argument bundle.
"""
params = {}
for cmds in partition(bundle_size, args):
params[cmds[0]] = list(cmds[1:])
return params
def _fetch_scan_data_and_access(self):
"""
Decide whether the widget is in live or historical mode
and return the appropriate data dict and access key.
Returns:
data_dict (dict): The data structure for the current scan.
access_key (str): Either 'val' (live) or 'value' (history).
"""
if self.scan_item is None:
# Optionally fetch the latest from history if nothing is set
# self.update_with_scan_history(-1)
if self.scan_item is None:
logger.info("No scan executed so far; skipping device curves categorisation.")
return "none", "none"
if hasattr(self.scan_item, "live_data"):
# Live scan
return self.scan_item.live_data, "val"
# Historical
scan_devices = self.scan_item.devices
return scan_devices, "value"
def reset(self):
self._grid_index = None
self.main_image.clear()
if self.crosshair is not None:
self.crosshair.reset()
super().reset()
################################################################################
# Post Processing
################################################################################
@SafeProperty(bool)
def fft(self) -> bool:
"""
Whether FFT postprocessing is enabled.
"""
return self.main_image.fft
@fft.setter
def fft(self, enable: bool):
"""
Set FFT postprocessing.
Args:
enable(bool): Whether to enable FFT postprocessing.
"""
self.main_image.fft = enable
@SafeProperty(bool)
def log(self) -> bool:
"""
Whether logarithmic scaling is applied.
"""
return self.main_image.log
@log.setter
def log(self, enable: bool):
"""
Set logarithmic scaling.
Args:
enable(bool): Whether to enable logarithmic scaling.
"""
self.main_image.log = enable
@SafeProperty(int)
def num_rotation_90(self) -> int:
"""
The number of 90° rotations to apply counterclockwise.
"""
return self.main_image.num_rotation_90
@num_rotation_90.setter
def num_rotation_90(self, value: int):
"""
Set the number of 90° rotations to apply counterclockwise.
Args:
value(int): The number of 90° rotations to apply.
"""
self.main_image.num_rotation_90 = value
@SafeProperty(bool)
def transpose(self) -> bool:
"""
Whether the image is transposed.
"""
return self.main_image.transpose
@transpose.setter
def transpose(self, enable: bool):
"""
Set the image to be transposed.
Args:
enable(bool): Whether to enable transposing the image.
"""
self.main_image.transpose = enable
if __name__ == "__main__": # pragma: no cover
import sys
from qtpy.QtWidgets import QApplication
app = QApplication(sys.argv)
heatmap = Heatmap()
heatmap.plot(x_name="samx", y_name="samy", z_name="bpm4i")
heatmap.show()
sys.exit(app.exec_())

View File

@@ -0,0 +1 @@
{'files': ['heatmap.py']}

View File

@@ -0,0 +1,54 @@
# Copyright (C) 2022 The Qt Company Ltd.
# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause
from qtpy.QtDesigner import QDesignerCustomWidgetInterface
from bec_widgets.utils.bec_designer import designer_material_icon
from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap
DOM_XML = """
<ui language='c++'>
<widget class='Heatmap' name='heatmap'>
</widget>
</ui>
"""
class HeatmapPlugin(QDesignerCustomWidgetInterface): # pragma: no cover
def __init__(self):
super().__init__()
self._form_editor = None
def createWidget(self, parent):
t = Heatmap(parent)
return t
def domXml(self):
return DOM_XML
def group(self):
return ""
def icon(self):
return designer_material_icon(Heatmap.ICON_NAME)
def includeFile(self):
return "heatmap"
def initialize(self, form_editor):
self._form_editor = form_editor
def isContainer(self):
return False
def isInitialized(self):
return self._form_editor is not None
def name(self):
return "Heatmap"
def toolTip(self):
return ""
def whatsThis(self):
return self.toolTip()

View File

@@ -0,0 +1,15 @@
def main(): # pragma: no cover
from qtpy import PYSIDE6
if not PYSIDE6:
print("PYSIDE6 is not available in the environment. Cannot patch designer.")
return
from PySide6.QtDesigner import QPyDesignerCustomWidgetCollection
from bec_widgets.widgets.plots.heatmap.heatmap_plugin import HeatmapPlugin
QPyDesignerCustomWidgetCollection.addCustomWidget(HeatmapPlugin())
if __name__ == "__main__": # pragma: no cover
main()

View File

@@ -0,0 +1,138 @@
import os
from qtpy.QtWidgets import QFrame, QScrollArea, QVBoxLayout
from bec_widgets.utils import UILoader
from bec_widgets.utils.error_popups import SafeSlot
from bec_widgets.utils.settings_dialog import SettingWidget
class HeatmapSettings(SettingWidget):
def __init__(self, parent=None, target_widget=None, popup=False, *args, **kwargs):
super().__init__(parent=parent, *args, **kwargs)
# This is a settings widget that depends on the target widget
# and should mirror what is in the target widget.
# Saving settings for this widget could result in recursively setting the target widget.
self.setProperty("skip_settings", True)
current_path = os.path.dirname(__file__)
if popup:
form = UILoader().load_ui(
os.path.join(current_path, "heatmap_settings_horizontal.ui"), self
)
else:
form = UILoader().load_ui(
os.path.join(current_path, "heatmap_settings_vertical.ui"), self
)
self.target_widget = target_widget
self.popup = popup
# # Scroll area
self.scroll_area = QScrollArea(self)
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setFrameShape(QFrame.NoFrame)
self.scroll_area.setWidget(form)
self.layout = QVBoxLayout(self)
self.layout.setContentsMargins(0, 0, 0, 0)
self.layout.addWidget(self.scroll_area)
self.ui = form
self.fetch_all_properties()
self.target_widget.heatmap_property_changed.connect(self.fetch_all_properties)
if popup is False:
self.ui.button_apply.clicked.connect(self.accept_changes)
@SafeSlot()
def fetch_all_properties(self):
"""
Fetch all properties from the target widget and update the settings widget.
"""
if not self.target_widget:
return
# Get properties from the target widget
color_map = getattr(self.target_widget, "color_map", None)
# Default values for device properties
x_name, x_entry = None, None
y_name, y_entry = None, None
z_name, z_entry = None, None
# Safely access device properties
if hasattr(self.target_widget, "_image_config") and self.target_widget._image_config:
config = self.target_widget._image_config
if hasattr(config, "x_device") and config.x_device:
x_name = getattr(config.x_device, "name", None)
x_entry = getattr(config.x_device, "entry", None)
if hasattr(config, "y_device") and config.y_device:
y_name = getattr(config.y_device, "name", None)
y_entry = getattr(config.y_device, "entry", None)
if hasattr(config, "z_device") and config.z_device:
z_name = getattr(config.z_device, "name", None)
z_entry = getattr(config.z_device, "entry", None)
# Apply the properties to the settings widget
if hasattr(self.ui, "color_map"):
self.ui.color_map.colormap = color_map
if hasattr(self.ui, "x_name"):
self.ui.x_name.set_device(x_name)
if hasattr(self.ui, "x_entry") and x_entry is not None:
self.ui.x_entry.setText(x_entry)
if hasattr(self.ui, "y_name"):
self.ui.y_name.set_device(y_name)
if hasattr(self.ui, "y_entry") and y_entry is not None:
self.ui.y_entry.setText(y_entry)
if hasattr(self.ui, "z_name"):
self.ui.z_name.set_device(z_name)
if hasattr(self.ui, "z_entry") and z_entry is not None:
self.ui.z_entry.setText(z_entry)
@SafeSlot()
def accept_changes(self):
"""
Apply all properties from the settings widget to the target widget.
"""
x_name = self.ui.x_name.text()
x_entry = self.ui.x_entry.text()
y_name = self.ui.y_name.text()
y_entry = self.ui.y_entry.text()
z_name = self.ui.z_name.text()
z_entry = self.ui.z_entry.text()
validate_bec = self.ui.validate_bec.checked
color_map = self.ui.color_map.colormap
self.target_widget.plot(
x_name=x_name,
y_name=y_name,
z_name=z_name,
x_entry=x_entry,
y_entry=y_entry,
z_entry=z_entry,
color_map=color_map,
validate_bec=validate_bec,
reload=True,
)
def cleanup(self):
self.ui.x_name.close()
self.ui.x_name.deleteLater()
self.ui.x_entry.close()
self.ui.x_entry.deleteLater()
self.ui.y_name.close()
self.ui.y_name.deleteLater()
self.ui.y_entry.close()
self.ui.y_entry.deleteLater()
self.ui.z_name.close()
self.ui.z_name.deleteLater()
self.ui.z_entry.close()
self.ui.z_entry.deleteLater()

View File

@@ -0,0 +1,203 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>Form</class>
<widget class="QWidget" name="Form">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>604</width>
<height>166</height>
</rect>
</property>
<property name="windowTitle">
<string>Form</string>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<layout class="QHBoxLayout" name="horizontalLayout_2">
<item>
<widget class="QLabel" name="label_7">
<property name="text">
<string>Validate BEC</string>
</property>
</widget>
</item>
<item>
<widget class="ToggleSwitch" name="validate_bec"/>
</item>
<item>
<widget class="BECColorMapWidget" name="color_map"/>
</item>
</layout>
</item>
<item>
<layout class="QHBoxLayout" name="horizontalLayout">
<item>
<widget class="QGroupBox" name="groupBox">
<property name="title">
<string>X Device</string>
</property>
<layout class="QGridLayout" name="gridLayout">
<item row="0" column="0">
<widget class="QLabel" name="label">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="x_name"/>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_2">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="x_entry"/>
</item>
</layout>
</widget>
</item>
<item>
<widget class="QGroupBox" name="groupBox_2">
<property name="title">
<string>Y Device</string>
</property>
<layout class="QGridLayout" name="gridLayout_2">
<item row="0" column="0">
<widget class="QLabel" name="label_3">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="y_name"/>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_4">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="y_entry"/>
</item>
</layout>
</widget>
</item>
<item>
<widget class="QGroupBox" name="groupBox_3">
<property name="title">
<string>Z Device</string>
</property>
<layout class="QGridLayout" name="gridLayout_3">
<item row="0" column="0">
<widget class="QLabel" name="label_5">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_6">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="z_entry"/>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="z_name"/>
</item>
</layout>
</widget>
</item>
</layout>
</item>
</layout>
</widget>
<customwidgets>
<customwidget>
<class>DeviceLineEdit</class>
<extends>QLineEdit</extends>
<header>device_line_edit</header>
</customwidget>
<customwidget>
<class>ToggleSwitch</class>
<extends>QWidget</extends>
<header>toggle_switch</header>
</customwidget>
<customwidget>
<class>BECColorMapWidget</class>
<extends>QWidget</extends>
<header>bec_color_map_widget</header>
</customwidget>
</customwidgets>
<tabstops>
<tabstop>x_name</tabstop>
<tabstop>x_entry</tabstop>
<tabstop>y_name</tabstop>
<tabstop>y_entry</tabstop>
<tabstop>z_name</tabstop>
<tabstop>z_entry</tabstop>
</tabstops>
<resources/>
<connections>
<connection>
<sender>x_name</sender>
<signal>textChanged(QString)</signal>
<receiver>x_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>134</x>
<y>95</y>
</hint>
<hint type="destinationlabel">
<x>138</x>
<y>128</y>
</hint>
</hints>
</connection>
<connection>
<sender>y_name</sender>
<signal>textChanged(QString)</signal>
<receiver>y_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>351</x>
<y>91</y>
</hint>
<hint type="destinationlabel">
<x>349</x>
<y>121</y>
</hint>
</hints>
</connection>
<connection>
<sender>z_name</sender>
<signal>textChanged(QString)</signal>
<receiver>z_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>520</x>
<y>98</y>
</hint>
<hint type="destinationlabel">
<x>522</x>
<y>127</y>
</hint>
</hints>
</connection>
</connections>
</ui>

View File

@@ -0,0 +1,204 @@
<?xml version="1.0" encoding="UTF-8"?>
<ui version="4.0">
<class>Form</class>
<widget class="QWidget" name="Form">
<property name="geometry">
<rect>
<x>0</x>
<y>0</y>
<width>233</width>
<height>427</height>
</rect>
</property>
<property name="maximumSize">
<size>
<width>16777215</width>
<height>427</height>
</size>
</property>
<property name="windowTitle">
<string>Form</string>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="QPushButton" name="button_apply">
<property name="text">
<string>Apply</string>
</property>
</widget>
</item>
<item>
<widget class="BECColorMapWidget" name="color_map"/>
</item>
<item>
<layout class="QHBoxLayout" name="horizontalLayout">
<item>
<widget class="QLabel" name="label_7">
<property name="text">
<string>Validate BEC</string>
</property>
</widget>
</item>
<item>
<widget class="ToggleSwitch" name="validate_bec"/>
</item>
</layout>
</item>
<item>
<widget class="QGroupBox" name="groupBox">
<property name="title">
<string>X Device</string>
</property>
<layout class="QGridLayout" name="gridLayout">
<item row="0" column="0">
<widget class="QLabel" name="label">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="x_name"/>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_2">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="x_entry"/>
</item>
</layout>
</widget>
</item>
<item>
<widget class="QGroupBox" name="groupBox_2">
<property name="title">
<string>Y Device</string>
</property>
<layout class="QGridLayout" name="gridLayout_2">
<item row="0" column="0">
<widget class="QLabel" name="label_3">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="y_name"/>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_4">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="y_entry"/>
</item>
</layout>
</widget>
</item>
<item>
<widget class="QGroupBox" name="groupBox_3">
<property name="title">
<string>Z Device</string>
</property>
<layout class="QGridLayout" name="gridLayout_3">
<item row="0" column="0">
<widget class="QLabel" name="label_5">
<property name="text">
<string>Name</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="DeviceLineEdit" name="z_name"/>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_6">
<property name="text">
<string>Signal</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QLineEdit" name="z_entry"/>
</item>
</layout>
</widget>
</item>
</layout>
</widget>
<customwidgets>
<customwidget>
<class>DeviceLineEdit</class>
<extends>QLineEdit</extends>
<header>device_line_edit</header>
</customwidget>
<customwidget>
<class>ToggleSwitch</class>
<extends>QWidget</extends>
<header>toggle_switch</header>
</customwidget>
<customwidget>
<class>BECColorMapWidget</class>
<extends>QWidget</extends>
<header>bec_color_map_widget</header>
</customwidget>
</customwidgets>
<resources/>
<connections>
<connection>
<sender>x_name</sender>
<signal>textChanged(QString)</signal>
<receiver>x_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>156</x>
<y>123</y>
</hint>
<hint type="destinationlabel">
<x>158</x>
<y>157</y>
</hint>
</hints>
</connection>
<connection>
<sender>y_name</sender>
<signal>textChanged(QString)</signal>
<receiver>y_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>116</x>
<y>229</y>
</hint>
<hint type="destinationlabel">
<x>116</x>
<y>251</y>
</hint>
</hints>
</connection>
<connection>
<sender>z_name</sender>
<signal>textChanged(QString)</signal>
<receiver>z_entry</receiver>
<slot>clear()</slot>
<hints>
<hint type="sourcelabel">
<x>110</x>
<y>326</y>
</hint>
<hint type="destinationlabel">
<x>110</x>
<y>352</y>
</hint>
</hints>
</connection>
</connections>
</ui>

View File

@@ -140,7 +140,6 @@ class Image(ImageBase):
if config is None:
config = ImageConfig(widget_class=self.__class__.__name__)
self.gui_id = config.gui_id
self._color_bar = None
self.subscriptions: defaultdict[str, ImageLayerConfig] = defaultdict(
lambda: ImageLayerConfig(monitor=None, monitor_type="auto", source="auto")
)
@@ -566,6 +565,8 @@ class Image(ImageBase):
self.main_image.clear()
self.main_image.buffer = []
self.main_image.max_len = 0
if self.crosshair is not None:
self.crosshair.reset()
image_buffer = self.adjust_image_buffer(self.main_image, data)
if self._color_bar is not None:
self._color_bar.blockSignals(True)

View File

@@ -260,6 +260,7 @@ class ImageBase(PlotBase):
"""
self.x_roi = None
self.y_roi = None
self._color_bar = None
super().__init__(*args, **kwargs)
self.roi_controller = ROIController(colormap="viridis")
@@ -566,7 +567,9 @@ class ImageBase(PlotBase):
"""
# Create ROI plot widgets
self.x_roi = ImageROIPlot(parent=self)
self.x_roi.plot_item.setXLink(self.plot_item)
self.y_roi = ImageROIPlot(parent=self)
self.y_roi.plot_item.setYLink(self.plot_item)
self.x_roi.apply_theme("dark")
self.y_roi.apply_theme("dark")
@@ -637,7 +640,8 @@ class ImageBase(PlotBase):
else:
x = coordinates[1]
y = coordinates[2]
image = self.layer_manager["main"].image.image
image_item = self.layer_manager["main"].image
image = image_item.image
if image is None:
return
max_row, max_col = image.shape[0] - 1, image.shape[1] - 1
@@ -646,14 +650,27 @@ class ImageBase(PlotBase):
return
# Horizontal slice
h_slice = image[:, col]
x_axis = np.arange(h_slice.shape[0])
x_pixel_indices = np.arange(h_slice.shape[0])
if image_item.image_transform is None:
h_world_x = np.arange(h_slice.shape[0])
else:
h_world_x = [
image_item.image_transform.map(xi + 0.5, col + 0.5)[0] for xi in x_pixel_indices
]
self.x_roi.plot_item.clear()
self.x_roi.plot_item.plot(x_axis, h_slice, pen=pg.mkPen(self.x_roi.curve_color, width=3))
self.x_roi.plot_item.plot(h_world_x, h_slice, pen=pg.mkPen(self.x_roi.curve_color, width=3))
# Vertical slice
v_slice = image[row, :]
y_axis = np.arange(v_slice.shape[0])
y_pixel_indices = np.arange(v_slice.shape[0])
if image_item.image_transform is None:
v_world_y = np.arange(v_slice.shape[0])
else:
v_world_y = [
image_item.image_transform.map(row + 0.5, yi + 0.5)[1] for yi in y_pixel_indices
]
self.y_roi.plot_item.clear()
self.y_roi.plot_item.plot(v_slice, y_axis, pen=pg.mkPen(self.y_roi.curve_color, width=3))
self.y_roi.plot_item.plot(v_slice, v_world_y, pen=pg.mkPen(self.y_roi.curve_color, width=3))
################################################################################
# Widget Specific Properties
@@ -865,6 +882,7 @@ class ImageBase(PlotBase):
enabled(bool): Whether to enable autorange.
sync(bool): Whether to synchronize the autorange state across all layers.
"""
print(f"Setting autorange to {enabled}")
for layer in self.layer_manager:
if not layer.sync.autorange:
continue
@@ -874,6 +892,7 @@ class ImageBase(PlotBase):
# if sync:
self._sync_colorbar_levels()
self._sync_autorange_switch()
print(f"Autorange set to {enabled}")
@SafeProperty(str)
def autorange_mode(self) -> str:
@@ -895,6 +914,7 @@ class ImageBase(PlotBase):
Args:
mode(str): The autorange mode. Options are "max" or "mean".
"""
print(f"Setting autorange mode to {mode}")
# for qt Designer
if mode not in ["max", "mean"]:
return

View File

@@ -7,6 +7,7 @@ import pyqtgraph as pg
from bec_lib.logger import bec_logger
from pydantic import Field, ValidationError, field_validator
from qtpy.QtCore import Signal
from qtpy.QtGui import QTransform
from bec_widgets.utils import BECConnector, Colors, ConnectionConfig
from bec_widgets.widgets.plots.image.image_processor import (
@@ -85,6 +86,7 @@ class ImageItem(BECConnector, pg.ImageItem):
self.set_parent(parent_image)
else:
self.parent_image = None
self.image_transform = None
super().__init__(config=config, gui_id=gui_id, **kwargs)
self.raw_data = None
@@ -100,8 +102,9 @@ class ImageItem(BECConnector, pg.ImageItem):
def parent(self):
return self.parent_image
def set_data(self, data: np.ndarray):
def set_data(self, data: np.ndarray, transform: QTransform | None = None):
self.raw_data = data
self.image_transform = transform
self._process_image()
################################################################################
@@ -210,12 +213,19 @@ class ImageItem(BECConnector, pg.ImageItem):
"""
Reprocess the current raw data and update the image display.
"""
if self.raw_data is not None:
autorange = self.config.autorange
self._image_processor.set_config(self.config.processing)
processed_data = self._image_processor.process_image(self.raw_data)
self.setImage(processed_data, autoLevels=False)
self.autorange = autorange
if self.raw_data is None:
return
if np.all(np.isnan(self.raw_data)):
return
autorange = self.config.autorange
self._image_processor.set_config(self.config.processing)
processed_data = self._image_processor.process_image(self.raw_data)
self.setImage(processed_data, autoLevels=False)
if self.image_transform is not None:
self.setTransform(self.image_transform)
self.autorange = autorange
@property
def fft(self) -> bool:

View File

@@ -27,7 +27,12 @@ class ImageStats:
Returns:
ImageStats: The statistics of the image data.
"""
return cls(maximum=np.max(data), minimum=np.min(data), mean=np.mean(data), std=np.std(data))
return cls(
maximum=np.nanmax(data),
minimum=np.nanmin(data),
mean=np.nanmean(data),
std=np.nanstd(data),
)
# noinspection PyDataclass
@@ -81,7 +86,7 @@ class ImageProcessor(QObject):
Returns:
np.ndarray: The processed data.
"""
return np.abs(np.fft.fftshift(np.fft.fft2(data)))
return np.abs(np.fft.fftshift(np.fft.fft2(np.nan_to_num(data))))
def rotation(self, data: np.ndarray, rotate_90: int) -> np.ndarray:
"""

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "bec_widgets"
version = "2.21.4"
version = "2.22.1"
description = "BEC Widgets"
requires-python = ">=3.10"
classifiers = [

View File

@@ -2,8 +2,10 @@ import numpy as np
import pyqtgraph as pg
import pytest
from qtpy.QtCore import QPointF, Qt
from qtpy.QtGui import QTransform
from bec_widgets.utils import Crosshair
from bec_widgets.widgets.plots.image.image_item import ImageItem
# pylint: disable = redefined-outer-name
@@ -27,7 +29,7 @@ def image_widget_with_crosshair(qtbot):
qtbot.addWidget(widget)
qtbot.waitExposed(widget)
image_item = pg.ImageItem()
image_item = ImageItem()
image_item.setImage(np.random.rand(100, 100))
widget.addItem(image_item)
@@ -113,7 +115,7 @@ def test_mouse_moved_signals_2D(image_widget_with_crosshair):
crosshair.mouse_moved(event_mock)
assert emitted_values_2D == [(str(id(image_item)), 21, 55)]
assert emitted_values_2D == [("ImageItem", 21, 55)]
def test_mouse_moved_signals_2D_outside(image_widget_with_crosshair):
@@ -311,3 +313,53 @@ def test_crosshair_precision_properties_image(image_widget_with_crosshair):
crosshair.precision = 2
assert crosshair._current_precision() == 2
def test_get_transformed_position(plot_widget_with_crosshair):
"""Test that _get_transformed_position correctly transforms coordinates."""
crosshair, _ = plot_widget_with_crosshair
# Create a simple transform
transform = QTransform()
transform.translate(10, 20) # Origin is now at (10, 20)
# Test coordinates
x, y = 5, 8
# Get the transformed position
row, col = crosshair._get_transformed_position(x, y, transform)
# Calculate expected values:
# row should be the y-offset from origin after transform
# col should be the x-offset from origin after transform
expected_row = QPointF(0, 8) # y direction offset
expected_col = QPointF(5, 0) # x direction offset
# Check that the results match expectations
assert row == expected_row
assert col == expected_col
def test_get_transformed_position_with_scale(plot_widget_with_crosshair):
"""Test that _get_transformed_position correctly handles scaling transformations."""
crosshair, _ = plot_widget_with_crosshair
# Create a transform with scaling
transform = QTransform()
transform.translate(10, 20) # Origin is now at (10, 20)
transform.scale(2, 3) # Scale x by 2 and y by 3
# Test coordinates
x, y = 5, 8
# Get the transformed position
row, col = crosshair._get_transformed_position(x, y, transform)
# Calculate expected values with scaling applied:
# For a scale transform, the offsets should be multiplied by the scale factors
expected_row = QPointF(0, 8 * 3) # y direction offset with scale factor 3
expected_col = QPointF(5 * 2, 0) # x direction offset with scale factor 2
# Check that the results match expectations
assert row == expected_row
assert col == expected_col

View File

@@ -0,0 +1,329 @@
from unittest import mock
import numpy as np
import pytest
from bec_lib import messages
from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap, HeatmapConfig, HeatmapDeviceSignal
# pytest: disable=unused-import
from tests.unit_tests.client_mocks import mocked_client
from .client_mocks import create_dummy_scan_item
@pytest.fixture
def heatmap_widget(qtbot, mocked_client):
widget = Heatmap(client=mocked_client)
qtbot.addWidget(widget)
qtbot.waitExposed(widget)
yield widget
def test_heatmap_plot(heatmap_widget):
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
assert heatmap_widget._image_config.x_device.name == "samx"
assert heatmap_widget._image_config.y_device.name == "samy"
assert heatmap_widget._image_config.z_device.name == "bpm4i"
def test_heatmap_on_scan_status_no_scan_id(heatmap_widget):
scan_msg = messages.ScanStatusMessage(scan_id=None, status="open", metadata={}, info={})
with mock.patch.object(heatmap_widget, "reset") as mock_reset:
heatmap_widget.on_scan_status(scan_msg.content, scan_msg.metadata)
mock_reset.assert_not_called()
def test_heatmap_on_scan_status_same_scan_id(heatmap_widget):
scan_msg = messages.ScanStatusMessage(scan_id="123", status="open", metadata={}, info={})
heatmap_widget.scan_id = "123"
with mock.patch.object(heatmap_widget, "reset") as mock_reset:
heatmap_widget.on_scan_status(scan_msg.content, scan_msg.metadata)
mock_reset.assert_not_called()
def test_heatmap_widget_on_scan_status_different_scan_id(heatmap_widget):
scan_msg = messages.ScanStatusMessage(scan_id="123", status="open", metadata={}, info={})
heatmap_widget.scan_id = "456"
with mock.patch.object(heatmap_widget, "reset") as mock_reset:
heatmap_widget.on_scan_status(scan_msg.content, scan_msg.metadata)
mock_reset.assert_called_once()
def test_heatmap_get_image_data_missing_data(heatmap_widget):
"""
If the data is missing or incomplete, the method should return None.
"""
assert heatmap_widget.get_image_data() == (None, None)
def test_heatmap_get_image_data_grid_scan(heatmap_widget):
scan_msg = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="grid_scan",
metadata={},
info={},
request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}},
)
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
heatmap_widget.status_message = scan_msg
with mock.patch.object(heatmap_widget, "get_grid_scan_image") as mock_get_grid_scan_image:
heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6])
mock_get_grid_scan_image.assert_called_once()
def test_heatmap_get_image_data_step_scan(heatmap_widget):
"""
If the step scan has too few points, it should return None.
"""
scan_msg = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="step_scan",
scan_type="step",
metadata={},
info={"positions": [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]},
)
with mock.patch.object(heatmap_widget, "get_step_scan_image") as mock_get_step_scan_image:
heatmap_widget.status_message = scan_msg
heatmap_widget.get_image_data(x_data=[1, 2, 3, 4], y_data=[1, 2, 3, 4], z_data=[1, 2, 5, 6])
mock_get_step_scan_image.assert_called_once()
def test_heatmap_get_image_data_step_scan_too_few_points(heatmap_widget):
"""
If the step scan has too few points, it should return None.
"""
scan_msg = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="step_scan",
scan_type="step",
metadata={},
info={"positions": [[1, 2], [3, 4]]},
)
heatmap_widget.status_message = scan_msg
out = heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6])
assert out == (None, None)
def test_heatmap_get_image_data_unsupported_scan(heatmap_widget):
scan_msg = messages.ScanStatusMessage(
scan_id="123", status="open", scan_type="fly", metadata={}, info={}
)
heatmap_widget.status_message = scan_msg
assert heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6]) == (
None,
None,
)
def test_heatmap_get_grid_scan_image(heatmap_widget):
scan_msg = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="grid_scan",
metadata={},
info={"positions": np.random.rand(100, 2).tolist()},
request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}},
)
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
img, _ = heatmap_widget.get_grid_scan_image(list(range(100)), msg=scan_msg)
assert img.shape == (10, 10)
assert sorted(np.asarray(img, dtype=int).flatten().tolist()) == list(range(100))
def test_heatmap_get_step_scan_image(heatmap_widget):
scan_msg = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="step_scan",
scan_type="step",
metadata={},
info={"positions": np.random.rand(100, 2).tolist()},
)
heatmap_widget.status_message = scan_msg
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.scan_item.status_message = scan_msg
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
img, _ = heatmap_widget.get_step_scan_image(
list(np.random.rand(100)), list(np.random.rand(100)), list(range(100)), msg=scan_msg
)
assert img.shape > (10, 10)
def test_heatmap_update_plot_no_scan_item(heatmap_widget):
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
with mock.patch.object(heatmap_widget.main_image, "setImage") as mock_set_image:
heatmap_widget.update_plot(_override_slot_params={"verify_sender": False})
mock_set_image.assert_not_called()
def test_heatmap_update_plot(heatmap_widget):
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.scan_item.status_message = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="grid_scan",
metadata={},
info={"positions": np.random.rand(100, 2).tolist()},
request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}},
)
with mock.patch.object(heatmap_widget.main_image, "setImage") as mock_set_image:
heatmap_widget.update_plot(_override_slot_params={"verify_sender": False})
img = mock_set_image.mock_calls[0].args[0]
assert img.shape == (10, 10)
def test_heatmap_update_plot_without_status_message(heatmap_widget):
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.scan_item.status_message = None
with mock.patch.object(heatmap_widget.main_image, "setImage") as mock_set_image:
heatmap_widget.update_plot(_override_slot_params={"verify_sender": False})
mock_set_image.assert_not_called()
def test_heatmap_update_plot_no_img_data(heatmap_widget):
heatmap_widget._image_config = HeatmapConfig(
parent_id="parent_id",
x_device=HeatmapDeviceSignal(name="samx", entry="samx"),
y_device=HeatmapDeviceSignal(name="samy", entry="samy"),
z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"),
color_map="viridis",
)
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.scan_item.status_message = messages.ScanStatusMessage(
scan_id="123",
status="open",
scan_name="grid_scan",
metadata={},
info={},
request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}},
)
with mock.patch.object(heatmap_widget, "get_image_data", return_value=None):
with mock.patch.object(heatmap_widget.main_image, "setImage") as mock_set_image:
heatmap_widget.update_plot(_override_slot_params={"verify_sender": False})
mock_set_image.assert_not_called()
def test_heatmap_settings_popup(heatmap_widget, qtbot):
"""
Test that the settings popup opens and contains the expected elements.
"""
settings_action = heatmap_widget.toolbar.components.get_action("heatmap_settings").action
heatmap_widget.show_heatmap_settings()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is not None)
assert heatmap_widget.heatmap_dialog.isVisible()
assert settings_action.isChecked()
heatmap_widget.heatmap_dialog.reject()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is None)
assert not settings_action.isChecked()
def test_heatmap_settings_popup_already_open(heatmap_widget, qtbot):
"""
Test that if the settings dialog is already open, it is brought to the front.
"""
heatmap_widget.show_heatmap_settings()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is not None)
initial_dialog = heatmap_widget.heatmap_dialog
heatmap_widget.show_heatmap_settings()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is initial_dialog)
assert heatmap_widget.heatmap_dialog.isVisible() # Dialog should still be visible
assert heatmap_widget.heatmap_dialog is initial_dialog # Should be the same dialog
heatmap_widget.heatmap_dialog.reject()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is None)
def test_heatmap_settings_popup_accept_changes(heatmap_widget, qtbot):
"""
Test that changes made in the settings dialog are applied correctly.
"""
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
assert heatmap_widget.color_map == "plasma" # Default colormap
heatmap_widget.show_heatmap_settings()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is not None)
dialog = heatmap_widget.heatmap_dialog
assert dialog.widget.isVisible()
# Simulate changing a setting
dialog.widget.ui.color_map.colormap = "viridis"
# Accept changes
dialog.accept()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is None)
# Verify that the setting was applied
assert heatmap_widget.color_map == "viridis"
def test_heatmap_settings_popup_show_settings(heatmap_widget, qtbot):
"""
Test that the settings dialog opens and contains the expected elements.
"""
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
heatmap_widget.show_heatmap_settings()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is not None)
dialog = heatmap_widget.heatmap_dialog
assert dialog.isVisible()
assert dialog.widget is not None
assert hasattr(dialog.widget.ui, "color_map")
assert hasattr(dialog.widget.ui, "x_name")
assert hasattr(dialog.widget.ui, "y_name")
assert hasattr(dialog.widget.ui, "z_name")
# Check that the ui elements are correctly initialized
assert dialog.widget.ui.color_map.colormap == heatmap_widget.color_map
assert dialog.widget.ui.x_name.text() == heatmap_widget._image_config.x_device.name
dialog.reject()
qtbot.waitUntil(lambda: heatmap_widget.heatmap_dialog is None)

View File

@@ -93,6 +93,7 @@ def test_logpanel_output(qtbot, log_panel: LogPanel):
assert log_panel.plain_text == TEST_COMBINED_PLAINTEXT
def display_queue_empty():
print(log_panel._log_manager._display_queue)
return len(log_panel._log_manager._display_queue) == 0
next_text = "datetime | error | test log message"

View File

@@ -296,6 +296,7 @@ def test_on_scan_selected(scan_control, scan_name):
# Check kwargs boxes
kwargs_group = [param for param in expected_scan_info["gui_config"]["kwarg_groups"]]
print(kwargs_group)
for kwarg_box, kwarg_group in zip(scan_control.kwarg_boxes, kwargs_group):
assert kwarg_box.title() == kwarg_group["name"]