1
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2026-03-04 16:02:51 +01:00

fix(heatmap): interpolation of the image moved to separate thread

This commit is contained in:
2025-11-26 21:32:31 +01:00
parent caa4e449e4
commit c354a9b249
2 changed files with 432 additions and 29 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Literal
import numpy as np
@@ -8,7 +9,7 @@ import pyqtgraph as pg
from bec_lib import bec_logger, messages
from bec_lib.endpoints import MessageEndpoints
from pydantic import BaseModel, Field, field_validator
from qtpy.QtCore import QTimer, Signal
from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal
from qtpy.QtGui import QTransform
from scipy.interpolate import (
CloughTocher2DInterpolator,
@@ -78,6 +79,85 @@ class HeatmapConfig(ConnectionConfig):
_validate_color_palette = field_validator("color_map")(Colors.validate_color_map)
@dataclass
class _InterpolationRequest:
x_data: list[float]
y_data: list[float]
z_data: list[float]
scan_id: str
interpolation: str
oversampling_factor: float
class _StepInterpolationWorker(QObject):
"""Worker for performing step-scan interpolation in a background thread.
This worker computes the interpolated heatmap image using the provided data
and settings, then emits the result or a failure signal.
Args:
x_data (list[float] or np.ndarray): The x-coordinates of the data points.
y_data (list[float] or np.ndarray): The y-coordinates of the data points.
z_data (list[float] or np.ndarray): The z-values (intensity) of the data points.
interpolation (str): The interpolation method to use.
oversampling_factor (float): The oversampling factor for the interpolation grid.
generation (int): The generation number for tracking requests.
scan_id (str): The scan identifier.
parent (QObject | None, optional): The parent QObject. Defaults to None.
Signals:
finished(image, transform, generation, scan_id):
Emitted when interpolation is successful.
- image: The resulting image (numpy array or similar).
- transform: The QTransform for the image.
- generation: The generation number.
- scan_id: The scan identifier.
failed(error_message, generation, scan_id):
Emitted when interpolation fails.
- error_message: The error message string.
- generation: The generation number.
- scan_id: The scan identifier.
"""
finished = Signal(object, object, int, str)
failed = Signal(str, int, str)
def __init__(
self,
x_data: list[float] | np.ndarray,
y_data: list[float] | np.ndarray,
z_data: list[float] | np.ndarray,
interpolation: str,
oversampling_factor: float,
generation: int,
scan_id: str,
parent: QObject | None = None,
):
super().__init__(parent=parent)
self._x_data = np.asarray(x_data, dtype=float)
self._y_data = np.asarray(y_data, dtype=float)
self._z_data = np.asarray(z_data, dtype=float)
self._interpolation = interpolation
self._oversampling_factor = oversampling_factor
self._generation = generation
self._scan_id = scan_id
def run(self):
try:
image, transform = Heatmap.compute_step_scan_image(
x_data=self._x_data,
y_data=self._y_data,
z_data=self._z_data,
oversampling_factor=self._oversampling_factor,
interpolation_method=self._interpolation,
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(f"Step-scan interpolation failed with: {exc}")
self.failed.emit(str(exc), self._generation, self._scan_id)
return
self.finished.emit(image, transform, self._generation, self._scan_id)
class Heatmap(ImageBase):
"""
Heatmap widget for visualizing 2d grid data with color mapping for the z-axis.
@@ -150,6 +230,10 @@ class Heatmap(ImageBase):
self.scan_item = None
self.status_message = None
self._grid_index = None
self._interpolation_generation = 0
self._interpolation_thread: QThread | None = None
self._interpolation_worker: _StepInterpolationWorker | None = None
self._pending_interpolation_request: _InterpolationRequest | None = None
self.heatmap_dialog = None
bg_color = pg.mkColor((240, 240, 240, 150))
self.config_label = pg.LegendItem(
@@ -426,6 +510,7 @@ class Heatmap(ImageBase):
if current_scan_id is None:
return
if current_scan_id != self.scan_id:
self._invalidate_interpolation_generation()
self.reset()
self.new_scan.emit()
self.new_scan_id.emit(current_scan_id)
@@ -531,13 +616,38 @@ class Heatmap(ImageBase):
if self._image_config.show_config_label:
self.redraw_config_label()
img, transform = self.get_image_data(x_data=x_data, y_data=y_data, z_data=z_data)
if img is None:
if self._is_grid_scan_supported(scan_msg):
img, transform = self.get_grid_scan_image(z_data, scan_msg)
self._apply_image_update(img, transform)
return
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
logger.warning("Not enough data points to interpolate; skipping update.")
return
self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg)
def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None):
"""Apply interpolated image and transform to the heatmap display.
This method updates the main image with the computed data and emits
the image_updated signal. Color bar signals are temporarily blocked
during the update to prevent cascading updates.
Args:
img(np.ndarray): The interpolated image data, or None if unavailable
transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable
"""
if img is None or transform is None:
logger.warning("Image data is None; skipping update.")
return
if self._color_bar is not None:
self._color_bar.blockSignals(True)
if self.main_image is None:
logger.warning("Main image item is None; cannot update image.")
return
self.main_image.set_data(img, transform=transform)
if self._color_bar is not None:
self._color_bar.blockSignals(False)
@@ -545,6 +655,122 @@ class Heatmap(ImageBase):
if self.crosshair is not None:
self.crosshair.update_markers_on_image_change()
def _request_step_scan_interpolation(
self,
x_data: list[float],
y_data: list[float],
z_data: list[float],
msg: messages.ScanStatusMessage,
):
"""Request step-scan interpolation in a background thread.
If a thread is already running, the request is queued as a pending request
and will be processed when the current interpolation completes.
Args:
x_data(list[float]): X coordinates of data points
y_data(list[float]): Y coordinates of data points
z_data(list[float]): Z values at each point
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
"""
request = _InterpolationRequest(
x_data=list(x_data),
y_data=list(y_data),
z_data=list(z_data),
scan_id=msg.scan_id,
interpolation=self._image_config.interpolation,
oversampling_factor=self._image_config.oversampling_factor,
)
if self._interpolation_thread is not None:
self._pending_interpolation_request = request
return
self._start_step_scan_interpolation(request)
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
self._interpolation_generation += 1
generation = self._interpolation_generation
self._interpolation_thread = QThread()
self._interpolation_worker = _StepInterpolationWorker(
x_data=request.x_data,
y_data=request.y_data,
z_data=request.z_data,
interpolation=request.interpolation,
oversampling_factor=request.oversampling_factor,
generation=generation,
scan_id=request.scan_id,
)
self._interpolation_worker.moveToThread(self._interpolation_thread)
self._interpolation_thread.started.connect(
self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.finished.connect(
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.failed.connect(
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
)
self._interpolation_thread.start()
def _on_interpolation_finished(
self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str
):
if generation == self._interpolation_generation and scan_id == self.scan_id:
self._apply_image_update(img, transform)
else:
logger.debug("Discarding outdated interpolation result.")
self._finish_interpolation_thread()
self._maybe_start_pending_interpolation()
def _on_interpolation_failed(self, error: str, generation: int, scan_id: str):
logger.warning(
"Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error
)
self._finish_interpolation_thread()
self._maybe_start_pending_interpolation()
def _finish_interpolation_thread(self):
if self._interpolation_worker is not None:
self._interpolation_worker.deleteLater()
self._interpolation_worker = None
if self._interpolation_thread is not None:
self._interpolation_thread.quit()
self._interpolation_thread.wait()
self._interpolation_thread.deleteLater()
self._interpolation_thread = None
def _maybe_start_pending_interpolation(self):
if self._pending_interpolation_request is None:
return
if self._pending_interpolation_request.scan_id != self.scan_id:
self._pending_interpolation_request = None
return
pending = self._pending_interpolation_request
self._pending_interpolation_request = None
self._start_step_scan_interpolation(pending)
def _cancel_interpolation(self):
"""Cancel any pending interpolation request without invalidating in-flight work.
This clears the pending request queue but does not increment the generation
counter, allowing any currently running interpolation to complete and update
the display if it matches the current scan.
"""
self._pending_interpolation_request = None
# Do not bump generation so an in-flight worker can still deliver the latest scan image.
def _invalidate_interpolation_generation(self):
"""Invalidate all in-flight and pending interpolation results.
Increments the generation counter so that any currently running or
queued interpolation work will be discarded when it completes.
This is typically called when starting a new scan.
"""
# Bump the generation so any in-flight worker results are ignored.
self._interpolation_generation += 1
def redraw_config_label(self):
scan_msg = self.status_message
if scan_msg is None:
@@ -590,21 +816,35 @@ class Heatmap(ImageBase):
logger.warning("x, y, or z data is None; skipping update.")
return None, None
if msg.scan_name == "grid_scan" and not self._image_config.enforce_interpolation:
# We only support the grid scan mode if both scanning motors
# are configured in the heatmap config.
device_x = self._image_config.x_device.entry
device_y = self._image_config.y_device.entry
if (
device_x in msg.request_inputs["arg_bundle"]
and device_y in msg.request_inputs["arg_bundle"]
):
return self.get_grid_scan_image(z_data, msg)
if self._is_grid_scan_supported(msg):
return self.get_grid_scan_image(z_data, msg)
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
return None, None
return self.get_step_scan_image(x_data, y_data, z_data, msg)
def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool:
"""Check if the scan can use optimized grid_scan rendering.
Grid scans can avoid interpolation if both X and Y devices match
the configured devices and interpolation is not enforced.
Args:
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
Returns:
True if grid_scan optimization is applicable, False otherwise
"""
if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
return False
device_x = self._image_config.x_device.entry
device_y = self._image_config.y_device.entry
return (
device_x in msg.request_inputs["arg_bundle"]
and device_y in msg.request_inputs["arg_bundle"]
)
def get_grid_scan_image(
self, z_data: list[float], msg: messages.ScanStatusMessage
) -> tuple[np.ndarray, QTransform]:
@@ -704,17 +944,49 @@ class Heatmap(ImageBase):
Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform.
"""
xy_data = np.column_stack((x_data, y_data))
grid_x, grid_y, transform = self.get_image_grid(xy_data)
return self.compute_step_scan_image(
x_data=x_data,
y_data=y_data,
z_data=z_data,
oversampling_factor=self._image_config.oversampling_factor,
interpolation_method=self._image_config.interpolation,
)
# Interpolate the z data onto the grid
if self._image_config.interpolation == "linear":
@staticmethod
def compute_step_scan_image(
x_data: list[float] | np.ndarray,
y_data: list[float] | np.ndarray,
z_data: list[float] | np.ndarray,
oversampling_factor: float,
interpolation_method: str,
) -> tuple[np.ndarray, QTransform]:
"""Compute interpolated heatmap image from step-scan data.
This static method is suitable for execution in a background thread
as it doesn't access any instance state.
Args:
x_data(list[float]): X coordinates of data points
y_data(list[float]): Y coordinates of data points
z_data(list[float]): Z values at each point
oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution)
interpolation_method(str): One of 'linear', 'nearest', or 'clough'
Returns:
(tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates
"""
xy_data = np.column_stack((x_data, y_data))
grid_x, grid_y, transform = Heatmap.build_image_grid(
positions=xy_data, oversampling_factor=oversampling_factor
)
if interpolation_method == "linear":
interp = LinearNDInterpolator(xy_data, z_data)
elif self._image_config.interpolation == "nearest":
elif interpolation_method == "nearest":
interp = NearestNDInterpolator(xy_data, z_data)
elif self._image_config.interpolation == "clough":
elif interpolation_method == "clough":
interp = CloughTocher2DInterpolator(xy_data, z_data)
else:
else: # pragma: no cover - guarded by validation
raise ValueError(
"Interpolation method must be either 'linear', 'nearest', or 'clough'."
)
@@ -733,22 +1005,33 @@ class Heatmap(ImageBase):
Returns:
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
"""
base_width, base_height = self.estimate_image_resolution(positions)
return self.build_image_grid(
positions=positions, oversampling_factor=self._image_config.oversampling_factor
)
# Apply oversampling factor
factor = self._image_config.oversampling_factor
@staticmethod
def build_image_grid(
positions: np.ndarray, oversampling_factor: float
) -> tuple[np.ndarray, np.ndarray, QTransform]:
"""Build an interpolation grid covering the data positions.
# Apply oversampling
width = int(base_width * factor)
height = int(base_height * factor)
Args:
positions: (N, 2) array of (x, y) coordinates
oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution)
Returns:
Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids
for interpolation and transform maps pixel to world coordinates
"""
base_width, base_height = Heatmap.estimate_image_resolution(positions)
width = max(1, int(base_width * oversampling_factor))
height = max(1, int(base_height * oversampling_factor))
# Create grid
grid_x, grid_y = np.mgrid[
min(positions[:, 0]) : max(positions[:, 0]) : width * 1j,
min(positions[:, 1]) : max(positions[:, 1]) : height * 1j,
]
# Calculate transform
x_min, x_max = min(positions[:, 0]), max(positions[:, 0])
y_min, y_max = min(positions[:, 1]), max(positions[:, 1])
x_range = x_max - x_min
@@ -832,6 +1115,7 @@ class Heatmap(ImageBase):
return scan_devices, "value"
def reset(self):
self._cancel_interpolation()
self._grid_index = None
self.main_image.clear()
if self.crosshair is not None:
@@ -966,6 +1250,10 @@ class Heatmap(ImageBase):
"""
self.main_image.transpose = enable
def cleanup(self):
self._finish_interpolation_thread()
super().cleanup()
if __name__ == "__main__": # pragma: no cover
import sys

View File

@@ -4,9 +4,15 @@ import numpy as np
import pytest
from bec_lib import messages
from bec_lib.scan_history import ScanHistory
from qtpy.QtGui import QTransform
from qtpy.QtCore import QPointF
from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap, HeatmapConfig, HeatmapDeviceSignal
from bec_widgets.widgets.plots.heatmap.heatmap import (
Heatmap,
HeatmapConfig,
HeatmapDeviceSignal,
_StepInterpolationWorker,
)
# pytest: disable=unused-import
from tests.unit_tests.client_mocks import mocked_client
@@ -448,12 +454,16 @@ def test_heatmap_widget_reset(heatmap_widget):
"""
Test that the reset method clears the plot.
"""
heatmap_widget._pending_interpolation_request = object()
heatmap_widget._interpolation_generation = 5
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
heatmap_widget.reset()
assert heatmap_widget._grid_index is None
assert heatmap_widget.main_image.raw_data is None
assert heatmap_widget._pending_interpolation_request is None
assert heatmap_widget._interpolation_generation == 5
def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot):
@@ -478,3 +488,108 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_
heatmap_widget.enforce_interpolation = True
heatmap_widget.oversampling_factor = 2.0
qtbot.waitUntil(lambda: heatmap_widget.main_image.raw_data.shape == (20, 20))
def test_step_interpolation_worker_emits_finished(qtbot):
worker = _StepInterpolationWorker(
x_data=[0.0, 1.0, 0.5, 0.2],
y_data=[0.0, 0.0, 1.0, 1.0],
z_data=[1.0, 2.0, 3.0, 4.0],
interpolation="linear",
oversampling_factor=1.0,
generation=1,
scan_id="scan-1",
)
with qtbot.waitSignal(worker.finished, timeout=1000) as blocker:
worker.run()
img, transform, generation, scan_id = blocker.args
assert img.shape[0] > 0
assert isinstance(transform, QTransform)
assert generation == 1
assert scan_id == "scan-1"
def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
def _scan_goes_boom(**kwargs):
raise RuntimeError("crash")
monkeypatch.setattr(
"bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom
)
worker = _StepInterpolationWorker(
x_data=[0.0, 1.0, 0.5, 0.2],
y_data=[0.0, 0.0, 1.0, 1.0],
z_data=[1.0, 2.0, 3.0, 4.0],
interpolation="linear",
oversampling_factor=1.0,
generation=99,
scan_id="scan-err",
)
with qtbot.waitSignal(worker.failed, timeout=1000) as blocker:
worker.run()
error, generation, scan_id = blocker.args
assert "crash" in error
assert generation == 99
assert scan_id == "scan-err"
def test_interpolation_generation_invalidation(heatmap_widget):
heatmap_widget.scan_id = "scan-1"
heatmap_widget._interpolation_generation = 2
with (
mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock,
mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock,
mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock,
):
heatmap_widget._on_interpolation_finished(
np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1"
)
apply_mock.assert_not_called()
finish_mock.assert_called_once()
maybe_mock.assert_called_once()
def test_pending_request_queueing_and_start(heatmap_widget):
heatmap_widget.scan_id = "scan-queue"
heatmap_widget.status_message = messages.ScanStatusMessage(
scan_id="scan-queue",
status="open",
scan_name="step_scan",
scan_type="step",
metadata={},
info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]},
)
heatmap_widget._interpolation_thread = object() # simulate running thread
with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock:
heatmap_widget._request_step_scan_interpolation(
x_data=[0, 1, 2, 3],
y_data=[0, 1, 2, 3],
z_data=[0, 1, 2, 3],
msg=heatmap_widget.status_message,
)
assert heatmap_widget._pending_interpolation_request is not None
# Now simulate worker finished and thread cleaned up
heatmap_widget._interpolation_thread = None
pending = heatmap_widget._pending_interpolation_request
heatmap_widget._pending_interpolation_request = pending
heatmap_widget._maybe_start_pending_interpolation()
start_mock.assert_called_once()
def test_finish_interpolation_thread_cleans_references(heatmap_widget):
worker_mock = mock.Mock()
thread_mock = mock.Mock()
heatmap_widget._interpolation_worker = worker_mock
heatmap_widget._interpolation_thread = thread_mock
heatmap_widget._finish_interpolation_thread()
worker_mock.deleteLater.assert_called_once()
thread_mock.quit.assert_called_once()
thread_mock.wait.assert_called_once()
thread_mock.deleteLater.assert_called_once()
assert heatmap_widget._interpolation_worker is None
assert heatmap_widget._interpolation_thread is None