diff --git a/bec_widgets/widgets/plots/heatmap/heatmap.py b/bec_widgets/widgets/plots/heatmap/heatmap.py index e1142aa1..05501b9f 100644 --- a/bec_widgets/widgets/plots/heatmap/heatmap.py +++ b/bec_widgets/widgets/plots/heatmap/heatmap.py @@ -620,10 +620,9 @@ class Heatmap(ImageBase): args = self.arg_bundle_to_dict(4, msg.request_inputs["arg_bundle"]) - shape = ( - args[self._image_config.x_device.entry][-1], - args[self._image_config.y_device.entry][-1], - ) + x_entry = self._image_config.x_device.entry + y_entry = self._image_config.y_device.entry + shape = (args[x_entry][-1], args[y_entry][-1]) data = self.main_image.raw_data @@ -633,57 +632,39 @@ class Heatmap(ImageBase): elif self.reload: data.fill(np.nan) - def _get_grid_data(axis, snaked=True): - x_grid, y_grid = np.meshgrid(axis[0], axis[1]) - if snaked: - y_grid.T[::2] = np.fliplr(y_grid.T[::2]) - x_flat = x_grid.T.ravel() - y_flat = y_grid.T.ravel() - positions = np.vstack((x_flat, y_flat)).T - return positions - snaked = msg.request_inputs["kwargs"].get("snaked", True) - # If the scan's fast axis is x, we need to swap the x and y axes - swap = bool(msg.request_inputs["arg_bundle"][4] == self._image_config.x_device.entry) - - def _axis_bounds( - axis_values: np.ndarray, limits: list[float], axis: Literal["x", "y"] - ) -> tuple[float, float]: - start, stop = limits[:2] - ascending = start <= stop - if snaked and axis == "y": - ascending = start >= stop - if ascending: - return float(axis_values.min()), float(axis_values.max()) - return float(axis_values.max()), float(axis_values.min()) - - # calculate the QTransform to put (0,0) at the axis origin - scan_pos = np.asarray(msg.info["positions"]) - x_min, x_max = _axis_bounds( - scan_pos[:, 0], args[self._image_config.x_device.entry], axis="x" - ) - y_min, y_max = _axis_bounds( - scan_pos[:, 1], args[self._image_config.y_device.entry], axis="y" + slow_entry, fast_entry = ( + msg.request_inputs["arg_bundle"][0], + msg.request_inputs["arg_bundle"][4], ) - x_range = x_max - x_min - y_range = y_max - y_min + scan_pos = np.asarray(msg.info["positions"], dtype=float) + relative = bool(msg.request_inputs["kwargs"].get("relative", False)) - pixel_size_x = x_range / max(shape[0] - 1, 1) - pixel_size_y = y_range / max(shape[1] - 1, 1) + def _axis_column(entry: str) -> int: + return 0 if entry == slow_entry else 1 + + def _axis_levels(entry: str, npts: int) -> np.ndarray: + start, stop = args[entry][:2] + if relative: + origin = float(scan_pos[0, _axis_column(entry)] - start) + return origin + np.linspace(start, stop, npts) + return np.linspace(start, stop, npts) + + x_levels = _axis_levels(x_entry, shape[0]) + y_levels = _axis_levels(y_entry, shape[1]) + + pixel_size_x = ( + float(x_levels[-1] - x_levels[0]) / max(shape[0] - 1, 1) if shape[0] > 1 else 1.0 + ) + pixel_size_y = ( + float(y_levels[-1] - y_levels[0]) / max(shape[1] - 1, 1) if shape[1] > 1 else 1.0 + ) transform = QTransform() - if swap: - transform.scale(pixel_size_y, pixel_size_x) - transform.translate(y_min / pixel_size_y - 0.5, x_min / pixel_size_x - 0.5) - else: - transform.scale(pixel_size_x, pixel_size_y) - transform.translate(x_min / pixel_size_x - 0.5, y_min / pixel_size_y - 0.5) - - target_positions = _get_grid_data( - (np.arange(shape[int(swap)]), np.arange(shape[int(not swap)])), snaked=snaked - ) + transform.scale(pixel_size_x, pixel_size_y) + transform.translate(x_levels[0] / pixel_size_x - 0.5, y_levels[0] / pixel_size_y - 0.5) # Fill the data array with the z values if self._grid_index is None or self.reload: @@ -691,7 +672,16 @@ class Heatmap(ImageBase): self.reload = False for i in range(self._grid_index, len(z_data)): - data[target_positions[i, int(swap)], target_positions[i, int(not swap)]] = z_data[i] + slow_i, fast_i = divmod(i, args[fast_entry][-1]) + if snaked and (slow_i % 2 == 1): + fast_i = args[fast_entry][-1] - 1 - fast_i + + if x_entry == fast_entry: + x_i, y_i = fast_i, slow_i + else: + x_i, y_i = slow_i, fast_i + + data[x_i, y_i] = z_data[i] self._grid_index = len(z_data) return data, transform diff --git a/tests/unit_tests/test_heatmap_widget.py b/tests/unit_tests/test_heatmap_widget.py index 66010a0e..ff2d4274 100644 --- a/tests/unit_tests/test_heatmap_widget.py +++ b/tests/unit_tests/test_heatmap_widget.py @@ -4,6 +4,7 @@ import numpy as np import pytest from bec_lib import messages from bec_lib.scan_history import ScanHistory +from qtpy.QtCore import QPointF from bec_widgets.widgets.plots.heatmap.heatmap import Heatmap, HeatmapConfig, HeatmapDeviceSignal @@ -125,12 +126,16 @@ def test_heatmap_get_image_data_unsupported_scan(heatmap_widget): def test_heatmap_get_grid_scan_image(heatmap_widget): + x_levels = np.linspace(-5, 5, 10).tolist() + y_levels = np.linspace(-5, 5, 10).tolist() scan_msg = messages.ScanStatusMessage( scan_id="123", status="open", scan_name="grid_scan", metadata={}, - info={"positions": np.random.rand(100, 2).tolist()}, + info={ + "positions": _grid_positions(slow_levels=x_levels, fast_levels=y_levels, snaked=True) + }, request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}}, ) heatmap_widget._image_config = HeatmapConfig( @@ -145,6 +150,111 @@ def test_heatmap_get_grid_scan_image(heatmap_widget): assert sorted(np.asarray(img, dtype=int).flatten().tolist()) == list(range(100)) +def _grid_positions( + *, slow_levels: list[float], fast_levels: list[float], snaked: bool, slow_is_col0: bool = True +) -> list[list[float]]: + positions: list[list[float]] = [] + for slow_i, slow_val in enumerate(slow_levels): + row_fast = fast_levels if (not snaked or slow_i % 2 == 0) else list(reversed(fast_levels)) + for fast_val in row_fast: + if slow_is_col0: + positions.append([slow_val, fast_val]) + else: + positions.append([fast_val, slow_val]) + return positions + + +def test_heatmap_grid_scan_direction_and_snaking_x_fast(heatmap_widget): + heatmap_widget._image_config = HeatmapConfig( + parent_id="parent_id", + x_device=HeatmapDeviceSignal(name="samx", entry="samx"), + y_device=HeatmapDeviceSignal(name="samy", entry="samy"), + z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"), + color_map="viridis", + ) + + # x decreases (relative), y increases (relative), x is fast axis + x0 = 10.0 + y0 = -3.0 + x_levels = (x0 + np.linspace(1.0, -1.0, 3)).tolist() + y_levels = (y0 + np.linspace(-2.0, 2.0, 2)).tolist() + snaked = True + + scan_msg = messages.ScanStatusMessage( + scan_id="123", + status="open", + scan_name="grid_scan", + metadata={}, + info={ + "positions": _grid_positions(slow_levels=y_levels, fast_levels=x_levels, snaked=snaked) + }, + request_inputs={ + "arg_bundle": ["samy", -2.0, 2.0, 2, "samx", 1.0, -1.0, 3], + "kwargs": {"snaked": snaked, "relative": True}, + }, + ) + + img, transform = heatmap_widget.get_grid_scan_image(list(range(6)), msg=scan_msg) + + assert img.shape == (3, 2) + assert img[0, 0] == 0 # first point: (x0,y0) in scan order + assert img[2, 1] == 3 # second row first point due to snaking + assert img[0, 1] == 5 # last point in second row + + p0 = transform.map(QPointF(0.5, 0.5)) + p1 = transform.map(QPointF(2.5, 1.5)) + assert p0.x() == pytest.approx(x_levels[0]) + assert p0.y() == pytest.approx(y_levels[0]) + assert p1.x() == pytest.approx(x_levels[-1]) + assert p1.y() == pytest.approx(y_levels[-1]) + + +def test_heatmap_grid_scan_direction_and_snaking_y_fast(heatmap_widget): + heatmap_widget._image_config = HeatmapConfig( + parent_id="parent_id", + x_device=HeatmapDeviceSignal(name="samx", entry="samx"), + y_device=HeatmapDeviceSignal(name="samy", entry="samy"), + z_device=HeatmapDeviceSignal(name="bpm4i", entry="bpm4i"), + color_map="viridis", + ) + + # x decreases (relative), y increases (relative), y is fast axis + x0 = 1.5 + y0 = 22.0 + x_levels = (x0 + np.linspace(1.0, -1.0, 3)).tolist() + y_levels = (y0 + np.linspace(-2.0, 2.0, 2)).tolist() + snaked = True + + scan_msg = messages.ScanStatusMessage( + scan_id="123", + status="open", + scan_name="grid_scan", + metadata={}, + info={ + "positions": _grid_positions(slow_levels=x_levels, fast_levels=y_levels, snaked=snaked) + }, + request_inputs={ + "arg_bundle": ["samx", 1.0, -1.0, 3, "samy", -2.0, 2.0, 2], + "kwargs": {"snaked": snaked, "relative": True}, + }, + ) + + img, transform = heatmap_widget.get_grid_scan_image(list(range(6)), msg=scan_msg) + + assert img.shape == (3, 2) + assert img[0, 0] == 0 + # For y-fast scans, snaking reverses the y index on every odd x row. + assert img[1, 1] == 2 + assert img[1, 0] == 3 + + p0 = transform.map(QPointF(0.5, 0.5)) + p1 = transform.map(QPointF(2.5, 1.5)) + assert p0.x() == pytest.approx(x_levels[0]) + assert p0.y() == pytest.approx(y_levels[0]) + assert p1.x() == pytest.approx(x_levels[-1]) + assert p1.y() == pytest.approx(y_levels[-1]) + + def test_heatmap_get_step_scan_image(heatmap_widget): scan_msg = messages.ScanStatusMessage( @@ -193,12 +303,16 @@ def test_heatmap_update_plot(heatmap_widget): color_map="viridis", ) heatmap_widget.scan_item = create_dummy_scan_item() + x_levels = np.linspace(-5, 5, 10).tolist() + y_levels = np.linspace(-5, 5, 10).tolist() heatmap_widget.scan_item.status_message = messages.ScanStatusMessage( scan_id="123", status="open", scan_name="grid_scan", metadata={}, - info={"positions": np.random.rand(100, 2).tolist()}, + info={ + "positions": _grid_positions(slow_levels=x_levels, fast_levels=y_levels, snaked=True) + }, request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}}, ) with mock.patch.object(heatmap_widget.main_image, "setImage") as mock_set_image: