mirror of
https://github.com/bec-project/bec_widgets.git
synced 2026-03-05 00:12:49 +01:00
perf(heatmap): thread worker optimization
This commit is contained in:
@@ -81,9 +81,22 @@ class HeatmapConfig(ConnectionConfig):
|
||||
|
||||
@dataclass
|
||||
class _InterpolationRequest:
|
||||
"""Immutable payload describing an interpolation request for the worker thread.
|
||||
|
||||
Args:
|
||||
x_data: X coordinates collected so far.
|
||||
y_data: Y coordinates collected so far.
|
||||
z_data: Z values associated with x/y.
|
||||
data_version: Number of points at request time (len(z_data)); used to reject stale results.
|
||||
scan_id: Identifier for the scan that produced the data.
|
||||
interpolation: Interpolation method to apply.
|
||||
oversampling_factor: Oversampling factor for the interpolation grid.
|
||||
"""
|
||||
|
||||
x_data: list[float]
|
||||
y_data: list[float]
|
||||
z_data: list[float]
|
||||
data_version: int
|
||||
scan_id: str
|
||||
interpolation: str
|
||||
oversampling_factor: float
|
||||
@@ -95,67 +108,50 @@ class _StepInterpolationWorker(QObject):
|
||||
This worker computes the interpolated heatmap image using the provided data
|
||||
and settings, then emits the result or a failure signal.
|
||||
|
||||
Args:
|
||||
x_data (list[float] or np.ndarray): The x-coordinates of the data points.
|
||||
y_data (list[float] or np.ndarray): The y-coordinates of the data points.
|
||||
z_data (list[float] or np.ndarray): The z-values (intensity) of the data points.
|
||||
interpolation (str): The interpolation method to use.
|
||||
oversampling_factor (float): The oversampling factor for the interpolation grid.
|
||||
generation (int): The generation number for tracking requests.
|
||||
scan_id (str): The scan identifier.
|
||||
parent (QObject | None, optional): The parent QObject. Defaults to None.
|
||||
|
||||
Signals:
|
||||
finished(image, transform, generation, scan_id):
|
||||
finished(image, transform, data_version, scan_id):
|
||||
Emitted when interpolation is successful.
|
||||
- image: The resulting image (numpy array or similar).
|
||||
- transform: The QTransform for the image.
|
||||
- generation: The generation number.
|
||||
- data_version: The data version for the request.
|
||||
- scan_id: The scan identifier.
|
||||
failed(error_message, generation, scan_id):
|
||||
failed(error_message, data_version, scan_id):
|
||||
Emitted when interpolation fails.
|
||||
- error_message: The error message string.
|
||||
- generation: The generation number.
|
||||
- data_version: The data version for the request.
|
||||
- scan_id: The scan identifier.
|
||||
"""
|
||||
|
||||
finished = Signal(object, object, int, str)
|
||||
failed = Signal(str, int, str)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_data: list[float] | np.ndarray,
|
||||
y_data: list[float] | np.ndarray,
|
||||
z_data: list[float] | np.ndarray,
|
||||
interpolation: str,
|
||||
oversampling_factor: float,
|
||||
generation: int,
|
||||
scan_id: str,
|
||||
parent: QObject | None = None,
|
||||
):
|
||||
def __init__(self, parent: QObject | None = None):
|
||||
super().__init__(parent=parent)
|
||||
self._x_data = np.asarray(x_data, dtype=float)
|
||||
self._y_data = np.asarray(y_data, dtype=float)
|
||||
self._z_data = np.asarray(z_data, dtype=float)
|
||||
self._interpolation = interpolation
|
||||
self._oversampling_factor = oversampling_factor
|
||||
self._generation = generation
|
||||
self._scan_id = scan_id
|
||||
self._active_request: _InterpolationRequest | None = None
|
||||
|
||||
def run(self):
|
||||
@SafeSlot(object, int)
|
||||
def process(self, request: _InterpolationRequest, data_version: int):
|
||||
"""
|
||||
Process an interpolation request in the worker thread.
|
||||
|
||||
Args:
|
||||
request(_InterpolationRequest): The interpolation request payload.
|
||||
data_version(int): The data version for the request.
|
||||
"""
|
||||
self._active_request = request
|
||||
try:
|
||||
image, transform = Heatmap.compute_step_scan_image(
|
||||
x_data=self._x_data,
|
||||
y_data=self._y_data,
|
||||
z_data=self._z_data,
|
||||
oversampling_factor=self._oversampling_factor,
|
||||
interpolation_method=self._interpolation,
|
||||
x_data=np.asarray(request.x_data, dtype=float),
|
||||
y_data=np.asarray(request.y_data, dtype=float),
|
||||
z_data=np.asarray(request.z_data, dtype=float),
|
||||
oversampling_factor=request.oversampling_factor,
|
||||
interpolation_method=request.interpolation,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(f"Step-scan interpolation failed with: {exc}")
|
||||
self.failed.emit(str(exc), self._generation, self._scan_id)
|
||||
self.failed.emit(str(exc), data_version, request.scan_id)
|
||||
return
|
||||
self.finished.emit(image, transform, self._generation, self._scan_id)
|
||||
self.finished.emit(image, transform, data_version, request.scan_id)
|
||||
|
||||
|
||||
class Heatmap(ImageBase):
|
||||
@@ -208,6 +204,7 @@ class Heatmap(ImageBase):
|
||||
new_scan_id = Signal(str)
|
||||
sync_signal_update = Signal()
|
||||
heatmap_property_changed = Signal()
|
||||
interpolation_requested = Signal(object, int)
|
||||
|
||||
def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
@@ -230,7 +227,9 @@ class Heatmap(ImageBase):
|
||||
self.scan_item = None
|
||||
self.status_message = None
|
||||
self._grid_index = None
|
||||
self._interpolation_generation = 0
|
||||
# Highest data_version we have dispatched for the current scan; used to drop stale results.
|
||||
# Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it.
|
||||
self._latest_interpolation_version = -1
|
||||
self._interpolation_thread: QThread | None = None
|
||||
self._interpolation_worker: _StepInterpolationWorker | None = None
|
||||
self._pending_interpolation_request: _InterpolationRequest | None = None
|
||||
@@ -510,7 +509,7 @@ class Heatmap(ImageBase):
|
||||
if current_scan_id is None:
|
||||
return
|
||||
if current_scan_id != self.scan_id:
|
||||
self._invalidate_interpolation_generation()
|
||||
self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts
|
||||
self.reset()
|
||||
self.new_scan.emit()
|
||||
self.new_scan_id.emit(current_scan_id)
|
||||
@@ -677,66 +676,77 @@ class Heatmap(ImageBase):
|
||||
x_data=list(x_data),
|
||||
y_data=list(y_data),
|
||||
z_data=list(z_data),
|
||||
data_version=len(z_data),
|
||||
scan_id=msg.scan_id,
|
||||
interpolation=self._image_config.interpolation,
|
||||
oversampling_factor=self._image_config.oversampling_factor,
|
||||
)
|
||||
|
||||
if self._interpolation_thread is not None:
|
||||
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
|
||||
self._pending_interpolation_request = request
|
||||
return
|
||||
|
||||
self._start_step_scan_interpolation(request)
|
||||
|
||||
def _ensure_interpolation_thread(self):
|
||||
if self._interpolation_thread is None:
|
||||
self._interpolation_thread = QThread()
|
||||
self._interpolation_worker = _StepInterpolationWorker()
|
||||
self._interpolation_worker.moveToThread(self._interpolation_thread)
|
||||
self.interpolation_requested.connect(
|
||||
self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.finished.connect(
|
||||
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.failed.connect(
|
||||
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
|
||||
self._interpolation_thread.start()
|
||||
|
||||
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
|
||||
self._interpolation_generation += 1
|
||||
generation = self._interpolation_generation
|
||||
self._interpolation_thread = QThread()
|
||||
self._interpolation_worker = _StepInterpolationWorker(
|
||||
x_data=request.x_data,
|
||||
y_data=request.y_data,
|
||||
z_data=request.z_data,
|
||||
interpolation=request.interpolation,
|
||||
oversampling_factor=request.oversampling_factor,
|
||||
generation=generation,
|
||||
scan_id=request.scan_id,
|
||||
)
|
||||
self._interpolation_worker.moveToThread(self._interpolation_thread)
|
||||
self._interpolation_thread.started.connect(
|
||||
self._interpolation_worker.run, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.finished.connect(
|
||||
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_worker.failed.connect(
|
||||
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
|
||||
)
|
||||
self._interpolation_thread.start()
|
||||
# data_version = len(z_data) at the time of the request; keep the latest to gate results.
|
||||
self._ensure_interpolation_thread()
|
||||
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
|
||||
self._interpolation_thread.start()
|
||||
self._latest_interpolation_version = request.data_version
|
||||
self.interpolation_requested.emit(request, request.data_version)
|
||||
|
||||
def _on_interpolation_finished(
|
||||
self, img: np.ndarray, transform: QTransform, generation: int, scan_id: str
|
||||
self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str
|
||||
):
|
||||
if generation == self._interpolation_generation and scan_id == self.scan_id:
|
||||
# Only accept results that match the latest dispatched version for the active scan.
|
||||
if data_version == self._latest_interpolation_version and scan_id == self.scan_id:
|
||||
self._apply_image_update(img, transform)
|
||||
else:
|
||||
logger.debug("Discarding outdated interpolation result.")
|
||||
self._finish_interpolation_thread()
|
||||
logger.info("Discarding outdated interpolation result.")
|
||||
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
|
||||
self._interpolation_thread.quit()
|
||||
self._interpolation_thread.wait()
|
||||
self._maybe_start_pending_interpolation()
|
||||
|
||||
def _on_interpolation_failed(self, error: str, generation: int, scan_id: str):
|
||||
logger.warning(
|
||||
"Interpolation failed for scan %s (generation %s): %s", scan_id, generation, error
|
||||
)
|
||||
self._finish_interpolation_thread()
|
||||
def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str):
|
||||
logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}")
|
||||
if self._interpolation_thread is not None and self._interpolation_thread.isRunning():
|
||||
self._interpolation_thread.quit()
|
||||
self._interpolation_thread.wait()
|
||||
self._maybe_start_pending_interpolation()
|
||||
|
||||
def _finish_interpolation_thread(self):
|
||||
self._pending_interpolation_request = None
|
||||
if self._interpolation_worker is not None:
|
||||
try:
|
||||
self.interpolation_requested.disconnect(self._interpolation_worker.process)
|
||||
except (TypeError, RuntimeError):
|
||||
# Defensive: disconnect may fail if already disconnected or during shutdown.
|
||||
pass
|
||||
self._interpolation_worker.deleteLater()
|
||||
self._interpolation_worker = None
|
||||
if self._interpolation_thread is not None:
|
||||
self._interpolation_thread.quit()
|
||||
self._interpolation_thread.wait()
|
||||
if self._interpolation_thread.isRunning():
|
||||
self._interpolation_thread.quit()
|
||||
self._interpolation_thread.wait()
|
||||
self._interpolation_thread.deleteLater()
|
||||
self._interpolation_thread = None
|
||||
|
||||
@@ -754,22 +764,17 @@ class Heatmap(ImageBase):
|
||||
def _cancel_interpolation(self):
|
||||
"""Cancel any pending interpolation request without invalidating in-flight work.
|
||||
|
||||
This clears the pending request queue but does not increment the generation
|
||||
counter, allowing any currently running interpolation to complete and update
|
||||
the display if it matches the current scan.
|
||||
This clears the pending request queue but does not invalidate in-flight work,
|
||||
allowing any currently running interpolation to complete and update the display
|
||||
if it matches the current scan.
|
||||
"""
|
||||
self._pending_interpolation_request = None
|
||||
# Do not bump generation so an in-flight worker can still deliver the latest scan image.
|
||||
# Do not change the active data version so an in-flight worker can still deliver.
|
||||
|
||||
def _invalidate_interpolation_generation(self):
|
||||
"""Invalidate all in-flight and pending interpolation results.
|
||||
|
||||
Increments the generation counter so that any currently running or
|
||||
queued interpolation work will be discarded when it completes.
|
||||
This is typically called when starting a new scan.
|
||||
"""
|
||||
# Bump the generation so any in-flight worker results are ignored.
|
||||
self._interpolation_generation += 1
|
||||
"""Invalidate all pending interpolation results and ignore in-flight updates."""
|
||||
self._pending_interpolation_request = None
|
||||
self._latest_interpolation_version = -1
|
||||
|
||||
def redraw_config_label(self):
|
||||
scan_msg = self.status_message
|
||||
|
||||
@@ -11,6 +11,7 @@ from bec_widgets.widgets.plots.heatmap.heatmap import (
|
||||
Heatmap,
|
||||
HeatmapConfig,
|
||||
HeatmapDeviceSignal,
|
||||
_InterpolationRequest,
|
||||
_StepInterpolationWorker,
|
||||
)
|
||||
|
||||
@@ -455,7 +456,7 @@ def test_heatmap_widget_reset(heatmap_widget):
|
||||
Test that the reset method clears the plot.
|
||||
"""
|
||||
heatmap_widget._pending_interpolation_request = object()
|
||||
heatmap_widget._interpolation_generation = 5
|
||||
heatmap_widget._latest_interpolation_version = 5
|
||||
heatmap_widget.scan_item = create_dummy_scan_item()
|
||||
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
|
||||
|
||||
@@ -463,7 +464,7 @@ def test_heatmap_widget_reset(heatmap_widget):
|
||||
assert heatmap_widget._grid_index is None
|
||||
assert heatmap_widget.main_image.raw_data is None
|
||||
assert heatmap_widget._pending_interpolation_request is None
|
||||
assert heatmap_widget._interpolation_generation == 5
|
||||
assert heatmap_widget._latest_interpolation_version == 5
|
||||
|
||||
|
||||
def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_history_msg, qtbot):
|
||||
@@ -491,22 +492,23 @@ def test_heatmap_widget_update_plot_with_scan_history(heatmap_widget, grid_scan_
|
||||
|
||||
|
||||
def test_step_interpolation_worker_emits_finished(qtbot):
|
||||
worker = _StepInterpolationWorker(
|
||||
worker = _StepInterpolationWorker()
|
||||
request = _InterpolationRequest(
|
||||
x_data=[0.0, 1.0, 0.5, 0.2],
|
||||
y_data=[0.0, 0.0, 1.0, 1.0],
|
||||
z_data=[1.0, 2.0, 3.0, 4.0],
|
||||
data_version=4,
|
||||
scan_id="scan-1",
|
||||
interpolation="linear",
|
||||
oversampling_factor=1.0,
|
||||
generation=1,
|
||||
scan_id="scan-1",
|
||||
)
|
||||
with qtbot.waitSignal(worker.finished, timeout=1000) as blocker:
|
||||
worker.run()
|
||||
img, transform, generation, scan_id = blocker.args
|
||||
worker.process(request, request.data_version)
|
||||
img, transform, data_version, scan_id = blocker.args
|
||||
assert img.shape[0] > 0
|
||||
assert isinstance(transform, QTransform)
|
||||
assert generation == 1
|
||||
assert scan_id == "scan-1"
|
||||
assert data_version == request.data_version
|
||||
assert scan_id == request.scan_id
|
||||
|
||||
|
||||
def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
|
||||
@@ -516,36 +518,35 @@ def test_step_interpolation_worker_emits_failed(qtbot, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"bec_widgets.widgets.plots.heatmap.heatmap.Heatmap.compute_step_scan_image", _scan_goes_boom
|
||||
)
|
||||
worker = _StepInterpolationWorker(
|
||||
worker = _StepInterpolationWorker()
|
||||
request = _InterpolationRequest(
|
||||
x_data=[0.0, 1.0, 0.5, 0.2],
|
||||
y_data=[0.0, 0.0, 1.0, 1.0],
|
||||
z_data=[1.0, 2.0, 3.0, 4.0],
|
||||
data_version=99,
|
||||
scan_id="scan-err",
|
||||
interpolation="linear",
|
||||
oversampling_factor=1.0,
|
||||
generation=99,
|
||||
scan_id="scan-err",
|
||||
)
|
||||
with qtbot.waitSignal(worker.failed, timeout=1000) as blocker:
|
||||
worker.run()
|
||||
error, generation, scan_id = blocker.args
|
||||
worker.process(request, request.data_version)
|
||||
error, data_version, scan_id = blocker.args
|
||||
assert "crash" in error
|
||||
assert generation == 99
|
||||
assert scan_id == "scan-err"
|
||||
assert data_version == request.data_version
|
||||
assert scan_id == request.scan_id
|
||||
|
||||
|
||||
def test_interpolation_generation_invalidation(heatmap_widget):
|
||||
heatmap_widget.scan_id = "scan-1"
|
||||
heatmap_widget._interpolation_generation = 2
|
||||
heatmap_widget._latest_interpolation_version = 2
|
||||
with (
|
||||
mock.patch.object(heatmap_widget, "_apply_image_update") as apply_mock,
|
||||
mock.patch.object(heatmap_widget, "_finish_interpolation_thread") as finish_mock,
|
||||
mock.patch.object(heatmap_widget, "_maybe_start_pending_interpolation") as maybe_mock,
|
||||
):
|
||||
heatmap_widget._on_interpolation_finished(
|
||||
np.zeros((2, 2)), QTransform(), generation=1, scan_id="scan-1"
|
||||
np.zeros((2, 2)), QTransform(), data_version=1, scan_id="scan-1"
|
||||
)
|
||||
apply_mock.assert_not_called()
|
||||
finish_mock.assert_called_once()
|
||||
maybe_mock.assert_called_once()
|
||||
|
||||
|
||||
@@ -559,7 +560,8 @@ def test_pending_request_queueing_and_start(heatmap_widget):
|
||||
metadata={},
|
||||
info={"positions": [[0, 0], [1, 1], [2, 2], [3, 3]]},
|
||||
)
|
||||
heatmap_widget._interpolation_thread = object() # simulate running thread
|
||||
heatmap_widget._interpolation_thread = mock.MagicMock()
|
||||
heatmap_widget._interpolation_thread.isRunning.return_value = True
|
||||
|
||||
with mock.patch.object(heatmap_widget, "_start_step_scan_interpolation") as start_mock:
|
||||
heatmap_widget._request_step_scan_interpolation(
|
||||
@@ -582,6 +584,7 @@ def test_pending_request_queueing_and_start(heatmap_widget):
|
||||
def test_finish_interpolation_thread_cleans_references(heatmap_widget):
|
||||
worker_mock = mock.Mock()
|
||||
thread_mock = mock.Mock()
|
||||
thread_mock.isRunning.return_value = True
|
||||
heatmap_widget._interpolation_worker = worker_mock
|
||||
heatmap_widget._interpolation_thread = thread_mock
|
||||
|
||||
|
||||
Reference in New Issue
Block a user