From 6c494258f82059a2472f43bb8287390ce1aba704 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Fri, 11 Jul 2025 13:23:50 +0200 Subject: [PATCH] fix(heatmap): fix pixel size calculation for arbitrary shapes --- bec_widgets/widgets/plots/heatmap/heatmap.py | 37 +++++++++++--------- tests/unit_tests/test_heatmap_widget.py | 9 ++++- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/bec_widgets/widgets/plots/heatmap/heatmap.py b/bec_widgets/widgets/plots/heatmap/heatmap.py index 9bf2ae02..0e9fed45 100644 --- a/bec_widgets/widgets/plots/heatmap/heatmap.py +++ b/bec_widgets/widgets/plots/heatmap/heatmap.py @@ -431,6 +431,8 @@ class Heatmap(ImageBase): if self._color_bar is not None: self._color_bar.blockSignals(False) self.image_updated.emit() + if self.crosshair is not None: + self.crosshair.update_markers_on_image_change() def get_image_data( self, @@ -457,14 +459,19 @@ class Heatmap(ImageBase): return None, None if msg.scan_name == "grid_scan": - return self.get_grid_scan_image(z_data, msg) - if msg.scan_type == "step" and msg.info["positions"]: - if len(z_data) < 4: - # LinearNDInterpolator requires at least 4 points to interpolate - return None, None - return self.get_step_scan_image(x_data, y_data, z_data, msg) - logger.warning(f"Scan type {msg.scan_name} not supported.") - return None, None + # We only support the grid scan mode if both scanning motors + # are configured in the heatmap config. + device_x = self._image_config.x_device.entry + device_y = self._image_config.y_device.entry + if ( + device_x in msg.request_inputs["arg_bundle"] + and device_y in msg.request_inputs["arg_bundle"] + ): + return self.get_grid_scan_image(z_data, msg) + if len(z_data) < 4: + # LinearNDInterpolator requires at least 4 points to interpolate + return None, None + return self.get_step_scan_image(x_data, y_data, z_data, msg) def get_grid_scan_image( self, z_data: list[float], msg: messages.ScanStatusMessage @@ -560,17 +567,16 @@ class Heatmap(ImageBase): Returns: tuple[np.ndarray, QTransform]: The image data and the QTransform. """ - - grid_x, grid_y, transform = self.get_image_grid(msg.scan_id) + xy_data = np.column_stack((x_data, y_data)) + grid_x, grid_y, transform = self.get_image_grid(xy_data) # Interpolate the z data onto the grid - interp = LinearNDInterpolator(np.column_stack((x_data, y_data)), z_data) + interp = LinearNDInterpolator(xy_data, z_data) grid_z = interp(grid_x, grid_y) return grid_z, transform - @functools.lru_cache(maxsize=2) - def get_image_grid(self, _scan_id) -> tuple[np.ndarray, np.ndarray, QTransform]: + def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]: """ LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id to avoid recalculating the grid for the same scan. @@ -581,8 +587,6 @@ class Heatmap(ImageBase): Returns: tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform. """ - msg = self.status_message - positions = np.asarray(msg.info["positions"]) width, height = self.estimate_image_resolution(positions) @@ -608,7 +612,8 @@ class Heatmap(ImageBase): return grid_x, grid_y, transform - def estimate_image_resolution(self, coords: np.ndarray) -> tuple[int, int]: + @staticmethod + def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]: """ Estimate the number of pixels needed for the image based on the coordinates. diff --git a/tests/unit_tests/test_heatmap_widget.py b/tests/unit_tests/test_heatmap_widget.py index 9be3fa03..52ed9179 100644 --- a/tests/unit_tests/test_heatmap_widget.py +++ b/tests/unit_tests/test_heatmap_widget.py @@ -62,8 +62,15 @@ def test_heatmap_get_image_data_missing_data(heatmap_widget): def test_heatmap_get_image_data_grid_scan(heatmap_widget): scan_msg = messages.ScanStatusMessage( - scan_id="123", status="open", scan_name="grid_scan", metadata={}, info={} + scan_id="123", + status="open", + scan_name="grid_scan", + metadata={}, + info={}, + request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}}, ) + heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i") + heatmap_widget.status_message = scan_msg with mock.patch.object(heatmap_widget, "get_grid_scan_image") as mock_get_grid_scan_image: heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6])