0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 03:31:50 +02:00

fix(heatmap): fix pixel size calculation for arbitrary shapes

This commit is contained in:
2025-07-11 13:23:50 +02:00
committed by Jan Wyzula
parent 63a8da680d
commit 6c494258f8
2 changed files with 29 additions and 17 deletions

View File

@ -431,6 +431,8 @@ class Heatmap(ImageBase):
if self._color_bar is not None: if self._color_bar is not None:
self._color_bar.blockSignals(False) self._color_bar.blockSignals(False)
self.image_updated.emit() self.image_updated.emit()
if self.crosshair is not None:
self.crosshair.update_markers_on_image_change()
def get_image_data( def get_image_data(
self, self,
@ -457,14 +459,19 @@ class Heatmap(ImageBase):
return None, None return None, None
if msg.scan_name == "grid_scan": if msg.scan_name == "grid_scan":
return self.get_grid_scan_image(z_data, msg) # We only support the grid scan mode if both scanning motors
if msg.scan_type == "step" and msg.info["positions"]: # are configured in the heatmap config.
if len(z_data) < 4: device_x = self._image_config.x_device.entry
# LinearNDInterpolator requires at least 4 points to interpolate device_y = self._image_config.y_device.entry
return None, None if (
return self.get_step_scan_image(x_data, y_data, z_data, msg) device_x in msg.request_inputs["arg_bundle"]
logger.warning(f"Scan type {msg.scan_name} not supported.") and device_y in msg.request_inputs["arg_bundle"]
return None, None ):
return self.get_grid_scan_image(z_data, msg)
if len(z_data) < 4:
# LinearNDInterpolator requires at least 4 points to interpolate
return None, None
return self.get_step_scan_image(x_data, y_data, z_data, msg)
def get_grid_scan_image( def get_grid_scan_image(
self, z_data: list[float], msg: messages.ScanStatusMessage self, z_data: list[float], msg: messages.ScanStatusMessage
@ -560,17 +567,16 @@ class Heatmap(ImageBase):
Returns: Returns:
tuple[np.ndarray, QTransform]: The image data and the QTransform. tuple[np.ndarray, QTransform]: The image data and the QTransform.
""" """
xy_data = np.column_stack((x_data, y_data))
grid_x, grid_y, transform = self.get_image_grid(msg.scan_id) grid_x, grid_y, transform = self.get_image_grid(xy_data)
# Interpolate the z data onto the grid # Interpolate the z data onto the grid
interp = LinearNDInterpolator(np.column_stack((x_data, y_data)), z_data) interp = LinearNDInterpolator(xy_data, z_data)
grid_z = interp(grid_x, grid_y) grid_z = interp(grid_x, grid_y)
return grid_z, transform return grid_z, transform
@functools.lru_cache(maxsize=2) def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]:
def get_image_grid(self, _scan_id) -> tuple[np.ndarray, np.ndarray, QTransform]:
""" """
LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id
to avoid recalculating the grid for the same scan. to avoid recalculating the grid for the same scan.
@ -581,8 +587,6 @@ class Heatmap(ImageBase):
Returns: Returns:
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform. tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
""" """
msg = self.status_message
positions = np.asarray(msg.info["positions"])
width, height = self.estimate_image_resolution(positions) width, height = self.estimate_image_resolution(positions)
@ -608,7 +612,8 @@ class Heatmap(ImageBase):
return grid_x, grid_y, transform return grid_x, grid_y, transform
def estimate_image_resolution(self, coords: np.ndarray) -> tuple[int, int]: @staticmethod
def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]:
""" """
Estimate the number of pixels needed for the image based on the coordinates. Estimate the number of pixels needed for the image based on the coordinates.

View File

@ -62,8 +62,15 @@ def test_heatmap_get_image_data_missing_data(heatmap_widget):
def test_heatmap_get_image_data_grid_scan(heatmap_widget): def test_heatmap_get_image_data_grid_scan(heatmap_widget):
scan_msg = messages.ScanStatusMessage( scan_msg = messages.ScanStatusMessage(
scan_id="123", status="open", scan_name="grid_scan", metadata={}, info={} scan_id="123",
status="open",
scan_name="grid_scan",
metadata={},
info={},
request_inputs={"arg_bundle": ["samx", -5, 5, 10, "samy", -5, 5, 10], "kwargs": {}},
) )
heatmap_widget.plot(x_name="samx", y_name="samy", z_name="bpm4i")
heatmap_widget.status_message = scan_msg heatmap_widget.status_message = scan_msg
with mock.patch.object(heatmap_widget, "get_grid_scan_image") as mock_get_grid_scan_image: with mock.patch.object(heatmap_widget, "get_grid_scan_image") as mock_get_grid_scan_image:
heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6]) heatmap_widget.get_image_data(x_data=[1, 2], y_data=[3, 4], z_data=[5, 6])