diff --git a/bec_widgets/widgets/plots/heatmap/heatmap.py b/bec_widgets/widgets/plots/heatmap/heatmap.py index 05501b9f..bb43852a 100644 --- a/bec_widgets/widgets/plots/heatmap/heatmap.py +++ b/bec_widgets/widgets/plots/heatmap/heatmap.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from dataclasses import dataclass from typing import Literal import numpy as np @@ -8,7 +9,7 @@ import pyqtgraph as pg from bec_lib import bec_logger, messages from bec_lib.endpoints import MessageEndpoints from pydantic import BaseModel, Field, field_validator -from qtpy.QtCore import QTimer, Signal +from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal from qtpy.QtGui import QTransform from scipy.interpolate import ( CloughTocher2DInterpolator, @@ -78,6 +79,85 @@ class HeatmapConfig(ConnectionConfig): _validate_color_palette = field_validator("color_map")(Colors.validate_color_map) +@dataclass +class _InterpolationRequest: + x_data: list[float] + y_data: list[float] + z_data: list[float] + scan_id: str + interpolation: str + oversampling_factor: float + + +class _StepInterpolationWorker(QObject): + """Worker for performing step-scan interpolation in a background thread. + + This worker computes the interpolated heatmap image using the provided data + and settings, then emits the result or a failure signal. + + Args: + x_data (list[float] or np.ndarray): The x-coordinates of the data points. + y_data (list[float] or np.ndarray): The y-coordinates of the data points. + z_data (list[float] or np.ndarray): The z-values (intensity) of the data points. + interpolation (str): The interpolation method to use. + oversampling_factor (float): The oversampling factor for the interpolation grid. + generation (int): The generation number for tracking requests. + scan_id (str): The scan identifier. + parent (QObject | None, optional): The parent QObject. Defaults to None. + + Signals: + finished(image, transform, generation, scan_id): + Emitted when interpolation is successful. + - image: The resulting image (numpy array or similar). + - transform: The QTransform for the image. + - generation: The generation number. + - scan_id: The scan identifier. + failed(error_message, generation, scan_id): + Emitted when interpolation fails. + - error_message: The error message string. + - generation: The generation number. + - scan_id: The scan identifier. + """ + + finished = Signal(object, object, int, str) + failed = Signal(str, int, str) + + def __init__( + self, + x_data: list[float] | np.ndarray, + y_data: list[float] | np.ndarray, + z_data: list[float] | np.ndarray, + interpolation: str, + oversampling_factor: float, + generation: int, + scan_id: str, + parent: QObject | None = None, + ): + super().__init__(parent=parent) + self._x_data = np.asarray(x_data, dtype=float) + self._y_data = np.asarray(y_data, dtype=float) + self._z_data = np.asarray(z_data, dtype=float) + self._interpolation = interpolation + self._oversampling_factor = oversampling_factor + self._generation = generation + self._scan_id = scan_id + + def run(self): + try: + image, transform = Heatmap.compute_step_scan_image( + x_data=self._x_data, + y_data=self._y_data, + z_data=self._z_data, + oversampling_factor=self._oversampling_factor, + interpolation_method=self._interpolation, + ) + except Exception as exc: # pragma: no cover - defensive + logger.warning(f"Step-scan interpolation failed with: {exc}") + self.failed.emit(str(exc), self._generation, self._scan_id) + return + self.finished.emit(image, transform, self._generation, self._scan_id) + + class Heatmap(ImageBase): """ Heatmap widget for visualizing 2d grid data with color mapping for the z-axis. @@ -150,6 +230,10 @@ class Heatmap(ImageBase): self.scan_item = None self.status_message = None self._grid_index = None + self._interpolation_generation = 0 + self._interpolation_thread: QThread | None = None + self._interpolation_worker: _StepInterpolationWorker | None = None + self._pending_interpolation_request: _InterpolationRequest | None = None self.heatmap_dialog = None bg_color = pg.mkColor((240, 240, 240, 150)) self.config_label = pg.LegendItem( @@ -426,6 +510,7 @@ class Heatmap(ImageBase): if current_scan_id is None: return if current_scan_id != self.scan_id: + self._invalidate_interpolation_generation() self.reset() self.new_scan.emit() self.new_scan_id.emit(current_scan_id) @@ -531,13 +616,38 @@ class Heatmap(ImageBase): if self._image_config.show_config_label: self.redraw_config_label() - img, transform = self.get_image_data(x_data=x_data, y_data=y_data, z_data=z_data) - if img is None: + if self._is_grid_scan_supported(scan_msg): + img, transform = self.get_grid_scan_image(z_data, scan_msg) + self._apply_image_update(img, transform) + return + + if len(z_data) < 4: + # LinearNDInterpolator requires at least 4 points to interpolate + logger.warning("Not enough data points to interpolate; skipping update.") + return + + self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg) + + def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None): + """Apply interpolated image and transform to the heatmap display. + + This method updates the main image with the computed data and emits + the image_updated signal. Color bar signals are temporarily blocked + during the update to prevent cascading updates. + + Args: + img(np.ndarray): The interpolated image data, or None if unavailable + transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable + """ + if img is None or transform is None: logger.warning("Image data is None; skipping update.") return if self._color_bar is not None: self._color_bar.blockSignals(True) + if self.main_image is None: + logger.warning("Main image item is None; cannot update image.") + return self.main_image.set_data(img, transform=transform) if self._color_bar is not None: self._color_bar.blockSignals(False) @@ -545,6 +655,122 @@ class Heatmap(ImageBase): if self.crosshair is not None: self.crosshair.update_markers_on_image_change() + def _request_step_scan_interpolation( + self, + x_data: list[float], + y_data: list[float], + z_data: list[float], + msg: messages.ScanStatusMessage, + ): + """Request step-scan interpolation in a background thread. + + If a thread is already running, the request is queued as a pending request + and will be processed when the current interpolation completes. + + Args: + x_data(list[float]): X coordinates of data points + y_data(list[float]): Y coordinates of data points + z_data(list[float]): Z values at each point + msg(messages.ScanStatusMessage): Scan status message containing scan metadata + """ + request = _InterpolationRequest( + x_data=list(x_data), + y_data=list(y_data), + z_data=list(z_data), + scan_id=msg.scan_id, + interpolation=self._image_config.interpolation, + oversampling_factor=self._image_config.oversampling_factor, + ) + + if self._interpolation_thread is not None: + self._pending_interpolation_request = request + return + + self._start_step_scan_interpolation(request) + + def _start_step_scan_interpolation(self, request: _InterpolationRequest): + self._interpolation_generation += 1 + generation = self._interpolation_generation + self._interpolation_thread = QThread() + self._interpolation_worker = _StepInterpolationWorker( + x_data=request.x_data, + y_data=request.y_data, + z_data=request.z_data, + interpolation=request.interpolation, + oversampling_factor=request.oversampling_factor, + generation=generation, + scan_id=request.scan_id, + ) + self._interpolation_worker.moveToThread(self._interpolation_thread) + self._interpolation_thread.started.connect( + self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection + ) + self._interpolation_worker.finished.connect( + self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection + ) + self._interpolation_worker.failed.connect( + self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection + ) + self._interpolation_thread.start() + + def _on_interpolation_finished( + self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str + ): + if generation == self._interpolation_generation and scan_id == self.scan_id: + self._apply_image_update(img, transform) + else: + logger.debug("Discarding outdated interpolation result.") + self._finish_interpolation_thread() + self._maybe_start_pending_interpolation() + + def _on_interpolation_failed(self, error: str, generation: int, scan_id: str): + logger.warning( + "Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error + ) + self._finish_interpolation_thread() + self._maybe_start_pending_interpolation() + + def _finish_interpolation_thread(self): + if self._interpolation_worker is not None: + self._interpolation_worker.deleteLater() + self._interpolation_worker = None + if self._interpolation_thread is not None: + self._interpolation_thread.quit() + self._interpolation_thread.wait() + self._interpolation_thread.deleteLater() + self._interpolation_thread = None + + def _maybe_start_pending_interpolation(self): + if self._pending_interpolation_request is None: + return + if self._pending_interpolation_request.scan_id != self.scan_id: + self._pending_interpolation_request = None + return + + pending = self._pending_interpolation_request + self._pending_interpolation_request = None + self._start_step_scan_interpolation(pending) + + def _cancel_interpolation(self): + """Cancel any pending interpolation request without invalidating in-flight work. + + This clears the pending request queue but does not increment the generation + counter, allowing any currently running interpolation to complete and update + the display if it matches the current scan. + """ + self._pending_interpolation_request = None + # Do not bump generation so an in-flight worker can still deliver the latest scan image. + + def _invalidate_interpolation_generation(self): + """Invalidate all in-flight and pending interpolation results. + + Increments the generation counter so that any currently running or + queued interpolation work will be discarded when it completes. + This is typically called when starting a new scan. + """ + # Bump the generation so any in-flight worker results are ignored. + self._interpolation_generation += 1 + def redraw_config_label(self): scan_msg = self.status_message if scan_msg is None: @@ -590,21 +816,35 @@ class Heatmap(ImageBase): logger.warning("x, y, or z data is None; skipping update.") return None, None - if msg.scan_name == "grid_scan" and not self._image_config.enforce_interpolation: - # We only support the grid scan mode if both scanning motors - # are configured in the heatmap config. - device_x = self._image_config.x_device.entry - device_y = self._image_config.y_device.entry - if ( - device_x in msg.request_inputs["arg_bundle"] - and device_y in msg.request_inputs["arg_bundle"] - ): - return self.get_grid_scan_image(z_data, msg) + if self._is_grid_scan_supported(msg): + return self.get_grid_scan_image(z_data, msg) if len(z_data) < 4: # LinearNDInterpolator requires at least 4 points to interpolate return None, None return self.get_step_scan_image(x_data, y_data, z_data, msg) + def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool: + """Check if the scan can use optimized grid_scan rendering. + + Grid scans can avoid interpolation if both X and Y devices match + the configured devices and interpolation is not enforced. + + Args: + msg(messages.ScanStatusMessage): Scan status message containing scan metadata + + Returns: + True if grid_scan optimization is applicable, False otherwise + """ + if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation: + return False + + device_x = self._image_config.x_device.entry + device_y = self._image_config.y_device.entry + return ( + device_x in msg.request_inputs["arg_bundle"] + and device_y in msg.request_inputs["arg_bundle"] + ) + def get_grid_scan_image( self, z_data: list[float], msg: messages.ScanStatusMessage ) -> tuple[np.ndarray, QTransform]: @@ -704,17 +944,49 @@ class Heatmap(ImageBase): Returns: tuple[np.ndarray, QTransform]: The image data and the QTransform. """ - xy_data = np.column_stack((x_data, y_data)) - grid_x, grid_y, transform = self.get_image_grid(xy_data) + return self.compute_step_scan_image( + x_data=x_data, + y_data=y_data, + z_data=z_data, + oversampling_factor=self._image_config.oversampling_factor, + interpolation_method=self._image_config.interpolation, + ) - # Interpolate the z data onto the grid - if self._image_config.interpolation == "linear": + @staticmethod + def compute_step_scan_image( + x_data: list[float] | np.ndarray, + y_data: list[float] | np.ndarray, + z_data: list[float] | np.ndarray, + oversampling_factor: float, + interpolation_method: str, + ) -> tuple[np.ndarray, QTransform]: + """Compute interpolated heatmap image from step-scan data. + + This static method is suitable for execution in a background thread + as it doesn't access any instance state. + + Args: + x_data(list[float]): X coordinates of data points + y_data(list[float]): Y coordinates of data points + z_data(list[float]): Z values at each point + oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution) + interpolation_method(str): One of 'linear', 'nearest', or 'clough' + + Returns: + (tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates + """ + xy_data = np.column_stack((x_data, y_data)) + grid_x, grid_y, transform = Heatmap.build_image_grid( + positions=xy_data, oversampling_factor=oversampling_factor + ) + + if interpolation_method == "linear": interp = LinearNDInterpolator(xy_data, z_data) - elif self._image_config.interpolation == "nearest": + elif interpolation_method == "nearest": interp = NearestNDInterpolator(xy_data, z_data) - elif self._image_config.interpolation == "clough": + elif interpolation_method == "clough": interp = CloughTocher2DInterpolator(xy_data, z_data) - else: + else: # pragma: no cover - guarded by validation raise ValueError( "Interpolation method must be either 'linear', 'nearest', or 'clough'." ) @@ -733,22 +1005,33 @@ class Heatmap(ImageBase): Returns: tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform. """ - base_width, base_height = self.estimate_image_resolution(positions) + return self.build_image_grid( + positions=positions, oversampling_factor=self._image_config.oversampling_factor + ) - # Apply oversampling factor - factor = self._image_config.oversampling_factor + @staticmethod + def build_image_grid( + positions: np.ndarray, oversampling_factor: float + ) -> tuple[np.ndarray, np.ndarray, QTransform]: + """Build an interpolation grid covering the data positions. - # Apply oversampling - width = int(base_width * factor) - height = int(base_height * factor) + Args: + positions: (N, 2) array of (x, y) coordinates + oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution) + + Returns: + Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids + for interpolation and transform maps pixel to world coordinates + """ + base_width, base_height = Heatmap.estimate_image_resolution(positions) + width = max(1, int(base_width * oversampling_factor)) + height = max(1, int(base_height * oversampling_factor)) - # Create grid grid_x, grid_y = np.mgrid[ min(positions[:, 0]) : max(positions[:, 0]) : width * 1j, min(positions[:, 1]) : max(positions[:, 1]) : height * 1j, ] - # Calculate transform x_min, x_max = min(positions[:, 0]), max(positions[:, 0]) y_min, y_max = min(positions[:, 1]), max(positions[:, 1]) x_range = x_max - x_min @@ -832,6 +1115,7 @@ class Heatmap(ImageBase): return scan_devices, "value" def reset(self): + self._cancel_interpolation() self._grid_index = None self.main_image.clear() if self.crosshair is not None: @@ -966,6 +1250,10 @@ class Heatmap(ImageBase): """ self.main_image.transpose = enable + def cleanup(self): + self._finish_interpolation_thread() + super().cleanup() + if __name__ == "__main__": # pragma: no cover import sys diff --git a/tests/unit_tests/test_heatmap_widget.py b/tests/unit_tests/test_heatmap_widget.py index ff2d4274..7cf83251 100644 --- a/tests/unit_tests/test_heatmap_widget.py +++ b/tests/unit_tests/test_heatmap_widget.py @@ -4,9 +4,15 @@ import numpy as np import pytest from bec_lib import messages from bec_lib.scan_history import ScanHistory +from qtpy.QtGui import QTransform from qtpy.QtCore import QPointF -from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap, HeatmapConfig, HeatmapDeviceSignal +from bec_widgets.widgets.plots.heatmap.heatmap import ( + Heatmap, + HeatmapConfig, + HeatmapDeviceSignal, + _StepInterpolationWorker, +) # pytest: disable=unused-import from tests.unit_tests.client_mocks import mocked_client @@ -448,12 +454,16 @@ def test_heatmap_widget_reset(heatmap_widget): """ Test that the reset method clears the plot. """ + heatmap_widget._pending_interpolation_request = object() + heatmap_widget._interpolation_generation = 5 heatmap_widget.scan_item = create_dummy_scan_item() heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i") heatmap_widget.reset() assert heatmap_widget._grid_index is None assert heatmap_widget.main_image.raw_data is None + assert heatmap_widget._pending_interpolation_request is None + assert heatmap_widget._interpolation_generation == 5 def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot): @@ -478,3 +488,108 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_ heatmap_widget.enforce_interpolation = True heatmap_widget.oversampling_factor = 2.0 qtbot.waitUntil(lambda: heatmap_widget.main_image.raw_data.shape == (20, 20)) + + +def test_step_interpolation_worker_emits_finished(qtbot): + worker = _StepInterpolationWorker( + x_data=[0.0, 1.0, 0.5, 0.2], + y_data=[0.0, 0.0, 1.0, 1.0], + z_data=[1.0, 2.0, 3.0, 4.0], + interpolation="linear", + oversampling_factor=1.0, + generation=1, + scan_id="scan-1", + ) + with qtbot.waitSignal(worker.finished, timeout=1000) as blocker: + worker.run() + img, transform, generation, scan_id = blocker.args + assert img.shape[0] > 0 + assert isinstance(transform, QTransform) + assert generation == 1 + assert scan_id == "scan-1" + + +def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch): + def _scan_goes_boom(**kwargs): + raise RuntimeError("crash") + + monkeypatch.setattr( + "bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom + ) + worker = _StepInterpolationWorker( + x_data=[0.0, 1.0, 0.5, 0.2], + y_data=[0.0, 0.0, 1.0, 1.0], + z_data=[1.0, 2.0, 3.0, 4.0], + interpolation="linear", + oversampling_factor=1.0, + generation=99, + scan_id="scan-err", + ) + with qtbot.waitSignal(worker.failed, timeout=1000) as blocker: + worker.run() + error, generation, scan_id = blocker.args + assert "crash" in error + assert generation == 99 + assert scan_id == "scan-err" + + +def test_interpolation_generation_invalidation(heatmap_widget): + heatmap_widget.scan_id = "scan-1" + heatmap_widget._interpolation_generation = 2 + with ( + mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock, + mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock, + mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock, + ): + heatmap_widget._on_interpolation_finished( + np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1" + ) + apply_mock.assert_not_called() + finish_mock.assert_called_once() + maybe_mock.assert_called_once() + + +def test_pending_request_queueing_and_start(heatmap_widget): + heatmap_widget.scan_id = "scan-queue" + heatmap_widget.status_message = messages.ScanStatusMessage( + scan_id="scan-queue", + status="open", + scan_name="step_scan", + scan_type="step", + metadata={}, + info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]}, + ) + heatmap_widget._interpolation_thread = object() # simulate running thread + + with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock: + heatmap_widget._request_step_scan_interpolation( + x_data=[0, 1, 2, 3], + y_data=[0, 1, 2, 3], + z_data=[0, 1, 2, 3], + msg=heatmap_widget.status_message, + ) + assert heatmap_widget._pending_interpolation_request is not None + + # Now simulate worker finished and thread cleaned up + heatmap_widget._interpolation_thread = None + pending = heatmap_widget._pending_interpolation_request + heatmap_widget._pending_interpolation_request = pending + heatmap_widget._maybe_start_pending_interpolation() + + start_mock.assert_called_once() + + +def test_finish_interpolation_thread_cleans_references(heatmap_widget): + worker_mock = mock.Mock() + thread_mock = mock.Mock() + heatmap_widget._interpolation_worker = worker_mock + heatmap_widget._interpolation_thread = thread_mock + + heatmap_widget._finish_interpolation_thread() + + worker_mock.deleteLater.assert_called_once() + thread_mock.quit.assert_called_once() + thread_mock.wait.assert_called_once() + thread_mock.deleteLater.assert_called_once() + assert heatmap_widget._interpolation_worker is None + assert heatmap_widget._interpolation_thread is None