mirror of
https://github.com/bec-project/bec_widgets.git
synced 2026-03-04 16:02:51 +01:00
fix(heatmap): interpolation of the image moved to separate thread
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
@@ -8,7 +9,7 @@ import pyqtgraph as pg
|
||||
from bec_lib import bec_logger, messages
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from qtpy.QtCore import QTimer, Signal
|
||||
from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal
|
||||
from qtpy.QtGui import QTransform
|
||||
from scipy.interpolate import (
|
||||
CloughTocher2DInterpolator,
|
||||
@@ -78,6 +79,85 @@ class HeatmapConfig(ConnectionConfig):
|
||||
_validate_color_palette = field_validator("color_map")(Colors.validate_color_map)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InterpolationRequest:
|
||||
x_data: list[float]
|
||||
y_data: list[float]
|
||||
z_data: list[float]
|
||||
scan_id: str
|
||||
interpolation: str
|
||||
oversampling_factor: float
|
||||
|
||||
|
||||
class _StepInterpolationWorker(QObject):
|
||||
"""Worker for performing step-scan interpolation in a background thread.
|
||||
|
||||
This worker computes the interpolated heatmap image using the provided data
|
||||
and settings, then emits the result or a failure signal.
|
||||
|
||||
Args:
|
||||
x_data (list[float] or np.ndarray): The x-coordinates of the data points.
|
||||
y_data (list[float] or np.ndarray): The y-coordinates of the data points.
|
||||
z_data (list[float] or np.ndarray): The z-values (intensity) of the data points.
|
||||
interpolation (str): The interpolation method to use.
|
||||
oversampling_factor (float): The oversampling factor for the interpolation grid.
|
||||
generation (int): The generation number for tracking requests.
|
||||
scan_id (str): The scan identifier.
|
||||
parent (QObject | None, optional): The parent QObject. Defaults to None.
|
||||
|
||||
Signals:
|
||||
finished(image, transform, generation, scan_id):
|
||||
Emitted when interpolation is successful.
|
||||
- image: The resulting image (numpy array or similar).
|
||||
- transform: The QTransform for the image.
|
||||
- generation: The generation number.
|
||||
- scan_id: The scan identifier.
|
||||
failed(error_message, generation, scan_id):
|
||||
Emitted when interpolation fails.
|
||||
- error_message: The error message string.
|
||||
- generation: The generation number.
|
||||
- scan_id: The scan identifier.
|
||||
"""
|
||||
|
||||
finished = Signal(object, object, int, str)
|
||||
failed = Signal(str, int, str)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_data: list[float] | np.ndarray,
|
||||
y_data: list[float] | np.ndarray,
|
||||
z_data: list[float] | np.ndarray,
|
||||
interpolation: str,
|
||||
oversampling_factor: float,
|
||||
generation: int,
|
||||
scan_id: str,
|
||||
parent: QObject | None = None,
|
||||
):
|
||||
super().__init__(parent=parent)
|
||||
self._x_data = np.asarray(x_data, dtype=float)
|
||||
self._y_data = np.asarray(y_data, dtype=float)
|
||||
self._z_data = np.asarray(z_data, dtype=float)
|
||||
self._interpolation = interpolation
|
||||
self._oversampling_factor = oversampling_factor
|
||||
self._generation = generation
|
||||
self._scan_id = scan_id
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
image, transform = Heatmap.compute_step_scan_image(
|
||||
x_data=self._x_data,
|
||||
y_data=self._y_data,
|
||||
z_data=self._z_data,
|
||||
oversampling_factor=self._oversampling_factor,
|
||||
interpolation_method=self._interpolation,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(f"Step-scan interpolation failed with: {exc}")
|
||||
self.failed.emit(str(exc), self._generation, self._scan_id)
|
||||
return
|
||||
self.finished.emit(image, transform, self._generation, self._scan_id)
|
||||
|
||||
|
||||
class Heatmap(ImageBase):
|
||||
"""
|
||||
Heatmap widget for visualizing 2d grid data with color mapping for the z-axis.
|
||||
@@ -150,6 +230,10 @@ class Heatmap(ImageBase):
|
||||
self.scan_item = None
|
||||
self.status_message = None
|
||||
self._grid_index = None
|
||||
self._interpolation_generation = 0
|
||||
self._interpolation_thread: QThread | None = None
|
||||
self._interpolation_worker: _StepInterpolationWorker | None = None
|
||||
self._pending_interpolation_request: _InterpolationRequest | None = None
|
||||
self.heatmap_dialog = None
|
||||
bg_color = pg.mkColor((240, 240, 240, 150))
|
||||
self.config_label = pg.LegendItem(
|
||||
@@ -426,6 +510,7 @@ class Heatmap(ImageBase):
|
||||
if current_scan_id is None:
|
||||
return
|
||||
if current_scan_id != self.scan_id:
|
||||
self._invalidate_interpolation_generation()
|
||||
self.reset()
|
||||
self.new_scan.emit()
|
||||
self.new_scan_id.emit(current_scan_id)
|
||||
@@ -531,13 +616,38 @@ class Heatmap(ImageBase):
|
||||
if self._image_config.show_config_label:
|
||||
self.redraw_config_label()
|
||||
|
||||
img, transform = self.get_image_data(x_data=x_data, y_data=y_data, z_data=z_data)
|
||||
if img is None:
|
||||
if self._is_grid_scan_supported(scan_msg):
|
||||
img, transform = self.get_grid_scan_image(z_data, scan_msg)
|
||||
self._apply_image_update(img, transform)
|
||||
return
|
||||
|
||||
if len(z_data) < 4:
|
||||
# LinearNDInterpolator requires at least 4 points to interpolate
|
||||
logger.warning("Not enough data points to interpolate; skipping update.")
|
||||
return
|
||||
|
||||
self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg)
|
||||
|
||||
def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None):
|
||||
"""Apply interpolated image and transform to the heatmap display.
|
||||
|
||||
This method updates the main image with the computed data and emits
|
||||
the image_updated signal. Color bar signals are temporarily blocked
|
||||
during the update to prevent cascading updates.
|
||||
|
||||
Args:
|
||||
img(np.ndarray): The interpolated image data, or None if unavailable
|
||||
transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable
|
||||
"""
|
||||
if img is None or transform is None:
|
||||
logger.warning("Image data is None; skipping update.")
|
||||
return
|
||||
|
||||
if self._color_bar is not None:
|
||||
self._color_bar.blockSignals(True)
|
||||
if self.main_image is None:
|
||||
logger.warning("Main image item is None; cannot update image.")
|
||||
return
|
||||
self.main_image.set_data(img, transform=transform)
|
||||
if self._color_bar is not None:
|
||||
self._color_bar.blockSignals(False)
|
||||
@@ -545,6 +655,122 @@ class Heatmap(ImageBase):
|
||||
if self.crosshair is not None:
|
||||
self.crosshair.update_markers_on_image_change()
|
||||
|
||||
def _request_step_scan_interpolation(
|
||||
self,
|
||||
x_data: list[float],
|
||||
y_data: list[float],
|
||||
z_data: list[float],
|
||||
msg: messages.ScanStatusMessage,
|
||||
):
|
||||
"""Request step-scan interpolation in a background thread.
|
||||
|
||||
If a thread is already running, the request is queued as a pending request
|
||||
and will be processed when the current interpolation completes.
|
||||
|
||||
Args:
|
||||
x_data(list[float]): X coordinates of data points
|
||||
y_data(list[float]): Y coordinates of data points
|
||||
z_data(list[float]): Z values at each point
|
||||
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
|
||||
"""
|
||||
request = _InterpolationRequest(
|
||||
x_data=list(x_data),
|
||||
y_data=list(y_data),
|
||||
z_data=list(z_data),
|
||||
scan_id=msg.scan_id,
|
||||
interpolation=self._image_config.interpolation,
|
||||
oversampling_factor=self._image_config.oversampling_factor,
|
||||
)
|
||||
|
||||
if self._interpolation_thread is not None:
|
||||
self._pending_interpolation_request = request
|
||||
return
|
||||
|
||||
self._start_step_scan_interpolation(request)
|
||||
|
||||
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
|
||||
self._interpolation_generation += 1
|
||||
generation = self._interpolation_generation
|
||||
self._interpolation_thread = QThread()
|
||||
self._interpolation_worker = _StepInterpolationWorker(
|
||||
x_data=request.x_data,
|
||||
y_data=request.y_data,
|
||||
z_data=request.z_data,
|
||||
interpolation=request.interpolation,
|
||||
oversampling_factor=request.oversampling_factor,
|
||||
generation=generation,
|
||||
scan_id=request.scan_id,
|
||||
)
|
||||
self._interpolation_worker.moveToThread(self._interpolation_thread)
|
||||
self._interpolation_thread.started.connect(
|
||||
self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.finished.connect(
|
||||
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.failed.connect(
|
||||
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_thread.start()
|
||||
|
||||
def _on_interpolation_finished(
|
||||
self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str
|
||||
):
|
||||
if generation == self._interpolation_generation and scan_id == self.scan_id:
|
||||
self._apply_image_update(img, transform)
|
||||
else:
|
||||
logger.debug("Discarding outdated interpolation result.")
|
||||
self._finish_interpolation_thread()
|
||||
self._maybe_start_pending_interpolation()
|
||||
|
||||
def _on_interpolation_failed(self, error: str, generation: int, scan_id: str):
|
||||
logger.warning(
|
||||
"Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error
|
||||
)
|
||||
self._finish_interpolation_thread()
|
||||
self._maybe_start_pending_interpolation()
|
||||
|
||||
def _finish_interpolation_thread(self):
|
||||
if self._interpolation_worker is not None:
|
||||
self._interpolation_worker.deleteLater()
|
||||
self._interpolation_worker = None
|
||||
if self._interpolation_thread is not None:
|
||||
self._interpolation_thread.quit()
|
||||
self._interpolation_thread.wait()
|
||||
self._interpolation_thread.deleteLater()
|
||||
self._interpolation_thread = None
|
||||
|
||||
def _maybe_start_pending_interpolation(self):
|
||||
if self._pending_interpolation_request is None:
|
||||
return
|
||||
if self._pending_interpolation_request.scan_id != self.scan_id:
|
||||
self._pending_interpolation_request = None
|
||||
return
|
||||
|
||||
pending = self._pending_interpolation_request
|
||||
self._pending_interpolation_request = None
|
||||
self._start_step_scan_interpolation(pending)
|
||||
|
||||
def _cancel_interpolation(self):
|
||||
"""Cancel any pending interpolation request without invalidating in-flight work.
|
||||
|
||||
This clears the pending request queue but does not increment the generation
|
||||
counter, allowing any currently running interpolation to complete and update
|
||||
the display if it matches the current scan.
|
||||
"""
|
||||
self._pending_interpolation_request = None
|
||||
# Do not bump generation so an in-flight worker can still deliver the latest scan image.
|
||||
|
||||
def _invalidate_interpolation_generation(self):
|
||||
"""Invalidate all in-flight and pending interpolation results.
|
||||
|
||||
Increments the generation counter so that any currently running or
|
||||
queued interpolation work will be discarded when it completes.
|
||||
This is typically called when starting a new scan.
|
||||
"""
|
||||
# Bump the generation so any in-flight worker results are ignored.
|
||||
self._interpolation_generation += 1
|
||||
|
||||
def redraw_config_label(self):
|
||||
scan_msg = self.status_message
|
||||
if scan_msg is None:
|
||||
@@ -590,21 +816,35 @@ class Heatmap(ImageBase):
|
||||
logger.warning("x, y, or z data is None; skipping update.")
|
||||
return None, None
|
||||
|
||||
if msg.scan_name == "grid_scan" and not self._image_config.enforce_interpolation:
|
||||
# We only support the grid scan mode if both scanning motors
|
||||
# are configured in the heatmap config.
|
||||
device_x = self._image_config.x_device.entry
|
||||
device_y = self._image_config.y_device.entry
|
||||
if (
|
||||
device_x in msg.request_inputs["arg_bundle"]
|
||||
and device_y in msg.request_inputs["arg_bundle"]
|
||||
):
|
||||
return self.get_grid_scan_image(z_data, msg)
|
||||
if self._is_grid_scan_supported(msg):
|
||||
return self.get_grid_scan_image(z_data, msg)
|
||||
if len(z_data) < 4:
|
||||
# LinearNDInterpolator requires at least 4 points to interpolate
|
||||
return None, None
|
||||
return self.get_step_scan_image(x_data, y_data, z_data, msg)
|
||||
|
||||
def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool:
|
||||
"""Check if the scan can use optimized grid_scan rendering.
|
||||
|
||||
Grid scans can avoid interpolation if both X and Y devices match
|
||||
the configured devices and interpolation is not enforced.
|
||||
|
||||
Args:
|
||||
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
|
||||
|
||||
Returns:
|
||||
True if grid_scan optimization is applicable, False otherwise
|
||||
"""
|
||||
if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
|
||||
return False
|
||||
|
||||
device_x = self._image_config.x_device.entry
|
||||
device_y = self._image_config.y_device.entry
|
||||
return (
|
||||
device_x in msg.request_inputs["arg_bundle"]
|
||||
and device_y in msg.request_inputs["arg_bundle"]
|
||||
)
|
||||
|
||||
def get_grid_scan_image(
|
||||
self, z_data: list[float], msg: messages.ScanStatusMessage
|
||||
) -> tuple[np.ndarray, QTransform]:
|
||||
@@ -704,17 +944,49 @@ class Heatmap(ImageBase):
|
||||
Returns:
|
||||
tuple[np.ndarray, QTransform]: The image data and the QTransform.
|
||||
"""
|
||||
xy_data = np.column_stack((x_data, y_data))
|
||||
grid_x, grid_y, transform = self.get_image_grid(xy_data)
|
||||
return self.compute_step_scan_image(
|
||||
x_data=x_data,
|
||||
y_data=y_data,
|
||||
z_data=z_data,
|
||||
oversampling_factor=self._image_config.oversampling_factor,
|
||||
interpolation_method=self._image_config.interpolation,
|
||||
)
|
||||
|
||||
# Interpolate the z data onto the grid
|
||||
if self._image_config.interpolation == "linear":
|
||||
@staticmethod
|
||||
def compute_step_scan_image(
|
||||
x_data: list[float] | np.ndarray,
|
||||
y_data: list[float] | np.ndarray,
|
||||
z_data: list[float] | np.ndarray,
|
||||
oversampling_factor: float,
|
||||
interpolation_method: str,
|
||||
) -> tuple[np.ndarray, QTransform]:
|
||||
"""Compute interpolated heatmap image from step-scan data.
|
||||
|
||||
This static method is suitable for execution in a background thread
|
||||
as it doesn't access any instance state.
|
||||
|
||||
Args:
|
||||
x_data(list[float]): X coordinates of data points
|
||||
y_data(list[float]): Y coordinates of data points
|
||||
z_data(list[float]): Z values at each point
|
||||
oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution)
|
||||
interpolation_method(str): One of 'linear', 'nearest', or 'clough'
|
||||
|
||||
Returns:
|
||||
(tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates
|
||||
"""
|
||||
xy_data = np.column_stack((x_data, y_data))
|
||||
grid_x, grid_y, transform = Heatmap.build_image_grid(
|
||||
positions=xy_data, oversampling_factor=oversampling_factor
|
||||
)
|
||||
|
||||
if interpolation_method == "linear":
|
||||
interp = LinearNDInterpolator(xy_data, z_data)
|
||||
elif self._image_config.interpolation == "nearest":
|
||||
elif interpolation_method == "nearest":
|
||||
interp = NearestNDInterpolator(xy_data, z_data)
|
||||
elif self._image_config.interpolation == "clough":
|
||||
elif interpolation_method == "clough":
|
||||
interp = CloughTocher2DInterpolator(xy_data, z_data)
|
||||
else:
|
||||
else: # pragma: no cover - guarded by validation
|
||||
raise ValueError(
|
||||
"Interpolation method must be either 'linear', 'nearest', or 'clough'."
|
||||
)
|
||||
@@ -733,22 +1005,33 @@ class Heatmap(ImageBase):
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
|
||||
"""
|
||||
base_width, base_height = self.estimate_image_resolution(positions)
|
||||
return self.build_image_grid(
|
||||
positions=positions, oversampling_factor=self._image_config.oversampling_factor
|
||||
)
|
||||
|
||||
# Apply oversampling factor
|
||||
factor = self._image_config.oversampling_factor
|
||||
@staticmethod
|
||||
def build_image_grid(
|
||||
positions: np.ndarray, oversampling_factor: float
|
||||
) -> tuple[np.ndarray, np.ndarray, QTransform]:
|
||||
"""Build an interpolation grid covering the data positions.
|
||||
|
||||
# Apply oversampling
|
||||
width = int(base_width * factor)
|
||||
height = int(base_height * factor)
|
||||
Args:
|
||||
positions: (N, 2) array of (x, y) coordinates
|
||||
oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution)
|
||||
|
||||
Returns:
|
||||
Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids
|
||||
for interpolation and transform maps pixel to world coordinates
|
||||
"""
|
||||
base_width, base_height = Heatmap.estimate_image_resolution(positions)
|
||||
width = max(1, int(base_width * oversampling_factor))
|
||||
height = max(1, int(base_height * oversampling_factor))
|
||||
|
||||
# Create grid
|
||||
grid_x, grid_y = np.mgrid[
|
||||
min(positions[:, 0]) : max(positions[:, 0]) : width * 1j,
|
||||
min(positions[:, 1]) : max(positions[:, 1]) : height * 1j,
|
||||
]
|
||||
|
||||
# Calculate transform
|
||||
x_min, x_max = min(positions[:, 0]), max(positions[:, 0])
|
||||
y_min, y_max = min(positions[:, 1]), max(positions[:, 1])
|
||||
x_range = x_max - x_min
|
||||
@@ -832,6 +1115,7 @@ class Heatmap(ImageBase):
|
||||
return scan_devices, "value"
|
||||
|
||||
def reset(self):
|
||||
self._cancel_interpolation()
|
||||
self._grid_index = None
|
||||
self.main_image.clear()
|
||||
if self.crosshair is not None:
|
||||
@@ -966,6 +1250,10 @@ class Heatmap(ImageBase):
|
||||
"""
|
||||
self.main_image.transpose = enable
|
||||
|
||||
def cleanup(self):
|
||||
self._finish_interpolation_thread()
|
||||
super().cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
import sys
|
||||
|
||||
@@ -4,9 +4,15 @@ import numpy as np
|
||||
import pytest
|
||||
from bec_lib import messages
|
||||
from bec_lib.scan_history import ScanHistory
|
||||
from qtpy.QtGui import QTransform
|
||||
from qtpy.QtCore import QPointF
|
||||
|
||||
from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap, HeatmapConfig, HeatmapDeviceSignal
|
||||
from bec_widgets.widgets.plots.heatmap.heatmap import (
|
||||
Heatmap,
|
||||
HeatmapConfig,
|
||||
HeatmapDeviceSignal,
|
||||
_StepInterpolationWorker,
|
||||
)
|
||||
|
||||
# pytest: disable=unused-import
|
||||
from tests.unit_tests.client_mocks import mocked_client
|
||||
@@ -448,12 +454,16 @@ def test_heatmap_widget_reset(heatmap_widget):
|
||||
"""
|
||||
Test that the reset method clears the plot.
|
||||
"""
|
||||
heatmap_widget._pending_interpolation_request = object()
|
||||
heatmap_widget._interpolation_generation = 5
|
||||
heatmap_widget.scan_item = create_dummy_scan_item()
|
||||
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
|
||||
|
||||
heatmap_widget.reset()
|
||||
assert heatmap_widget._grid_index is None
|
||||
assert heatmap_widget.main_image.raw_data is None
|
||||
assert heatmap_widget._pending_interpolation_request is None
|
||||
assert heatmap_widget._interpolation_generation == 5
|
||||
|
||||
|
||||
def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot):
|
||||
@@ -478,3 +488,108 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_
|
||||
heatmap_widget.enforce_interpolation = True
|
||||
heatmap_widget.oversampling_factor = 2.0
|
||||
qtbot.waitUntil(lambda: heatmap_widget.main_image.raw_data.shape == (20, 20))
|
||||
|
||||
|
||||
def test_step_interpolation_worker_emits_finished(qtbot):
|
||||
worker = _StepInterpolationWorker(
|
||||
x_data=[0.0, 1.0, 0.5, 0.2],
|
||||
y_data=[0.0, 0.0, 1.0, 1.0],
|
||||
z_data=[1.0, 2.0, 3.0, 4.0],
|
||||
interpolation="linear",
|
||||
oversampling_factor=1.0,
|
||||
generation=1,
|
||||
scan_id="scan-1",
|
||||
)
|
||||
with qtbot.waitSignal(worker.finished, timeout=1000) as blocker:
|
||||
worker.run()
|
||||
img, transform, generation, scan_id = blocker.args
|
||||
assert img.shape[0] > 0
|
||||
assert isinstance(transform, QTransform)
|
||||
assert generation == 1
|
||||
assert scan_id == "scan-1"
|
||||
|
||||
|
||||
def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
|
||||
def _scan_goes_boom(**kwargs):
|
||||
raise RuntimeError("crash")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom
|
||||
)
|
||||
worker = _StepInterpolationWorker(
|
||||
x_data=[0.0, 1.0, 0.5, 0.2],
|
||||
y_data=[0.0, 0.0, 1.0, 1.0],
|
||||
z_data=[1.0, 2.0, 3.0, 4.0],
|
||||
interpolation="linear",
|
||||
oversampling_factor=1.0,
|
||||
generation=99,
|
||||
scan_id="scan-err",
|
||||
)
|
||||
with qtbot.waitSignal(worker.failed, timeout=1000) as blocker:
|
||||
worker.run()
|
||||
error, generation, scan_id = blocker.args
|
||||
assert "crash" in error
|
||||
assert generation == 99
|
||||
assert scan_id == "scan-err"
|
||||
|
||||
|
||||
def test_interpolation_generation_invalidation(heatmap_widget):
|
||||
heatmap_widget.scan_id = "scan-1"
|
||||
heatmap_widget._interpolation_generation = 2
|
||||
with (
|
||||
mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock,
|
||||
mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock,
|
||||
mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock,
|
||||
):
|
||||
heatmap_widget._on_interpolation_finished(
|
||||
np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1"
|
||||
)
|
||||
apply_mock.assert_not_called()
|
||||
finish_mock.assert_called_once()
|
||||
maybe_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_pending_request_queueing_and_start(heatmap_widget):
|
||||
heatmap_widget.scan_id = "scan-queue"
|
||||
heatmap_widget.status_message = messages.ScanStatusMessage(
|
||||
scan_id="scan-queue",
|
||||
status="open",
|
||||
scan_name="step_scan",
|
||||
scan_type="step",
|
||||
metadata={},
|
||||
info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]},
|
||||
)
|
||||
heatmap_widget._interpolation_thread = object() # simulate running thread
|
||||
|
||||
with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock:
|
||||
heatmap_widget._request_step_scan_interpolation(
|
||||
x_data=[0, 1, 2, 3],
|
||||
y_data=[0, 1, 2, 3],
|
||||
z_data=[0, 1, 2, 3],
|
||||
msg=heatmap_widget.status_message,
|
||||
)
|
||||
assert heatmap_widget._pending_interpolation_request is not None
|
||||
|
||||
# Now simulate worker finished and thread cleaned up
|
||||
heatmap_widget._interpolation_thread = None
|
||||
pending = heatmap_widget._pending_interpolation_request
|
||||
heatmap_widget._pending_interpolation_request = pending
|
||||
heatmap_widget._maybe_start_pending_interpolation()
|
||||
|
||||
start_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_finish_interpolation_thread_cleans_references(heatmap_widget):
|
||||
worker_mock = mock.Mock()
|
||||
thread_mock = mock.Mock()
|
||||
heatmap_widget._interpolation_worker = worker_mock
|
||||
heatmap_widget._interpolation_thread = thread_mock
|
||||
|
||||
heatmap_widget._finish_interpolation_thread()
|
||||
|
||||
worker_mock.deleteLater.assert_called_once()
|
||||
thread_mock.quit.assert_called_once()
|
||||
thread_mock.wait.assert_called_once()
|
||||
thread_mock.deleteLater.assert_called_once()
|
||||
assert heatmap_widget._interpolation_worker is None
|
||||
assert heatmap_widget._interpolation_thread is None
|
||||
|
||||
Reference in New Issue
Block a user