diff --git a/bec_widgets/widgets/plots/heatmap/heatmap.py b/bec_widgets/widgets/plots/heatmap/heatmap.py index bb43852a..faf61453 100644 --- a/bec_widgets/widgets/plots/heatmap/heatmap.py +++ b/bec_widgets/widgets/plots/heatmap/heatmap.py @@ -81,9 +81,22 @@ class HeatmapConfig(ConnectionConfig): @dataclass class _InterpolationRequest: + """Immutable payload describing an interpolation request for the worker thread. + + Args: + x_data: X coordinates collected so far. + y_data: Y coordinates collected so far. + z_data: Z values associated with x/y. + data_version: Number of points at request time (len(z_data)); used to reject stale results. + scan_id: Identifier for the scan that produced the data. + interpolation: Interpolation method to apply. + oversampling_factor: Oversampling factor for the interpolation grid. + """ + x_data: list[float] y_data: list[float] z_data: list[float] + data_version: int scan_id: str interpolation: str oversampling_factor: float @@ -95,67 +108,50 @@ class _StepInterpolationWorker(QObject): This worker computes the interpolated heatmap image using the provided data and settings, then emits the result or a failure signal. - Args: - x_data (list[float] or np.ndarray): The x-coordinates of the data points. - y_data (list[float] or np.ndarray): The y-coordinates of the data points. - z_data (list[float] or np.ndarray): The z-values (intensity) of the data points. - interpolation (str): The interpolation method to use. - oversampling_factor (float): The oversampling factor for the interpolation grid. - generation (int): The generation number for tracking requests. - scan_id (str): The scan identifier. - parent (QObject | None, optional): The parent QObject. Defaults to None. - Signals: - finished(image, transform, generation, scan_id): + finished(image, transform, data_version, scan_id): Emitted when interpolation is successful. - image: The resulting image (numpy array or similar). - transform: The QTransform for the image. - - generation: The generation number. + - data_version: The data version for the request. - scan_id: The scan identifier. - failed(error_message, generation, scan_id): + failed(error_message, data_version, scan_id): Emitted when interpolation fails. - error_message: The error message string. - - generation: The generation number. + - data_version: The data version for the request. - scan_id: The scan identifier. """ finished = Signal(object, object, int, str) failed = Signal(str, int, str) - def __init__( - self, - x_data: list[float] | np.ndarray, - y_data: list[float] | np.ndarray, - z_data: list[float] | np.ndarray, - interpolation: str, - oversampling_factor: float, - generation: int, - scan_id: str, - parent: QObject | None = None, - ): + def __init__(self, parent: QObject | None = None): super().__init__(parent=parent) - self._x_data = np.asarray(x_data, dtype=float) - self._y_data = np.asarray(y_data, dtype=float) - self._z_data = np.asarray(z_data, dtype=float) - self._interpolation = interpolation - self._oversampling_factor = oversampling_factor - self._generation = generation - self._scan_id = scan_id + self._active_request: _InterpolationRequest | None = None - def run(self): + @SafeSlot(object, int) + def process(self, request: _InterpolationRequest, data_version: int): + """ + Process an interpolation request in the worker thread. + + Args: + request(_InterpolationRequest): The interpolation request payload. + data_version(int): The data version for the request. + """ + self._active_request = request try: image, transform = Heatmap.compute_step_scan_image( - x_data=self._x_data, - y_data=self._y_data, - z_data=self._z_data, - oversampling_factor=self._oversampling_factor, - interpolation_method=self._interpolation, + x_data=np.asarray(request.x_data, dtype=float), + y_data=np.asarray(request.y_data, dtype=float), + z_data=np.asarray(request.z_data, dtype=float), + oversampling_factor=request.oversampling_factor, + interpolation_method=request.interpolation, ) except Exception as exc: # pragma: no cover - defensive logger.warning(f"Step-scan interpolation failed with: {exc}") - self.failed.emit(str(exc), self._generation, self._scan_id) + self.failed.emit(str(exc), data_version, request.scan_id) return - self.finished.emit(image, transform, self._generation, self._scan_id) + self.finished.emit(image, transform, data_version, request.scan_id) class Heatmap(ImageBase): @@ -208,6 +204,7 @@ class Heatmap(ImageBase): new_scan_id = Signal(str) sync_signal_update = Signal() heatmap_property_changed = Signal() + interpolation_requested = Signal(object, int) def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs): if config is None: @@ -230,7 +227,9 @@ class Heatmap(ImageBase): self.scan_item = None self.status_message = None self._grid_index = None - self._interpolation_generation = 0 + # Highest data_version we have dispatched for the current scan; used to drop stale results. + # Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it. + self._latest_interpolation_version = -1 self._interpolation_thread: QThread | None = None self._interpolation_worker: _StepInterpolationWorker | None = None self._pending_interpolation_request: _InterpolationRequest | None = None @@ -510,7 +509,7 @@ class Heatmap(ImageBase): if current_scan_id is None: return if current_scan_id != self.scan_id: - self._invalidate_interpolation_generation() + self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts self.reset() self.new_scan.emit() self.new_scan_id.emit(current_scan_id) @@ -677,66 +676,77 @@ class Heatmap(ImageBase): x_data=list(x_data), y_data=list(y_data), z_data=list(z_data), + data_version=len(z_data), scan_id=msg.scan_id, interpolation=self._image_config.interpolation, oversampling_factor=self._image_config.oversampling_factor, ) - if self._interpolation_thread is not None: + if self._interpolation_thread is not None and self._interpolation_thread.isRunning(): self._pending_interpolation_request = request return self._start_step_scan_interpolation(request) + def _ensure_interpolation_thread(self): + if self._interpolation_thread is None: + self._interpolation_thread = QThread() + self._interpolation_worker = _StepInterpolationWorker() + self._interpolation_worker.moveToThread(self._interpolation_thread) + self.interpolation_requested.connect( + self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection + ) + self._interpolation_worker.finished.connect( + self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection + ) + self._interpolation_worker.failed.connect( + self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection + ) + if self._interpolation_thread is not None and not self._interpolation_thread.isRunning(): + self._interpolation_thread.start() + def _start_step_scan_interpolation(self, request: _InterpolationRequest): - self._interpolation_generation += 1 - generation = self._interpolation_generation - self._interpolation_thread = QThread() - self._interpolation_worker = _StepInterpolationWorker( - x_data=request.x_data, - y_data=request.y_data, - z_data=request.z_data, - interpolation=request.interpolation, - oversampling_factor=request.oversampling_factor, - generation=generation, - scan_id=request.scan_id, - ) - self._interpolation_worker.moveToThread(self._interpolation_thread) - self._interpolation_thread.started.connect( - self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection - ) - self._interpolation_worker.finished.connect( - self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection - ) - self._interpolation_worker.failed.connect( - self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection - ) - self._interpolation_thread.start() + # data_version = len(z_data) at the time of the request; keep the latest to gate results. + self._ensure_interpolation_thread() + if self._interpolation_thread is not None and not self._interpolation_thread.isRunning(): + self._interpolation_thread.start() + self._latest_interpolation_version = request.data_version + self.interpolation_requested.emit(request, request.data_version) def _on_interpolation_finished( - self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str + self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str ): - if generation == self._interpolation_generation and scan_id == self.scan_id: + # Only accept results that match the latest dispatched version for the active scan. + if data_version == self._latest_interpolation_version and scan_id == self.scan_id: self._apply_image_update(img, transform) else: - logger.debug("Discarding outdated interpolation result.") - self._finish_interpolation_thread() + logger.info("Discarding outdated interpolation result.") + if self._interpolation_thread is not None and self._interpolation_thread.isRunning(): + self._interpolation_thread.quit() + self._interpolation_thread.wait() self._maybe_start_pending_interpolation() - def _on_interpolation_failed(self, error: str, generation: int, scan_id: str): - logger.warning( - "Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error - ) - self._finish_interpolation_thread() + def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str): + logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}") + if self._interpolation_thread is not None and self._interpolation_thread.isRunning(): + self._interpolation_thread.quit() + self._interpolation_thread.wait() self._maybe_start_pending_interpolation() def _finish_interpolation_thread(self): + self._pending_interpolation_request = None if self._interpolation_worker is not None: + try: + self.interpolation_requested.disconnect(self._interpolation_worker.process) + except (TypeError, RuntimeError): + # Defensive: disconnect may fail if already disconnected or during shutdown. + pass self._interpolation_worker.deleteLater() self._interpolation_worker = None if self._interpolation_thread is not None: - self._interpolation_thread.quit() - self._interpolation_thread.wait() + if self._interpolation_thread.isRunning(): + self._interpolation_thread.quit() + self._interpolation_thread.wait() self._interpolation_thread.deleteLater() self._interpolation_thread = None @@ -754,22 +764,17 @@ class Heatmap(ImageBase): def _cancel_interpolation(self): """Cancel any pending interpolation request without invalidating in-flight work. - This clears the pending request queue but does not increment the generation - counter, allowing any currently running interpolation to complete and update - the display if it matches the current scan. + This clears the pending request queue but does not invalidate in-flight work, + allowing any currently running interpolation to complete and update the display + if it matches the current scan. """ self._pending_interpolation_request = None - # Do not bump generation so an in-flight worker can still deliver the latest scan image. + # Do not change the active data version so an in-flight worker can still deliver. def _invalidate_interpolation_generation(self): - """Invalidate all in-flight and pending interpolation results. - - Increments the generation counter so that any currently running or - queued interpolation work will be discarded when it completes. - This is typically called when starting a new scan. - """ - # Bump the generation so any in-flight worker results are ignored. - self._interpolation_generation += 1 + """Invalidate all pending interpolation results and ignore in-flight updates.""" + self._pending_interpolation_request = None + self._latest_interpolation_version = -1 def redraw_config_label(self): scan_msg = self.status_message diff --git a/tests/unit_tests/test_heatmap_widget.py b/tests/unit_tests/test_heatmap_widget.py index 7cf83251..d3a64da7 100644 --- a/tests/unit_tests/test_heatmap_widget.py +++ b/tests/unit_tests/test_heatmap_widget.py @@ -11,6 +11,7 @@ from bec_widgets.widgets.plots.heatmap.heatmap import ( Heatmap, HeatmapConfig, HeatmapDeviceSignal, + _InterpolationRequest, _StepInterpolationWorker, ) @@ -455,7 +456,7 @@ def test_heatmap_widget_reset(heatmap_widget): Test that the reset method clears the plot. """ heatmap_widget._pending_interpolation_request = object() - heatmap_widget._interpolation_generation = 5 + heatmap_widget._latest_interpolation_version = 5 heatmap_widget.scan_item = create_dummy_scan_item() heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i") @@ -463,7 +464,7 @@ def test_heatmap_widget_reset(heatmap_widget): assert heatmap_widget._grid_index is None assert heatmap_widget.main_image.raw_data is None assert heatmap_widget._pending_interpolation_request is None - assert heatmap_widget._interpolation_generation == 5 + assert heatmap_widget._latest_interpolation_version == 5 def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot): @@ -491,22 +492,23 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_ def test_step_interpolation_worker_emits_finished(qtbot): - worker = _StepInterpolationWorker( + worker = _StepInterpolationWorker() + request = _InterpolationRequest( x_data=[0.0, 1.0, 0.5, 0.2], y_data=[0.0, 0.0, 1.0, 1.0], z_data=[1.0, 2.0, 3.0, 4.0], + data_version=4, + scan_id="scan-1", interpolation="linear", oversampling_factor=1.0, - generation=1, - scan_id="scan-1", ) with qtbot.waitSignal(worker.finished, timeout=1000) as blocker: - worker.run() - img, transform, generation, scan_id = blocker.args + worker.process(request, request.data_version) + img, transform, data_version, scan_id = blocker.args assert img.shape[0] > 0 assert isinstance(transform, QTransform) - assert generation == 1 - assert scan_id == "scan-1" + assert data_version == request.data_version + assert scan_id == request.scan_id def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch): @@ -516,36 +518,35 @@ def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch): monkeypatch.setattr( "bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom ) - worker = _StepInterpolationWorker( + worker = _StepInterpolationWorker() + request = _InterpolationRequest( x_data=[0.0, 1.0, 0.5, 0.2], y_data=[0.0, 0.0, 1.0, 1.0], z_data=[1.0, 2.0, 3.0, 4.0], + data_version=99, + scan_id="scan-err", interpolation="linear", oversampling_factor=1.0, - generation=99, - scan_id="scan-err", ) with qtbot.waitSignal(worker.failed, timeout=1000) as blocker: - worker.run() - error, generation, scan_id = blocker.args + worker.process(request, request.data_version) + error, data_version, scan_id = blocker.args assert "crash" in error - assert generation == 99 - assert scan_id == "scan-err" + assert data_version == request.data_version + assert scan_id == request.scan_id def test_interpolation_generation_invalidation(heatmap_widget): heatmap_widget.scan_id = "scan-1" - heatmap_widget._interpolation_generation = 2 + heatmap_widget._latest_interpolation_version = 2 with ( mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock, - mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock, mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock, ): heatmap_widget._on_interpolation_finished( - np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1" + np.zeros((2, 2)), QTransform(), data_version=1, scan_id="scan-1" ) apply_mock.assert_not_called() - finish_mock.assert_called_once() maybe_mock.assert_called_once() @@ -559,7 +560,8 @@ def test_pending_request_queueing_and_start(heatmap_widget): metadata={}, info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]}, ) - heatmap_widget._interpolation_thread = object() # simulate running thread + heatmap_widget._interpolation_thread = mock.MagicMock() + heatmap_widget._interpolation_thread.isRunning.return_value = True with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock: heatmap_widget._request_step_scan_interpolation( @@ -582,6 +584,7 @@ def test_pending_request_queueing_and_start(heatmap_widget): def test_finish_interpolation_thread_cleans_references(heatmap_widget): worker_mock = mock.Mock() thread_mock = mock.Mock() + thread_mock.isRunning.return_value = True heatmap_widget._interpolation_worker = worker_mock heatmap_widget._interpolation_thread = thread_mock