Files
bec_widgets/bec_widgets/widgets/plots/heatmap/heatmap.py
T

1546 lines
58 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger, messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.utils.import_utils import lazy_import, lazy_import_from
from pydantic import BaseModel, Field, field_validator
from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal
from qtpy.QtGui import QTransform
from toolz import partition
from bec_widgets.utils.bec_connector import ConnectionConfig
from bec_widgets.utils.colors import Colors
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils.settings_dialog import SettingsDialog
from bec_widgets.utils.toolbars.actions import MaterialIconAction
from bec_widgets.widgets.plots.heatmap.settings.heatmap_setting import HeatmapSettings
from bec_widgets.widgets.plots.image.image_base import ImageBase
from bec_widgets.widgets.plots.image.image_item import ImageItem
from bec_widgets.widgets.plots.plot_base import PlotBase
logger = bec_logger.logger
if TYPE_CHECKING:
from scipy.interpolate import (
CloughTocher2DInterpolator,
LinearNDInterpolator,
NearestNDInterpolator,
)
from scipy.spatial import cKDTree
else:
CloughTocher2DInterpolator, LinearNDInterpolator, NearestNDInterpolator = lazy_import_from(
"scipy.interpolate",
["CloughTocher2DInterpolator", "LinearNDInterpolator", "NearestNDInterpolator"],
)
cKDTree = lazy_import_from("scipy.spatial", ["cKDTree"])
class HeatmapDeviceSignal(BaseModel):
"""The configuration of a signal in the scatter waveform widget."""
device: str
signal: str
model_config: dict = {"validate_assignment": True}
class HeatmapConfig(ConnectionConfig):
parent_id: str | None = Field(None, description="The parent plot of the curve.")
color_map: str | None = Field(
"plasma", description="The color palette of the heatmap widget.", validate_default=True
)
color_bar: Literal["full", "simple"] | None = Field(
None, description="The type of the color bar."
)
interpolation: Literal["linear", "nearest", "clough"] = Field(
"linear", description="The interpolation method for the heatmap."
)
oversampling_factor: float = Field(
1.0,
description="Factor to oversample the grid resolution (1.0 = no oversampling, 2.0 = 2x resolution).",
)
show_config_label: bool = Field(
True, description="Whether to show the configuration label in the heatmap."
)
enforce_interpolation: bool = Field(
False, description="Whether to use the interpolation mode even for grid scans."
)
lock_aspect_ratio: bool = Field(
False, description="Whether to lock the aspect ratio of the image."
)
device_x: HeatmapDeviceSignal | None = Field(
None, description="The x device signal of the heatmap."
)
device_y: HeatmapDeviceSignal | None = Field(
None, description="The y device signal of the heatmap."
)
device_z: HeatmapDeviceSignal | None = Field(
None, description="The z device signal of the heatmap."
)
model_config: dict = {"validate_assignment": True}
_validate_color_palette = field_validator("color_map")(Colors.validate_color_map)
@dataclass
class _InterpolationRequest:
"""Immutable payload describing an interpolation request for the worker thread.
Args:
x_data: X coordinates collected so far.
y_data: Y coordinates collected so far.
z_data: Z values associated with x/y.
data_version: Number of points at request time (len(z_data)); used to reject stale results.
scan_id: Identifier for the scan that produced the data.
interpolation: Interpolation method to apply.
oversampling_factor: Oversampling factor for the interpolation grid.
"""
x_data: list[float]
y_data: list[float]
z_data: list[float]
data_version: int
scan_id: str
interpolation: str
oversampling_factor: float
class _StepInterpolationWorker(QObject):
"""Worker for performing step-scan interpolation in a background thread.
This worker computes the interpolated heatmap image using the provided data
and settings, then emits the result or a failure signal.
Signals:
finished(image, transform, data_version, scan_id):
Emitted when interpolation is successful.
- image: The resulting image (numpy array or similar).
- transform: The QTransform for the image.
- data_version: The data version for the request.
- scan_id: The scan identifier.
failed(error_message, data_version, scan_id):
Emitted when interpolation fails.
- error_message: The error message string.
- data_version: The data version for the request.
- scan_id: The scan identifier.
"""
finished = Signal(object, object, int, str)
failed = Signal(str, int, str)
def __init__(self, parent: QObject | None = None):
super().__init__(parent=parent)
self._active_request: _InterpolationRequest | None = None
self._processing = False
@property
def is_processing(self) -> bool:
"""Return whether the worker is currently processing a request."""
return self._processing
@SafeSlot(object, int)
def process(self, request: _InterpolationRequest, data_version: int):
"""
Process an interpolation request in the worker thread.
Args:
request(_InterpolationRequest): The interpolation request payload.
data_version(int): The data version for the request.
"""
self._active_request = request
self._processing = True
try:
image, transform = Heatmap.compute_step_scan_image(
x_data=np.asarray(request.x_data, dtype=float),
y_data=np.asarray(request.y_data, dtype=float),
z_data=np.asarray(request.z_data, dtype=float),
oversampling_factor=request.oversampling_factor,
interpolation_method=request.interpolation,
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(f"Step-scan interpolation failed with: {exc}")
self.failed.emit(str(exc), data_version, request.scan_id)
self._processing = False
return
self._processing = False
self.finished.emit(image, transform, data_version, request.scan_id)
class Heatmap(ImageBase):
"""
Heatmap widget for visualizing 2d grid data with color mapping for the z-axis.
"""
USER_ACCESS = [
*PlotBase.USER_ACCESS,
# ImageView Specific Settings
"color_map",
"color_map.setter",
"v_range",
"v_range.setter",
"v_min",
"v_min.setter",
"v_max",
"v_max.setter",
"autorange",
"autorange.setter",
"autorange_mode",
"autorange_mode.setter",
"enable_colorbar",
"enable_simple_colorbar",
"enable_simple_colorbar.setter",
"enable_full_colorbar",
"enable_full_colorbar.setter",
"interpolation_method",
"interpolation_method.setter",
"oversampling_factor",
"oversampling_factor.setter",
"enforce_interpolation",
"enforce_interpolation.setter",
"fft",
"fft.setter",
"log",
"log.setter",
"main_image",
"add_roi",
"remove_roi",
"rois",
"plot",
# Device properties
"device_x",
"device_x.setter",
"signal_x",
"signal_x.setter",
"device_y",
"device_y.setter",
"signal_y",
"signal_y.setter",
"device_z",
"device_z.setter",
"signal_z",
"signal_z.setter",
]
PLUGIN = True
RPC = True
ICON_NAME = "dataset"
new_scan = Signal()
new_scan_id = Signal(str)
sync_signal_update = Signal()
heatmap_property_changed = Signal()
interpolation_requested = Signal(object, int)
def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs):
if config is None:
config = HeatmapConfig(
widget_class=self.__class__.__name__,
parent_id=None,
color_map="plasma",
color_bar=None,
interpolation="linear",
oversampling_factor=1.0,
lock_aspect_ratio=False,
device_x=None,
device_y=None,
device_z=None,
)
super().__init__(parent=parent, config=config, theme_update=True, **kwargs)
self._image_config = config
self.scan_id = None
self.old_scan_id = None
self.scan_item = None
self.status_message = None
self._grid_index = None
# Highest data_version we have dispatched for the current scan; used to drop stale results.
# Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it.
self._latest_interpolation_version = -1
self._interpolation_thread: QThread | None = None
self._interpolation_worker: _StepInterpolationWorker | None = None
self._pending_interpolation_request: _InterpolationRequest | None = None
self.heatmap_dialog = None
bg_color = pg.mkColor((240, 240, 240, 150))
self.config_label = pg.LegendItem(
labelTextColor=(0, 0, 0), offset=(-30, 1), brush=pg.mkBrush(bg_color), horSpacing=0
)
self.config_label.setParentItem(self.plot_item.vb)
self.config_label.setVisible(False)
self.reload = False
self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status())
self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress())
self.heatmap_property_changed.connect(lambda: self.sync_signal_update.emit())
self.proxy_update_sync = pg.SignalProxy(
self.sync_signal_update, rateLimit=5, slot=self.update_plot
)
self._init_toolbar_heatmap()
self.toolbar.show_bundles(
[
"heatmap_settings",
"plot_export",
"image_crosshair",
"mouse_interaction",
"image_autorange",
"image_colorbar",
"image_processing",
"axis_popup",
"interpolation_info",
]
)
@property
def main_image(self) -> ImageItem:
"""Access the main image item."""
return self.layer_manager["main"].image
################################################################################
# Widget Specific GUI interactions
################################################################################
@SafeSlot(str)
def apply_theme(self, theme: str):
"""
Apply the current theme to the heatmap widget.
"""
super().apply_theme(theme)
if theme == "dark":
brush = pg.mkBrush(pg.mkColor(50, 50, 50, 150))
color = pg.mkColor(255, 255, 255)
else:
brush = pg.mkBrush(pg.mkColor(240, 240, 240, 150))
color = pg.mkColor(0, 0, 0)
if hasattr(self, "config_label"):
self.config_label.setBrush(brush)
self.config_label.setLabelTextColor(color)
self.redraw_config_label()
@SafeSlot(popup_error=True)
def plot(
self,
device_x: str,
device_y: str,
device_z: str,
signal_x: None | str = None,
signal_y: None | str = None,
signal_z: None | str = None,
color_map: str | None = "plasma",
validate_bec: bool = True,
interpolation: Literal["linear", "nearest"] | None = None,
enforce_interpolation: bool | None = None,
oversampling_factor: float | None = None,
lock_aspect_ratio: bool | None = None,
show_config_label: bool | None = None,
reload: bool = False,
):
"""
Plot the heatmap with the given x, y, and z data.
Args:
device_x (str): The name of the x-axis device signal.
device_y (str): The name of the y-axis device signal.
device_z (str): The name of the z-axis device signal.
signal_x (str | None): The entry for the x-axis device signal.
signal_y (str | None): The entry for the y-axis device signal.
signal_z (str | None): The entry for the z-axis device signal.
color_map (str | None): The color map to use for the heatmap.
validate_bec (bool): Whether to validate the entries against BEC signals.
interpolation (Literal["linear", "nearest"] | None): The interpolation method to use.
enforce_interpolation (bool | None): Whether to enforce interpolation even for grid scans.
oversampling_factor (float | None): Factor to oversample the grid resolution.
lock_aspect_ratio (bool | None): Whether to lock the aspect ratio of the image.
show_config_label (bool | None): Whether to show the configuration label in the heatmap.
reload (bool): Whether to reload the heatmap with new data.
"""
if validate_bec:
signal_x = self.entry_validator.validate_signal(device_x, signal_x)
signal_y = self.entry_validator.validate_signal(device_y, signal_y)
signal_z = self.entry_validator.validate_signal(device_z, signal_z)
if signal_x is None or signal_y is None or signal_z is None:
raise ValueError("x, y, and z entries must be provided.")
if device_x is None or device_y is None or device_z is None:
raise ValueError("x, y, and z names must be provided.")
if interpolation is None:
interpolation = self._image_config.interpolation
if oversampling_factor is None:
oversampling_factor = self._image_config.oversampling_factor
if enforce_interpolation is None:
enforce_interpolation = self._image_config.enforce_interpolation
if lock_aspect_ratio is None:
lock_aspect_ratio = self._image_config.lock_aspect_ratio
if show_config_label is None:
show_config_label = self._image_config.show_config_label
def _device_key(device: HeatmapDeviceSignal | None) -> tuple[str | None, str | None]:
return (device.device if device else None, device.signal if device else None)
prev_cfg = getattr(self, "_image_config", None)
config_changed = False
if prev_cfg and prev_cfg.device_x and prev_cfg.device_y and prev_cfg.device_z:
config_changed = any(
(
_device_key(prev_cfg.device_x) != (device_x, signal_x),
_device_key(prev_cfg.device_y) != (device_y, signal_y),
_device_key(prev_cfg.device_z) != (device_z, signal_z),
)
)
self._image_config = HeatmapConfig(
parent_id=self.gui_id,
device_x=HeatmapDeviceSignal(device=device_x, signal=signal_x),
device_y=HeatmapDeviceSignal(device=device_y, signal=signal_y),
device_z=HeatmapDeviceSignal(device=device_z, signal=signal_z),
color_map=color_map,
color_bar=None,
interpolation=interpolation,
oversampling_factor=oversampling_factor,
enforce_interpolation=enforce_interpolation,
lock_aspect_ratio=lock_aspect_ratio,
show_config_label=show_config_label,
)
self.color_map = color_map
self.reload = reload or config_changed
if config_changed:
self._grid_index = None
self.main_image.clear()
self.update_labels()
self._fetch_running_scan()
self.sync_signal_update.emit()
def _fetch_running_scan(self):
scan = self.client.queue.scan_storage.current_scan
if scan is not None:
self.scan_item = scan
self.scan_id = scan.scan_id
elif self.client.history and len(self.client.history) > 0:
self.scan_item = self.client.history[-1]
self.scan_id = self.client.history._scan_ids[-1]
self.old_scan_id = None
def update_labels(self):
"""
Update the labels of the x, y, and z axes.
"""
if self._image_config is None:
return
# Safely get device names (might be None if not yet configured)
device_x = self._image_config.device_x
device_y = self._image_config.device_y
device_z = self._image_config.device_z
device_x_name = device_x.device if device_x else None
device_y_name = device_y.device if device_y else None
device_z_name = device_z.device if device_z else None
if device_x_name is not None:
self.x_label = device_x_name # type: ignore
x_dev = self.dev.get(device_x_name)
if x_dev and hasattr(x_dev, "egu"):
self.x_label_units = x_dev.egu()
if device_y_name is not None:
self.y_label = device_y_name # type: ignore
y_dev = self.dev.get(device_y_name)
if y_dev and hasattr(y_dev, "egu"):
self.y_label_units = y_dev.egu()
if device_z_name is not None:
self.title = device_z_name
def _init_toolbar_heatmap(self):
"""
Initialize the toolbar for the heatmap widget, adding actions for heatmap settings.
"""
self.toolbar.add_action(
"heatmap_settings",
MaterialIconAction(
icon_name="scatter_plot",
tooltip="Show Heatmap Settings",
checkable=True,
parent=self,
),
)
self.toolbar.components.get_action("heatmap_settings").action.triggered.connect(
self.show_heatmap_settings
)
# disable all processing actions except for the fft and log
bundle = self.toolbar.get_bundle("image_processing")
for name, action in bundle.bundle_actions.items():
if name not in ["image_processing_fft", "image_processing_log"]:
action().action.setVisible(False)
self.toolbar.add_action(
"interpolation_info",
MaterialIconAction(
icon_name="info", tooltip="Show Interpolation Info", checkable=True, parent=self
),
)
self.toolbar.components.get_action("interpolation_info").action.triggered.connect(
self.toggle_interpolation_info
)
self.toolbar.components.get_action("interpolation_info").action.setChecked(
self._image_config.show_config_label
)
def show_heatmap_settings(self):
"""
Show the heatmap settings dialog.
"""
heatmap_settings_action = self.toolbar.components.get_action("heatmap_settings").action
if self.heatmap_dialog is None or not self.heatmap_dialog.isVisible():
heatmap_settings = HeatmapSettings(parent=self, target_widget=self, popup=True)
self.heatmap_dialog = SettingsDialog(
self, settings_widget=heatmap_settings, window_title="Heatmap Settings", modal=False
)
self.heatmap_dialog.resize(700, 350)
# When the dialog is closed, update the toolbar icon and clear the reference
self.heatmap_dialog.finished.connect(self._heatmap_dialog_closed)
self.heatmap_dialog.show()
heatmap_settings_action.setChecked(True)
else:
# If already open, bring it to the front
self.heatmap_dialog.raise_()
self.heatmap_dialog.activateWindow()
heatmap_settings_action.setChecked(True) # keep it toggled
def toggle_interpolation_info(self):
"""
Toggle the visibility of the interpolation info label.
"""
self._image_config.show_config_label = not self._image_config.show_config_label
self.toolbar.components.get_action("interpolation_info").action.setChecked(
self._image_config.show_config_label
)
self.redraw_config_label()
def _heatmap_dialog_closed(self):
"""
Slot for when the heatmap settings dialog is closed.
"""
self.heatmap_dialog = None
self.toolbar.components.get_action("heatmap_settings").action.setChecked(False)
@SafeSlot(dict, dict)
def on_scan_status(self, msg: dict, meta: dict):
"""
Initial scan status message handler, which is triggered at the begging and end of scan.
Args:
msg(dict): The message content.
meta(dict): The message metadata.
"""
current_scan_id = msg.get("scan_id", None)
if current_scan_id is None:
return
if current_scan_id != self.scan_id:
self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts
self.reset()
self.new_scan.emit()
self.new_scan_id.emit(current_scan_id)
self.old_scan_id = self.scan_id
self.scan_id = current_scan_id
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # type: ignore
# First trigger to update the scan curves
self.sync_signal_update.emit()
@SafeSlot(dict, dict)
def on_scan_progress(self, msg: dict, meta: dict):
self.sync_signal_update.emit()
status = msg.get("done")
if status:
QTimer.singleShot(100, self.update_plot)
QTimer.singleShot(300, self.update_plot)
@SafeSlot(verify_sender=True)
def update_plot(self, _=None) -> None:
"""
Update the plot with the current data.
"""
if self.scan_item is None:
logger.info("No scan executed so far; skipping update.")
return
data, access_key = self._fetch_scan_data_and_access()
if data == "none":
logger.info("No scan executed so far; skipping update.")
return
if self._image_config is None:
return
try:
device_x = self._image_config.device_x.device
signal_x = self._image_config.device_x.signal
device_y = self._image_config.device_y.device
signal_y = self._image_config.device_y.signal
device_z = self._image_config.device_z.device
signal_z = self._image_config.device_z.signal
except AttributeError:
return
if access_key == "val":
x_data = data.get(device_x, {}).get(signal_x, {}).get(access_key, None)
y_data = data.get(device_y, {}).get(signal_y, {}).get(access_key, None)
z_data = data.get(device_z, {}).get(signal_z, {}).get(access_key, None)
else:
x_data = data.get(device_x, {}).get(signal_x, {}).read().get("value", None)
y_data = data.get(device_y, {}).get(signal_y, {}).read().get("value", None)
z_data = data.get(device_z, {}).get(signal_z, {}).read().get("value", None)
if not isinstance(x_data, list):
x_data = x_data.tolist() if isinstance(x_data, np.ndarray) else None
if not isinstance(y_data, list):
y_data = y_data.tolist() if isinstance(y_data, np.ndarray) else None
if not isinstance(z_data, list):
z_data = z_data.tolist() if isinstance(z_data, np.ndarray) else None
if x_data is None or y_data is None or z_data is None:
logger.warning("x, y, or z data is None; skipping update.")
return
if len(x_data) != len(y_data) or len(x_data) != len(z_data):
logger.warning(
"x, y, and z data lengths do not match; skipping update. "
f"Lengths: x={len(x_data)}, y={len(y_data)}, z={len(z_data)}"
)
return
if hasattr(self.scan_item, "status_message"):
scan_msg = self.scan_item.status_message
elif hasattr(self.scan_item, "metadata"):
metadata = self.scan_item.metadata["bec"]
status = metadata["status"]
scan_id = metadata["scan_id"]
scan_name = metadata["scan_name"]
scan_type = metadata["scan_type"]
scan_number = metadata["scan_number"]
request_inputs = metadata["request_inputs"]
if "arg_bundle" in request_inputs and isinstance(request_inputs["arg_bundle"], str):
# Convert the arg_bundle from a JSON string to a dictionary
request_inputs["arg_bundle"] = json.loads(request_inputs["arg_bundle"])
positions = metadata.get("positions", [])
positions = positions.tolist() if isinstance(positions, np.ndarray) else positions
scan_msg = messages.ScanStatusMessage(
status=status,
scan_id=scan_id,
scan_name=scan_name,
scan_number=scan_number,
scan_type=scan_type,
request_inputs=request_inputs,
info={"positions": positions},
)
else:
scan_msg = None
if scan_msg is None:
logger.warning("Scan message is None; skipping update.")
return
self.status_message = scan_msg
if self._image_config.show_config_label:
self.redraw_config_label()
if self._is_grid_scan_supported(scan_msg):
img, transform = self.get_grid_scan_image(z_data, scan_msg)
self._apply_image_update(img, transform)
return
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
logger.warning("Not enough data points to interpolate; skipping update.")
return
self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg)
def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None):
"""Apply interpolated image and transform to the heatmap display.
This method updates the main image with the computed data and emits
the image_updated signal. Color bar signals are temporarily blocked
during the update to prevent cascading updates.
Args:
img(np.ndarray): The interpolated image data, or None if unavailable
transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable
"""
if img is None or transform is None:
logger.warning("Image data is None; skipping update.")
return
if self._color_bar is not None:
self._color_bar.blockSignals(True)
if self.main_image is None:
logger.warning("Main image item is None; cannot update image.")
return
self.main_image.set_data(img, transform=transform)
if self._color_bar is not None:
self._color_bar.blockSignals(False)
self.image_updated.emit()
if self.crosshair is not None:
self.crosshair.update_markers_on_image_change()
def _request_step_scan_interpolation(
self,
x_data: list[float],
y_data: list[float],
z_data: list[float],
msg: messages.ScanStatusMessage,
):
"""Request step-scan interpolation in a background thread.
If a thread is already running, the request is queued as a pending request
and will be processed when the current interpolation completes.
Args:
x_data(list[float]): X coordinates of data points
y_data(list[float]): Y coordinates of data points
z_data(list[float]): Z values at each point
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
"""
request = _InterpolationRequest(
x_data=list(x_data),
y_data=list(y_data),
z_data=list(z_data),
data_version=len(z_data),
scan_id=msg.scan_id,
interpolation=self._image_config.interpolation,
oversampling_factor=self._image_config.oversampling_factor,
)
if self._interpolation_worker is not None and self._interpolation_worker.is_processing:
self._pending_interpolation_request = request
return
self._start_step_scan_interpolation(request)
def _ensure_interpolation_thread(self):
if self._interpolation_thread is None:
self._interpolation_thread = QThread()
self._interpolation_worker = _StepInterpolationWorker()
self._interpolation_worker.moveToThread(self._interpolation_thread)
self.interpolation_requested.connect(
self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.finished.connect(
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.failed.connect(
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
)
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
self._interpolation_thread.start()
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
# data_version = len(z_data) at the time of the request; keep the latest to gate results.
self._ensure_interpolation_thread()
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
self._interpolation_thread.start()
self._latest_interpolation_version = request.data_version
self.interpolation_requested.emit(request, request.data_version)
def _on_interpolation_finished(
self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str
):
# Only accept results that match the latest dispatched version for the active scan.
if data_version == self._latest_interpolation_version and scan_id == self.scan_id:
self._apply_image_update(img, transform)
else:
logger.info("Discarding outdated interpolation result.")
self._maybe_start_pending_interpolation()
def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str):
logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}")
self._maybe_start_pending_interpolation()
def _finish_interpolation_thread(self):
self._pending_interpolation_request = None
if self._interpolation_worker is not None:
try:
self.interpolation_requested.disconnect(self._interpolation_worker.process)
except (TypeError, RuntimeError) as ext:
logger.warning(f"Processing thread already disconnected: {ext}")
pass
self._interpolation_worker.deleteLater()
self._interpolation_worker = None
if self._interpolation_thread is not None:
if self._interpolation_thread.isRunning():
self._interpolation_thread.quit()
if not self._interpolation_thread.wait(3000): # 3s timeout
logger.error(
f"Interpolation thread of widget {self.gui_id} did not stop within timeout 3s; leaving it dangling."
)
self._interpolation_thread.deleteLater()
self._interpolation_thread = None
logger.info(f"Interpolation thread finished of widget {self.gui_id}")
def _maybe_start_pending_interpolation(self):
if self._pending_interpolation_request is None:
return
if self._pending_interpolation_request.scan_id != self.scan_id:
self._pending_interpolation_request = None
return
if self._interpolation_worker is not None and self._interpolation_worker.is_processing:
return
pending = self._pending_interpolation_request
self._pending_interpolation_request = None
self._start_step_scan_interpolation(pending)
def _cancel_interpolation(self):
"""Cancel any pending interpolation request without invalidating in-flight work.
This clears the pending request queue but does not invalidate in-flight work,
allowing any currently running interpolation to complete and update the display
if it matches the current scan.
"""
self._pending_interpolation_request = None
# Do not change the active data version so an in-flight worker can still deliver.
def _invalidate_interpolation_generation(self):
"""Invalidate all pending interpolation results and ignore in-flight updates."""
self._pending_interpolation_request = None
self._latest_interpolation_version = -1
def redraw_config_label(self):
scan_msg = self.status_message
if scan_msg is None:
return
if not self._image_config.show_config_label:
self.config_label.setVisible(False)
return
self.config_label.setOffset((-30, 1))
self.config_label.setVisible(True)
self.config_label.clear()
self.config_label.addItem(self.plot_item, f"Scan: {scan_msg.scan_number}")
self.config_label.addItem(self.plot_item, f"Scan Name: {scan_msg.scan_name}")
if scan_msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
self.config_label.addItem(
self.plot_item, f"Interpolation: {self._image_config.interpolation}"
)
self.config_label.addItem(
self.plot_item, f"Oversampling: {self._image_config.oversampling_factor}x"
)
def get_image_data(
self,
x_data: list[float] | None = None,
y_data: list[float] | None = None,
z_data: list[float] | None = None,
) -> tuple[np.ndarray | None, QTransform | None]:
"""
Get the image data for the heatmap. Depending on the scan type, it will
either pre-allocate the grid (grid_scan) or interpolate the data (step scan).
Args:
x_data (np.ndarray): The x data.
y_data (np.ndarray): The y data.
z_data (np.ndarray): The z data.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
msg = self.status_message
if x_data is None or y_data is None or z_data is None or msg is None:
logger.warning("x, y, or z data is None; skipping update.")
return None, None
if self._is_grid_scan_supported(msg):
return self.get_grid_scan_image(z_data, msg)
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
return None, None
return self.get_step_scan_image(x_data, y_data, z_data)
def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool:
"""Check if the scan can use optimized grid_scan rendering.
Grid scans can avoid interpolation if both X and Y devices match
the configured devices and interpolation is not enforced.
Args:
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
Returns:
True if grid_scan optimization is applicable, False otherwise
"""
if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
return False
signal_x = self._image_config.device_x.signal
signal_y = self._image_config.device_y.signal
return (
signal_x in msg.request_inputs["arg_bundle"]
and signal_y in msg.request_inputs["arg_bundle"]
)
def get_grid_scan_image(
self, z_data: list[float], msg: messages.ScanStatusMessage
) -> tuple[np.ndarray, QTransform]:
"""
Get the image data for a grid scan.
Args:
z_data (np.ndarray): The z data.
msg (messages.ScanStatusMessage): The scan status message.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
args = self.arg_bundle_to_dict(4, msg.request_inputs["arg_bundle"])
signal_x = self._image_config.device_x.signal
signal_y = self._image_config.device_y.signal
shape = (args[signal_x][-1], args[signal_y][-1])
data = self.main_image.raw_data
if data is None or data.shape != shape:
data = np.empty(shape)
data.fill(np.nan)
elif self.reload:
data.fill(np.nan)
snaked = msg.request_inputs["kwargs"].get("snaked", True)
slow_entry, fast_entry = (
msg.request_inputs["arg_bundle"][0],
msg.request_inputs["arg_bundle"][4],
)
scan_pos = np.asarray(msg.info["positions"], dtype=float)
relative = bool(msg.request_inputs["kwargs"].get("relative", False))
def _axis_column(entry: str) -> int:
return 0 if entry == slow_entry else 1
def _axis_levels(entry: str, npts: int) -> np.ndarray:
start, stop = args[entry][:2]
if relative:
origin = float(scan_pos[0, _axis_column(entry)] - start)
return origin + np.linspace(start, stop, npts)
return np.linspace(start, stop, npts)
x_levels = _axis_levels(signal_x, shape[0])
y_levels = _axis_levels(signal_y, shape[1])
pixel_size_x = (
float(x_levels[-1] - x_levels[0]) / max(shape[0] - 1, 1) if shape[0] > 1 else 1.0
)
pixel_size_y = (
float(y_levels[-1] - y_levels[0]) / max(shape[1] - 1, 1) if shape[1] > 1 else 1.0
)
transform = QTransform()
transform.scale(pixel_size_x, pixel_size_y)
transform.translate(x_levels[0] / pixel_size_x - 0.5, y_levels[0] / pixel_size_y - 0.5)
# Fill the data array with the z values
if self._grid_index is None or self.reload:
self._grid_index = 0
self.reload = False
for i in range(self._grid_index, len(z_data)):
slow_i, fast_i = divmod(i, args[fast_entry][-1])
if snaked and (slow_i % 2 == 1):
fast_i = args[fast_entry][-1] - 1 - fast_i
if signal_x == fast_entry:
x_i, y_i = fast_i, slow_i
else:
x_i, y_i = slow_i, fast_i
data[x_i, y_i] = z_data[i]
self._grid_index = len(z_data)
return data, transform
def get_step_scan_image(
self, x_data: list[float], y_data: list[float], z_data: list[float]
) -> tuple[np.ndarray, QTransform]:
"""
Get the image data for an arbitrary step scan.
Args:
x_data (list[float]): The x data.
y_data (list[float]): The y data.
z_data (list[float]): The z data.
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
return self.compute_step_scan_image(
x_data=x_data,
y_data=y_data,
z_data=z_data,
oversampling_factor=self._image_config.oversampling_factor,
interpolation_method=self._image_config.interpolation,
)
@staticmethod
def compute_step_scan_image(
x_data: list[float] | np.ndarray,
y_data: list[float] | np.ndarray,
z_data: list[float] | np.ndarray,
oversampling_factor: float,
interpolation_method: str,
) -> tuple[np.ndarray, QTransform]:
"""Compute interpolated heatmap image from step-scan data.
This static method is suitable for execution in a background thread
as it doesn't access any instance state.
Args:
x_data(list[float]): X coordinates of data points
y_data(list[float]): Y coordinates of data points
z_data(list[float]): Z values at each point
oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution)
interpolation_method(str): One of 'linear', 'nearest', or 'clough'
Returns:
(tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates
"""
xy_data = np.column_stack((x_data, y_data))
grid_x, grid_y, transform = Heatmap.build_image_grid(
positions=xy_data, oversampling_factor=oversampling_factor
)
if interpolation_method == "linear":
interp = LinearNDInterpolator(xy_data, z_data)
elif interpolation_method == "nearest":
interp = NearestNDInterpolator(xy_data, z_data)
elif interpolation_method == "clough":
interp = CloughTocher2DInterpolator(xy_data, z_data)
else: # pragma: no cover - guarded by validation
raise ValueError(
"Interpolation method must be either 'linear', 'nearest', or 'clough'."
)
grid_z = interp(grid_x, grid_y)
return grid_z, transform
def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]:
"""
LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id
to avoid recalculating the grid for the same scan.
Args:
positions: positions of the data points.
Returns:
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
"""
return self.build_image_grid(
positions=positions, oversampling_factor=self._image_config.oversampling_factor
)
@staticmethod
def build_image_grid(
positions: np.ndarray, oversampling_factor: float
) -> tuple[np.ndarray, np.ndarray, QTransform]:
"""Build an interpolation grid covering the data positions.
Args:
positions: (N, 2) array of (x, y) coordinates
oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution)
Returns:
Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids
for interpolation and transform maps pixel to world coordinates
"""
base_width, base_height = Heatmap.estimate_image_resolution(positions)
width = max(1, int(base_width * oversampling_factor))
height = max(1, int(base_height * oversampling_factor))
grid_x, grid_y = np.mgrid[
min(positions[:, 0]) : max(positions[:, 0]) : width * 1j,
min(positions[:, 1]) : max(positions[:, 1]) : height * 1j,
]
x_min, x_max = min(positions[:, 0]), max(positions[:, 0])
y_min, y_max = min(positions[:, 1]), max(positions[:, 1])
x_range = x_max - x_min
y_range = y_max - y_min
x_scale = x_range / width
y_scale = y_range / height
transform = QTransform()
transform.scale(x_scale, y_scale)
transform.translate(x_min / x_scale - 0.5, y_min / y_scale - 0.5)
return grid_x, grid_y, transform
@staticmethod
def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]:
"""
Estimate the number of pixels needed for the image based on the coordinates.
Args:
coords (np.ndarray): The coordinates of the points.
Returns:
tuple[int, int]: The estimated width and height of the image."""
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError("Input must be an (m x 2) array of (x, y) coordinates.")
x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
tree = cKDTree(coords)
distances, _ = tree.query(coords, k=2)
distances = distances[:, 1] # Get the second nearest neighbor distance
avg_distance = np.mean(distances)
width_extent = x_max - x_min
height_extent = y_max - y_min
# Calculate the number of pixels needed based on the average distance
width_pixels = int(np.ceil(width_extent / avg_distance))
height_pixels = int(np.ceil(height_extent / avg_distance))
return max(1, width_pixels), max(1, height_pixels)
@staticmethod
def arg_bundle_to_dict(bundle_size: int, args: list) -> dict:
"""
Convert the argument bundle to a dictionary.
Args:
bundle_size (int): The size of each argument bundle.
args (list): The argument bundle.
Returns:
dict: The dictionary representation of the argument bundle.
"""
params = {}
for cmds in partition(bundle_size, args):
params[cmds[0]] = list(cmds[1:])
return params
def _fetch_scan_data_and_access(self):
"""
Decide whether the widget is in live or historical mode
and return the appropriate data dict and access key.
Returns:
data_dict (dict): The data structure for the current scan.
access_key (str): Either 'val' (live) or 'value' (history).
"""
if self.scan_item is None:
# Optionally fetch the latest from history if nothing is set
# self.update_with_scan_history(-1)
if self.scan_item is None:
logger.info("No scan executed so far; skipping update.")
return "none", "none"
if hasattr(self.scan_item, "live_data"):
# Live scan
return self.scan_item.live_data, "val"
# Historical
scan_devices = self.scan_item.devices
return scan_devices, "value"
def reset(self):
self._cancel_interpolation()
self._grid_index = None
self.main_image.clear()
if self.crosshair is not None:
self.crosshair.reset()
super().reset()
################################################################################
# Widget Specific Properties
################################################################################
@SafeProperty(str)
def device_x(self) -> str:
"""Device name for the X axis."""
if self._image_config.device_x is None:
return ""
return self._image_config.device_x.device or ""
@device_x.setter
def device_x(self, device_name: str) -> None:
"""
Set the X device name.
Args:
device_name(str): Device name for the X axis
"""
device_name = device_name or ""
# Get current entry or validate
if device_name:
try:
signal = self.entry_validator.validate_signal(device_name, None)
self._image_config.device_x = HeatmapDeviceSignal(device=device_name, signal=signal)
self.property_changed.emit("device_x", device_name)
self.update_labels() # Update axis labels
self._try_auto_plot()
except Exception:
pass # Silently fail if device is not available yet
else:
self._image_config.device_x = None
self.property_changed.emit("device_x", "")
self.update_labels() # Clear axis labels
@SafeProperty(str)
def signal_x(self) -> str:
"""Signal entry for the X axis device."""
if self._image_config.device_x is None:
return ""
return self._image_config.device_x.signal or ""
@signal_x.setter
def signal_x(self, entry: str) -> None:
"""
Set the X device entry.
Args:
entry(str): Signal entry for the X axis device
"""
if not entry:
return
if self._image_config.device_x is None:
logger.warning("Cannot set signal_x without device_x set first.")
return
device_name = self._image_config.device_x.device
try:
# Validate the entry for this device
validated_signal = self.entry_validator.validate_signal(device_name, entry)
self._image_config.device_x = HeatmapDeviceSignal(
device=device_name, signal=validated_signal
)
self.property_changed.emit("signal_x", validated_signal)
self.update_labels() # Update axis labels
self._try_auto_plot()
except Exception:
pass # Silently fail if validation fails
@SafeProperty(str)
def device_y(self) -> str:
"""Device name for the Y axis."""
if self._image_config.device_y is None:
return ""
return self._image_config.device_y.device or ""
@device_y.setter
def device_y(self, device_name: str) -> None:
"""
Set the Y device name.
Args:
device_name(str): Device name for the Y axis
"""
device_name = device_name or ""
# Get current entry or validate
if device_name:
try:
signal = self.entry_validator.validate_signal(device_name, None)
self._image_config.device_y = HeatmapDeviceSignal(device=device_name, signal=signal)
self.property_changed.emit("device_y", device_name)
self.update_labels() # Update axis labels
self._try_auto_plot()
except Exception:
pass # Silently fail if device is not available yet
else:
self._image_config.device_y = None
self.property_changed.emit("device_y", "")
self.update_labels() # Clear axis labels
@SafeProperty(str)
def signal_y(self) -> str:
"""Signal entry for the Y axis device."""
if self._image_config.device_y is None:
return ""
return self._image_config.device_y.signal or ""
@signal_y.setter
def signal_y(self, entry: str) -> None:
"""
Set the Y device entry.
Args:
entry(str): Signal entry for the Y axis device
"""
if not entry:
return
if self._image_config.device_y is None:
logger.warning("Cannot set signal_y without device_y set first.")
return
device_name = self._image_config.device_y.device
try:
# Validate the entry for this device
validated_signal = self.entry_validator.validate_signal(device_name, entry)
self._image_config.device_y = HeatmapDeviceSignal(
device=device_name, signal=validated_signal
)
self.property_changed.emit("signal_y", validated_signal)
self.update_labels() # Update axis labels
self._try_auto_plot()
except Exception as e:
logger.debug(f"Y device entry validation failed: {e}")
pass # Silently fail if validation fails
@SafeProperty(str)
def device_z(self) -> str:
"""Device name for the Z (color) axis."""
if self._image_config.device_z is None:
return ""
return self._image_config.device_z.device or ""
@device_z.setter
def device_z(self, device_name: str) -> None:
"""
Set the Z device name.
Args:
device_name(str): Device name for the Z axis
"""
device_name = device_name or ""
# Get current entry or validate
if device_name:
try:
signal = self.entry_validator.validate_signal(device_name, None)
self._image_config.device_z = HeatmapDeviceSignal(device=device_name, signal=signal)
self.property_changed.emit("device_z", device_name)
self.update_labels() # Update axis labels (title)
self._try_auto_plot()
except Exception as e:
logger.debug(f"Z device name validation failed: {e}")
pass # Silently fail if device is not available yet
else:
self._image_config.device_z = None
self.property_changed.emit("device_z", "")
self.update_labels() # Clear axis labels
@SafeProperty(str)
def signal_z(self) -> str:
"""Signal entry for the Z (color) axis device."""
if self._image_config.device_z is None:
return ""
return self._image_config.device_z.signal or ""
@signal_z.setter
def signal_z(self, entry: str) -> None:
"""
Set the Z device entry.
Args:
entry(str): Signal entry for the Z axis device
"""
if not entry:
return
if self._image_config.device_z is None:
logger.warning("Cannot set signal_z without device_z set first.")
return
device_name = self._image_config.device_z.device
try:
# Validate the entry for this device
validated_signal = self.entry_validator.validate_signal(device_name, entry)
self._image_config.device_z = HeatmapDeviceSignal(
device=device_name, signal=validated_signal
)
self.property_changed.emit("signal_z", validated_signal)
self.update_labels() # Update axis labels (title)
self._try_auto_plot()
except Exception as e:
logger.debug(f"Z device entry validation failed: {e}")
pass # Silently fail if validation fails
def _try_auto_plot(self) -> None:
"""
Attempt to automatically call plot() if all three devices are set.
Similar to waveform's approach but requires all three devices.
"""
has_x = self._image_config.device_x is not None
has_y = self._image_config.device_y is not None
has_z = self._image_config.device_z is not None
if has_x and has_y and has_z:
device_x = self._image_config.device_x.device
signal_x = self._image_config.device_x.signal
device_y = self._image_config.device_y.device
signal_y = self._image_config.device_y.signal
device_z = self._image_config.device_z.device
signal_z = self._image_config.device_z.signal
try:
self.plot(
device_x=device_x,
device_y=device_y,
device_z=device_z,
signal_x=signal_x,
signal_y=signal_y,
signal_z=signal_z,
validate_bec=False, # Don't validate - entries already validated
)
except Exception as e:
logger.debug(f"Auto-plot failed: {e}")
pass # Silently fail if plot cannot be called yet
@SafeProperty(str)
def interpolation_method(self) -> str:
"""
The interpolation method used for the heatmap.
"""
return self._image_config.interpolation
@interpolation_method.setter
def interpolation_method(self, value: str):
"""
Set the interpolation method for the heatmap.
Args:
value(str): The interpolation method, either 'linear' or 'nearest'.
"""
if value not in ["linear", "nearest"]:
raise ValueError("Interpolation method must be either 'linear' or 'nearest'.")
self._image_config.interpolation = value
self.heatmap_property_changed.emit()
@SafeProperty(float)
def oversampling_factor(self) -> float:
"""
The oversampling factor for grid resolution.
"""
return self._image_config.oversampling_factor
@oversampling_factor.setter
def oversampling_factor(self, value: float):
"""
Set the oversampling factor for grid resolution.
Args:
value(float): The oversampling factor (1.0 = no oversampling, 2.0 = 2x resolution).
"""
if value <= 0:
raise ValueError("Oversampling factor must be greater than 0.")
self._image_config.oversampling_factor = value
self.heatmap_property_changed.emit()
@SafeProperty(bool)
def enforce_interpolation(self) -> bool:
"""
Whether to enforce interpolation even for grid scans.
"""
return self._image_config.enforce_interpolation
@enforce_interpolation.setter
def enforce_interpolation(self, value: bool):
"""
Set whether to enforce interpolation even for grid scans.
Args:
value(bool): Whether to enforce interpolation.
"""
self._image_config.enforce_interpolation = value
self.heatmap_property_changed.emit()
################################################################################
# Post Processing
################################################################################
@SafeProperty(bool, auto_emit=True)
def fft(self) -> bool:
"""
Whether FFT postprocessing is enabled.
"""
return self.main_image.fft
@fft.setter
def fft(self, enable: bool):
"""
Set FFT postprocessing.
Args:
enable(bool): Whether to enable FFT postprocessing.
"""
self.main_image.fft = enable
@SafeProperty(bool, auto_emit=True)
def log(self) -> bool:
"""
Whether logarithmic scaling is applied.
"""
return self.main_image.log
@log.setter
def log(self, enable: bool):
"""
Set logarithmic scaling.
Args:
enable(bool): Whether to enable logarithmic scaling.
"""
self.main_image.log = enable
@SafeProperty(int)
def num_rotation_90(self) -> int:
"""
The number of 90° rotations to apply counterclockwise.
"""
return self.main_image.num_rotation_90
@num_rotation_90.setter
def num_rotation_90(self, value: int):
"""
Set the number of 90° rotations to apply counterclockwise.
Args:
value(int): The number of 90° rotations to apply.
"""
self.main_image.num_rotation_90 = value
@SafeProperty(bool, auto_emit=True)
def transpose(self) -> bool:
"""
Whether the image is transposed.
"""
return self.main_image.transpose
@transpose.setter
def transpose(self, enable: bool):
"""
Set the image to be transposed.
Args:
enable(bool): Whether to enable transposing the image.
"""
self.main_image.transpose = enable
def cleanup(self):
self._finish_interpolation_thread()
super().cleanup()
if __name__ == "__main__": # pragma: no cover
import sys
from qtpy.QtWidgets import QApplication
app = QApplication(sys.argv)
heatmap = Heatmap()
heatmap.plot(device_x="samx", device_y="samy", device_z="bpm4i", oversampling_factor=5.0)
heatmap.show()
sys.exit(app.exec_())