1
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2026-03-05 00:12:49 +01:00

perf(heatmap): thread worker optimization

This commit is contained in:
2025-12-05 10:58:58 +01:00
parent c354a9b249
commit 063e5d064c
2 changed files with 118 additions and 110 deletions

View File

@@ -81,9 +81,22 @@ class HeatmapConfig(ConnectionConfig):
@dataclass
class _InterpolationRequest:
"""Immutable payload describing an interpolation request for the worker thread.
Args:
x_data: X coordinates collected so far.
y_data: Y coordinates collected so far.
z_data: Z values associated with x/y.
data_version: Number of points at request time (len(z_data)); used to reject stale results.
scan_id: Identifier for the scan that produced the data.
interpolation: Interpolation method to apply.
oversampling_factor: Oversampling factor for the interpolation grid.
"""
x_data: list[float]
y_data: list[float]
z_data: list[float]
data_version: int
scan_id: str
interpolation: str
oversampling_factor: float
@@ -95,67 +108,50 @@ class _StepInterpolationWorker(QObject):
This worker computes the interpolated heatmap image using the provided data
and settings, then emits the result or a failure signal.
Args:
x_data (list[float] or np.ndarray): The x-coordinates of the data points.
y_data (list[float] or np.ndarray): The y-coordinates of the data points.
z_data (list[float] or np.ndarray): The z-values (intensity) of the data points.
interpolation (str): The interpolation method to use.
oversampling_factor (float): The oversampling factor for the interpolation grid.
generation (int): The generation number for tracking requests.
scan_id (str): The scan identifier.
parent (QObject | None, optional): The parent QObject. Defaults to None.
Signals:
finished(image, transform, generation, scan_id):
finished(image, transform, data_version, scan_id):
Emitted when interpolation is successful.
- image: The resulting image (numpy array or similar).
- transform: The QTransform for the image.
- generation: The generation number.
- data_version: The data version for the request.
- scan_id: The scan identifier.
failed(error_message, generation, scan_id):
failed(error_message, data_version, scan_id):
Emitted when interpolation fails.
- error_message: The error message string.
- generation: The generation number.
- data_version: The data version for the request.
- scan_id: The scan identifier.
"""
finished = Signal(object, object, int, str)
failed = Signal(str, int, str)
def __init__(
self,
x_data: list[float] | np.ndarray,
y_data: list[float] | np.ndarray,
z_data: list[float] | np.ndarray,
interpolation: str,
oversampling_factor: float,
generation: int,
scan_id: str,
parent: QObject | None = None,
):
def __init__(self, parent: QObject | None = None):
super().__init__(parent=parent)
self._x_data = np.asarray(x_data, dtype=float)
self._y_data = np.asarray(y_data, dtype=float)
self._z_data = np.asarray(z_data, dtype=float)
self._interpolation = interpolation
self._oversampling_factor = oversampling_factor
self._generation = generation
self._scan_id = scan_id
self._active_request: _InterpolationRequest | None = None
def run(self):
@SafeSlot(object, int)
def process(self, request: _InterpolationRequest, data_version: int):
"""
Process an interpolation request in the worker thread.
Args:
request(_InterpolationRequest): The interpolation request payload.
data_version(int): The data version for the request.
"""
self._active_request = request
try:
image, transform = Heatmap.compute_step_scan_image(
x_data=self._x_data,
y_data=self._y_data,
z_data=self._z_data,
oversampling_factor=self._oversampling_factor,
interpolation_method=self._interpolation,
x_data=np.asarray(request.x_data, dtype=float),
y_data=np.asarray(request.y_data, dtype=float),
z_data=np.asarray(request.z_data, dtype=float),
oversampling_factor=request.oversampling_factor,
interpolation_method=request.interpolation,
)
except Exception as exc: # pragma: no cover - defensive
logger.warning(f"Step-scan interpolation failed with: {exc}")
self.failed.emit(str(exc), self._generation, self._scan_id)
self.failed.emit(str(exc), data_version, request.scan_id)
return
self.finished.emit(image, transform, self._generation, self._scan_id)
self.finished.emit(image, transform, data_version, request.scan_id)
class Heatmap(ImageBase):
@@ -208,6 +204,7 @@ class Heatmap(ImageBase):
new_scan_id = Signal(str)
sync_signal_update = Signal()
heatmap_property_changed = Signal()
interpolation_requested = Signal(object, int)
def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs):
if config is None:
@@ -230,7 +227,9 @@ class Heatmap(ImageBase):
self.scan_item = None
self.status_message = None
self._grid_index = None
self._interpolation_generation = 0
# Highest data_version we have dispatched for the current scan; used to drop stale results.
# Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it.
self._latest_interpolation_version = -1
self._interpolation_thread: QThread | None = None
self._interpolation_worker: _StepInterpolationWorker | None = None
self._pending_interpolation_request: _InterpolationRequest | None = None
@@ -510,7 +509,7 @@ class Heatmap(ImageBase):
if current_scan_id is None:
return
if current_scan_id != self.scan_id:
self._invalidate_interpolation_generation()
self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts
self.reset()
self.new_scan.emit()
self.new_scan_id.emit(current_scan_id)
@@ -677,66 +676,77 @@ class Heatmap(ImageBase):
x_data=list(x_data),
y_data=list(y_data),
z_data=list(z_data),
data_version=len(z_data),
scan_id=msg.scan_id,
interpolation=self._image_config.interpolation,
oversampling_factor=self._image_config.oversampling_factor,
)
if self._interpolation_thread is not None:
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
self._pending_interpolation_request = request
return
self._start_step_scan_interpolation(request)
def _ensure_interpolation_thread(self):
if self._interpolation_thread is None:
self._interpolation_thread = QThread()
self._interpolation_worker = _StepInterpolationWorker()
self._interpolation_worker.moveToThread(self._interpolation_thread)
self.interpolation_requested.connect(
self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.finished.connect(
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.failed.connect(
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
)
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
self._interpolation_thread.start()
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
self._interpolation_generation += 1
generation = self._interpolation_generation
self._interpolation_thread = QThread()
self._interpolation_worker = _StepInterpolationWorker(
x_data=request.x_data,
y_data=request.y_data,
z_data=request.z_data,
interpolation=request.interpolation,
oversampling_factor=request.oversampling_factor,
generation=generation,
scan_id=request.scan_id,
)
self._interpolation_worker.moveToThread(self._interpolation_thread)
self._interpolation_thread.started.connect(
self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.finished.connect(
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
)
self._interpolation_worker.failed.connect(
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
)
self._interpolation_thread.start()
# data_version = len(z_data) at the time of the request; keep the latest to gate results.
self._ensure_interpolation_thread()
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
self._interpolation_thread.start()
self._latest_interpolation_version = request.data_version
self.interpolation_requested.emit(request, request.data_version)
def _on_interpolation_finished(
self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str
self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str
):
if generation == self._interpolation_generation and scan_id == self.scan_id:
# Only accept results that match the latest dispatched version for the active scan.
if data_version == self._latest_interpolation_version and scan_id == self.scan_id:
self._apply_image_update(img, transform)
else:
logger.debug("Discarding outdated interpolation result.")
self._finish_interpolation_thread()
logger.info("Discarding outdated interpolation result.")
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
self._interpolation_thread.quit()
self._interpolation_thread.wait()
self._maybe_start_pending_interpolation()
def _on_interpolation_failed(self, error: str, generation: int, scan_id: str):
logger.warning(
"Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error
)
self._finish_interpolation_thread()
def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str):
logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}")
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
self._interpolation_thread.quit()
self._interpolation_thread.wait()
self._maybe_start_pending_interpolation()
def _finish_interpolation_thread(self):
self._pending_interpolation_request = None
if self._interpolation_worker is not None:
try:
self.interpolation_requested.disconnect(self._interpolation_worker.process)
except (TypeError, RuntimeError):
# Defensive: disconnect may fail if already disconnected or during shutdown.
pass
self._interpolation_worker.deleteLater()
self._interpolation_worker = None
if self._interpolation_thread is not None:
self._interpolation_thread.quit()
self._interpolation_thread.wait()
if self._interpolation_thread.isRunning():
self._interpolation_thread.quit()
self._interpolation_thread.wait()
self._interpolation_thread.deleteLater()
self._interpolation_thread = None
@@ -754,22 +764,17 @@ class Heatmap(ImageBase):
def _cancel_interpolation(self):
"""Cancel any pending interpolation request without invalidating in-flight work.
This clears the pending request queue but does not increment the generation
counter, allowing any currently running interpolation to complete and update
the display if it matches the current scan.
This clears the pending request queue but does not invalidate in-flight work,
allowing any currently running interpolation to complete and update the display
if it matches the current scan.
"""
self._pending_interpolation_request = None
# Do not bump generation so an in-flight worker can still deliver the latest scan image.
# Do not change the active data version so an in-flight worker can still deliver.
def _invalidate_interpolation_generation(self):
"""Invalidate all in-flight and pending interpolation results.
Increments the generation counter so that any currently running or
queued interpolation work will be discarded when it completes.
This is typically called when starting a new scan.
"""
# Bump the generation so any in-flight worker results are ignored.
self._interpolation_generation += 1
"""Invalidate all pending interpolation results and ignore in-flight updates."""
self._pending_interpolation_request = None
self._latest_interpolation_version = -1
def redraw_config_label(self):
scan_msg = self.status_message

View File

@@ -11,6 +11,7 @@ from bec_widgets.widgets.plots.heatmap.heatmap import (
Heatmap,
HeatmapConfig,
HeatmapDeviceSignal,
_InterpolationRequest,
_StepInterpolationWorker,
)
@@ -455,7 +456,7 @@ def test_heatmap_widget_reset(heatmap_widget):
Test that the reset method clears the plot.
"""
heatmap_widget._pending_interpolation_request = object()
heatmap_widget._interpolation_generation = 5
heatmap_widget._latest_interpolation_version = 5
heatmap_widget.scan_item = create_dummy_scan_item()
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
@@ -463,7 +464,7 @@ def test_heatmap_widget_reset(heatmap_widget):
assert heatmap_widget._grid_index is None
assert heatmap_widget.main_image.raw_data is None
assert heatmap_widget._pending_interpolation_request is None
assert heatmap_widget._interpolation_generation == 5
assert heatmap_widget._latest_interpolation_version == 5
def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot):
@@ -491,22 +492,23 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_
def test_step_interpolation_worker_emits_finished(qtbot):
worker = _StepInterpolationWorker(
worker = _StepInterpolationWorker()
request = _InterpolationRequest(
x_data=[0.0, 1.0, 0.5, 0.2],
y_data=[0.0, 0.0, 1.0, 1.0],
z_data=[1.0, 2.0, 3.0, 4.0],
data_version=4,
scan_id="scan-1",
interpolation="linear",
oversampling_factor=1.0,
generation=1,
scan_id="scan-1",
)
with qtbot.waitSignal(worker.finished, timeout=1000) as blocker:
worker.run()
img, transform, generation, scan_id = blocker.args
worker.process(request, request.data_version)
img, transform, data_version, scan_id = blocker.args
assert img.shape[0] > 0
assert isinstance(transform, QTransform)
assert generation == 1
assert scan_id == "scan-1"
assert data_version == request.data_version
assert scan_id == request.scan_id
def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
@@ -516,36 +518,35 @@ def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
monkeypatch.setattr(
"bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom
)
worker = _StepInterpolationWorker(
worker = _StepInterpolationWorker()
request = _InterpolationRequest(
x_data=[0.0, 1.0, 0.5, 0.2],
y_data=[0.0, 0.0, 1.0, 1.0],
z_data=[1.0, 2.0, 3.0, 4.0],
data_version=99,
scan_id="scan-err",
interpolation="linear",
oversampling_factor=1.0,
generation=99,
scan_id="scan-err",
)
with qtbot.waitSignal(worker.failed, timeout=1000) as blocker:
worker.run()
error, generation, scan_id = blocker.args
worker.process(request, request.data_version)
error, data_version, scan_id = blocker.args
assert "crash" in error
assert generation == 99
assert scan_id == "scan-err"
assert data_version == request.data_version
assert scan_id == request.scan_id
def test_interpolation_generation_invalidation(heatmap_widget):
heatmap_widget.scan_id = "scan-1"
heatmap_widget._interpolation_generation = 2
heatmap_widget._latest_interpolation_version = 2
with (
mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock,
mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock,
mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock,
):
heatmap_widget._on_interpolation_finished(
np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1"
np.zeros((2, 2)), QTransform(), data_version=1, scan_id="scan-1"
)
apply_mock.assert_not_called()
finish_mock.assert_called_once()
maybe_mock.assert_called_once()
@@ -559,7 +560,8 @@ def test_pending_request_queueing_and_start(heatmap_widget):
metadata={},
info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]},
)
heatmap_widget._interpolation_thread = object() # simulate running thread
heatmap_widget._interpolation_thread = mock.MagicMock()
heatmap_widget._interpolation_thread.isRunning.return_value = True
with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock:
heatmap_widget._request_step_scan_interpolation(
@@ -582,6 +584,7 @@ def test_pending_request_queueing_and_start(heatmap_widget):
def test_finish_interpolation_thread_cleans_references(heatmap_widget):
worker_mock = mock.Mock()
thread_mock = mock.Mock()
thread_mock.isRunning.return_value = True
heatmap_widget._interpolation_worker = worker_mock
heatmap_widget._interpolation_thread = thread_mock