mirror of
https://github.com/bec-project/bec_widgets.git
synced 2026-05-13 01:55:46 +02:00
1546 lines
58 KiB
Python
1546 lines
58 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Literal
|
|
|
|
import numpy as np
|
|
import pyqtgraph as pg
|
|
from bec_lib import bec_logger, messages
|
|
from bec_lib.endpoints import MessageEndpoints
|
|
from bec_lib.utils.import_utils import lazy_import, lazy_import_from
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal
|
|
from qtpy.QtGui import QTransform
|
|
from toolz import partition
|
|
|
|
from bec_widgets.utils.bec_connector import ConnectionConfig
|
|
from bec_widgets.utils.colors import Colors
|
|
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
|
|
from bec_widgets.utils.settings_dialog import SettingsDialog
|
|
from bec_widgets.utils.toolbars.actions import MaterialIconAction
|
|
from bec_widgets.widgets.plots.heatmap.settings.heatmap_setting import HeatmapSettings
|
|
from bec_widgets.widgets.plots.image.image_base import ImageBase
|
|
from bec_widgets.widgets.plots.image.image_item import ImageItem
|
|
from bec_widgets.widgets.plots.plot_base import PlotBase
|
|
|
|
logger = bec_logger.logger
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from scipy.interpolate import (
|
|
CloughTocher2DInterpolator,
|
|
LinearNDInterpolator,
|
|
NearestNDInterpolator,
|
|
)
|
|
from scipy.spatial import cKDTree
|
|
else:
|
|
|
|
CloughTocher2DInterpolator, LinearNDInterpolator, NearestNDInterpolator = lazy_import_from(
|
|
"scipy.interpolate",
|
|
["CloughTocher2DInterpolator", "LinearNDInterpolator", "NearestNDInterpolator"],
|
|
)
|
|
cKDTree = lazy_import_from("scipy.spatial", ["cKDTree"])
|
|
|
|
|
|
class HeatmapDeviceSignal(BaseModel):
|
|
"""The configuration of a signal in the scatter waveform widget."""
|
|
|
|
device: str
|
|
signal: str
|
|
|
|
model_config: dict = {"validate_assignment": True}
|
|
|
|
|
|
class HeatmapConfig(ConnectionConfig):
|
|
parent_id: str | None = Field(None, description="The parent plot of the curve.")
|
|
color_map: str | None = Field(
|
|
"plasma", description="The color palette of the heatmap widget.", validate_default=True
|
|
)
|
|
color_bar: Literal["full", "simple"] | None = Field(
|
|
None, description="The type of the color bar."
|
|
)
|
|
interpolation: Literal["linear", "nearest", "clough"] = Field(
|
|
"linear", description="The interpolation method for the heatmap."
|
|
)
|
|
oversampling_factor: float = Field(
|
|
1.0,
|
|
description="Factor to oversample the grid resolution (1.0 = no oversampling, 2.0 = 2x resolution).",
|
|
)
|
|
show_config_label: bool = Field(
|
|
True, description="Whether to show the configuration label in the heatmap."
|
|
)
|
|
enforce_interpolation: bool = Field(
|
|
False, description="Whether to use the interpolation mode even for grid scans."
|
|
)
|
|
lock_aspect_ratio: bool = Field(
|
|
False, description="Whether to lock the aspect ratio of the image."
|
|
)
|
|
device_x: HeatmapDeviceSignal | None = Field(
|
|
None, description="The x device signal of the heatmap."
|
|
)
|
|
device_y: HeatmapDeviceSignal | None = Field(
|
|
None, description="The y device signal of the heatmap."
|
|
)
|
|
device_z: HeatmapDeviceSignal | None = Field(
|
|
None, description="The z device signal of the heatmap."
|
|
)
|
|
|
|
model_config: dict = {"validate_assignment": True}
|
|
_validate_color_palette = field_validator("color_map")(Colors.validate_color_map)
|
|
|
|
|
|
@dataclass
|
|
class _InterpolationRequest:
|
|
"""Immutable payload describing an interpolation request for the worker thread.
|
|
|
|
Args:
|
|
x_data: X coordinates collected so far.
|
|
y_data: Y coordinates collected so far.
|
|
z_data: Z values associated with x/y.
|
|
data_version: Number of points at request time (len(z_data)); used to reject stale results.
|
|
scan_id: Identifier for the scan that produced the data.
|
|
interpolation: Interpolation method to apply.
|
|
oversampling_factor: Oversampling factor for the interpolation grid.
|
|
"""
|
|
|
|
x_data: list[float]
|
|
y_data: list[float]
|
|
z_data: list[float]
|
|
data_version: int
|
|
scan_id: str
|
|
interpolation: str
|
|
oversampling_factor: float
|
|
|
|
|
|
class _StepInterpolationWorker(QObject):
|
|
"""Worker for performing step-scan interpolation in a background thread.
|
|
|
|
This worker computes the interpolated heatmap image using the provided data
|
|
and settings, then emits the result or a failure signal.
|
|
|
|
Signals:
|
|
finished(image, transform, data_version, scan_id):
|
|
Emitted when interpolation is successful.
|
|
- image: The resulting image (numpy array or similar).
|
|
- transform: The QTransform for the image.
|
|
- data_version: The data version for the request.
|
|
- scan_id: The scan identifier.
|
|
failed(error_message, data_version, scan_id):
|
|
Emitted when interpolation fails.
|
|
- error_message: The error message string.
|
|
- data_version: The data version for the request.
|
|
- scan_id: The scan identifier.
|
|
"""
|
|
|
|
finished = Signal(object, object, int, str)
|
|
failed = Signal(str, int, str)
|
|
|
|
def __init__(self, parent: QObject | None = None):
|
|
super().__init__(parent=parent)
|
|
self._active_request: _InterpolationRequest | None = None
|
|
self._processing = False
|
|
|
|
@property
|
|
def is_processing(self) -> bool:
|
|
"""Return whether the worker is currently processing a request."""
|
|
return self._processing
|
|
|
|
@SafeSlot(object, int)
|
|
def process(self, request: _InterpolationRequest, data_version: int):
|
|
"""
|
|
Process an interpolation request in the worker thread.
|
|
|
|
Args:
|
|
request(_InterpolationRequest): The interpolation request payload.
|
|
data_version(int): The data version for the request.
|
|
"""
|
|
self._active_request = request
|
|
self._processing = True
|
|
try:
|
|
image, transform = Heatmap.compute_step_scan_image(
|
|
x_data=np.asarray(request.x_data, dtype=float),
|
|
y_data=np.asarray(request.y_data, dtype=float),
|
|
z_data=np.asarray(request.z_data, dtype=float),
|
|
oversampling_factor=request.oversampling_factor,
|
|
interpolation_method=request.interpolation,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive
|
|
logger.warning(f"Step-scan interpolation failed with: {exc}")
|
|
self.failed.emit(str(exc), data_version, request.scan_id)
|
|
self._processing = False
|
|
return
|
|
self._processing = False
|
|
self.finished.emit(image, transform, data_version, request.scan_id)
|
|
|
|
|
|
class Heatmap(ImageBase):
|
|
"""
|
|
Heatmap widget for visualizing 2d grid data with color mapping for the z-axis.
|
|
"""
|
|
|
|
USER_ACCESS = [
|
|
*PlotBase.USER_ACCESS,
|
|
# ImageView Specific Settings
|
|
"color_map",
|
|
"color_map.setter",
|
|
"v_range",
|
|
"v_range.setter",
|
|
"v_min",
|
|
"v_min.setter",
|
|
"v_max",
|
|
"v_max.setter",
|
|
"autorange",
|
|
"autorange.setter",
|
|
"autorange_mode",
|
|
"autorange_mode.setter",
|
|
"enable_colorbar",
|
|
"enable_simple_colorbar",
|
|
"enable_simple_colorbar.setter",
|
|
"enable_full_colorbar",
|
|
"enable_full_colorbar.setter",
|
|
"interpolation_method",
|
|
"interpolation_method.setter",
|
|
"oversampling_factor",
|
|
"oversampling_factor.setter",
|
|
"enforce_interpolation",
|
|
"enforce_interpolation.setter",
|
|
"fft",
|
|
"fft.setter",
|
|
"log",
|
|
"log.setter",
|
|
"main_image",
|
|
"add_roi",
|
|
"remove_roi",
|
|
"rois",
|
|
"plot",
|
|
# Device properties
|
|
"device_x",
|
|
"device_x.setter",
|
|
"signal_x",
|
|
"signal_x.setter",
|
|
"device_y",
|
|
"device_y.setter",
|
|
"signal_y",
|
|
"signal_y.setter",
|
|
"device_z",
|
|
"device_z.setter",
|
|
"signal_z",
|
|
"signal_z.setter",
|
|
]
|
|
|
|
PLUGIN = True
|
|
RPC = True
|
|
ICON_NAME = "dataset"
|
|
|
|
new_scan = Signal()
|
|
new_scan_id = Signal(str)
|
|
sync_signal_update = Signal()
|
|
heatmap_property_changed = Signal()
|
|
interpolation_requested = Signal(object, int)
|
|
|
|
def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs):
|
|
if config is None:
|
|
config = HeatmapConfig(
|
|
widget_class=self.__class__.__name__,
|
|
parent_id=None,
|
|
color_map="plasma",
|
|
color_bar=None,
|
|
interpolation="linear",
|
|
oversampling_factor=1.0,
|
|
lock_aspect_ratio=False,
|
|
device_x=None,
|
|
device_y=None,
|
|
device_z=None,
|
|
)
|
|
super().__init__(parent=parent, config=config, theme_update=True, **kwargs)
|
|
self._image_config = config
|
|
self.scan_id = None
|
|
self.old_scan_id = None
|
|
self.scan_item = None
|
|
self.status_message = None
|
|
self._grid_index = None
|
|
# Highest data_version we have dispatched for the current scan; used to drop stale results.
|
|
# Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it.
|
|
self._latest_interpolation_version = -1
|
|
self._interpolation_thread: QThread | None = None
|
|
self._interpolation_worker: _StepInterpolationWorker | None = None
|
|
self._pending_interpolation_request: _InterpolationRequest | None = None
|
|
self.heatmap_dialog = None
|
|
bg_color = pg.mkColor((240, 240, 240, 150))
|
|
self.config_label = pg.LegendItem(
|
|
labelTextColor=(0, 0, 0), offset=(-30, 1), brush=pg.mkBrush(bg_color), horSpacing=0
|
|
)
|
|
self.config_label.setParentItem(self.plot_item.vb)
|
|
self.config_label.setVisible(False)
|
|
self.reload = False
|
|
self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status())
|
|
self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress())
|
|
self.heatmap_property_changed.connect(lambda: self.sync_signal_update.emit())
|
|
|
|
self.proxy_update_sync = pg.SignalProxy(
|
|
self.sync_signal_update, rateLimit=5, slot=self.update_plot
|
|
)
|
|
self._init_toolbar_heatmap()
|
|
self.toolbar.show_bundles(
|
|
[
|
|
"heatmap_settings",
|
|
"plot_export",
|
|
"image_crosshair",
|
|
"mouse_interaction",
|
|
"image_autorange",
|
|
"image_colorbar",
|
|
"image_processing",
|
|
"axis_popup",
|
|
"interpolation_info",
|
|
]
|
|
)
|
|
|
|
@property
|
|
def main_image(self) -> ImageItem:
|
|
"""Access the main image item."""
|
|
return self.layer_manager["main"].image
|
|
|
|
################################################################################
|
|
# Widget Specific GUI interactions
|
|
################################################################################
|
|
|
|
@SafeSlot(str)
|
|
def apply_theme(self, theme: str):
|
|
"""
|
|
Apply the current theme to the heatmap widget.
|
|
"""
|
|
super().apply_theme(theme)
|
|
if theme == "dark":
|
|
brush = pg.mkBrush(pg.mkColor(50, 50, 50, 150))
|
|
color = pg.mkColor(255, 255, 255)
|
|
else:
|
|
brush = pg.mkBrush(pg.mkColor(240, 240, 240, 150))
|
|
color = pg.mkColor(0, 0, 0)
|
|
if hasattr(self, "config_label"):
|
|
self.config_label.setBrush(brush)
|
|
self.config_label.setLabelTextColor(color)
|
|
self.redraw_config_label()
|
|
|
|
@SafeSlot(popup_error=True)
|
|
def plot(
|
|
self,
|
|
device_x: str,
|
|
device_y: str,
|
|
device_z: str,
|
|
signal_x: None | str = None,
|
|
signal_y: None | str = None,
|
|
signal_z: None | str = None,
|
|
color_map: str | None = "plasma",
|
|
validate_bec: bool = True,
|
|
interpolation: Literal["linear", "nearest"] | None = None,
|
|
enforce_interpolation: bool | None = None,
|
|
oversampling_factor: float | None = None,
|
|
lock_aspect_ratio: bool | None = None,
|
|
show_config_label: bool | None = None,
|
|
reload: bool = False,
|
|
):
|
|
"""
|
|
Plot the heatmap with the given x, y, and z data.
|
|
|
|
Args:
|
|
device_x (str): The name of the x-axis device signal.
|
|
device_y (str): The name of the y-axis device signal.
|
|
device_z (str): The name of the z-axis device signal.
|
|
signal_x (str | None): The entry for the x-axis device signal.
|
|
signal_y (str | None): The entry for the y-axis device signal.
|
|
signal_z (str | None): The entry for the z-axis device signal.
|
|
color_map (str | None): The color map to use for the heatmap.
|
|
validate_bec (bool): Whether to validate the entries against BEC signals.
|
|
interpolation (Literal["linear", "nearest"] | None): The interpolation method to use.
|
|
enforce_interpolation (bool | None): Whether to enforce interpolation even for grid scans.
|
|
oversampling_factor (float | None): Factor to oversample the grid resolution.
|
|
lock_aspect_ratio (bool | None): Whether to lock the aspect ratio of the image.
|
|
show_config_label (bool | None): Whether to show the configuration label in the heatmap.
|
|
reload (bool): Whether to reload the heatmap with new data.
|
|
"""
|
|
if validate_bec:
|
|
signal_x = self.entry_validator.validate_signal(device_x, signal_x)
|
|
signal_y = self.entry_validator.validate_signal(device_y, signal_y)
|
|
signal_z = self.entry_validator.validate_signal(device_z, signal_z)
|
|
|
|
if signal_x is None or signal_y is None or signal_z is None:
|
|
raise ValueError("x, y, and z entries must be provided.")
|
|
if device_x is None or device_y is None or device_z is None:
|
|
raise ValueError("x, y, and z names must be provided.")
|
|
|
|
if interpolation is None:
|
|
interpolation = self._image_config.interpolation
|
|
|
|
if oversampling_factor is None:
|
|
oversampling_factor = self._image_config.oversampling_factor
|
|
|
|
if enforce_interpolation is None:
|
|
enforce_interpolation = self._image_config.enforce_interpolation
|
|
|
|
if lock_aspect_ratio is None:
|
|
lock_aspect_ratio = self._image_config.lock_aspect_ratio
|
|
|
|
if show_config_label is None:
|
|
show_config_label = self._image_config.show_config_label
|
|
|
|
def _device_key(device: HeatmapDeviceSignal | None) -> tuple[str | None, str | None]:
|
|
return (device.device if device else None, device.signal if device else None)
|
|
|
|
prev_cfg = getattr(self, "_image_config", None)
|
|
config_changed = False
|
|
if prev_cfg and prev_cfg.device_x and prev_cfg.device_y and prev_cfg.device_z:
|
|
config_changed = any(
|
|
(
|
|
_device_key(prev_cfg.device_x) != (device_x, signal_x),
|
|
_device_key(prev_cfg.device_y) != (device_y, signal_y),
|
|
_device_key(prev_cfg.device_z) != (device_z, signal_z),
|
|
)
|
|
)
|
|
|
|
self._image_config = HeatmapConfig(
|
|
parent_id=self.gui_id,
|
|
device_x=HeatmapDeviceSignal(device=device_x, signal=signal_x),
|
|
device_y=HeatmapDeviceSignal(device=device_y, signal=signal_y),
|
|
device_z=HeatmapDeviceSignal(device=device_z, signal=signal_z),
|
|
color_map=color_map,
|
|
color_bar=None,
|
|
interpolation=interpolation,
|
|
oversampling_factor=oversampling_factor,
|
|
enforce_interpolation=enforce_interpolation,
|
|
lock_aspect_ratio=lock_aspect_ratio,
|
|
show_config_label=show_config_label,
|
|
)
|
|
self.color_map = color_map
|
|
self.reload = reload or config_changed
|
|
if config_changed:
|
|
self._grid_index = None
|
|
self.main_image.clear()
|
|
self.update_labels()
|
|
|
|
self._fetch_running_scan()
|
|
self.sync_signal_update.emit()
|
|
|
|
def _fetch_running_scan(self):
|
|
scan = self.client.queue.scan_storage.current_scan
|
|
if scan is not None:
|
|
self.scan_item = scan
|
|
self.scan_id = scan.scan_id
|
|
elif self.client.history and len(self.client.history) > 0:
|
|
self.scan_item = self.client.history[-1]
|
|
self.scan_id = self.client.history._scan_ids[-1]
|
|
self.old_scan_id = None
|
|
|
|
def update_labels(self):
|
|
"""
|
|
Update the labels of the x, y, and z axes.
|
|
"""
|
|
if self._image_config is None:
|
|
return
|
|
|
|
# Safely get device names (might be None if not yet configured)
|
|
device_x = self._image_config.device_x
|
|
device_y = self._image_config.device_y
|
|
device_z = self._image_config.device_z
|
|
|
|
device_x_name = device_x.device if device_x else None
|
|
device_y_name = device_y.device if device_y else None
|
|
device_z_name = device_z.device if device_z else None
|
|
|
|
if device_x_name is not None:
|
|
self.x_label = device_x_name # type: ignore
|
|
x_dev = self.dev.get(device_x_name)
|
|
if x_dev and hasattr(x_dev, "egu"):
|
|
self.x_label_units = x_dev.egu()
|
|
if device_y_name is not None:
|
|
self.y_label = device_y_name # type: ignore
|
|
y_dev = self.dev.get(device_y_name)
|
|
if y_dev and hasattr(y_dev, "egu"):
|
|
self.y_label_units = y_dev.egu()
|
|
if device_z_name is not None:
|
|
self.title = device_z_name
|
|
|
|
def _init_toolbar_heatmap(self):
|
|
"""
|
|
Initialize the toolbar for the heatmap widget, adding actions for heatmap settings.
|
|
"""
|
|
self.toolbar.add_action(
|
|
"heatmap_settings",
|
|
MaterialIconAction(
|
|
icon_name="scatter_plot",
|
|
tooltip="Show Heatmap Settings",
|
|
checkable=True,
|
|
parent=self,
|
|
),
|
|
)
|
|
|
|
self.toolbar.components.get_action("heatmap_settings").action.triggered.connect(
|
|
self.show_heatmap_settings
|
|
)
|
|
|
|
# disable all processing actions except for the fft and log
|
|
bundle = self.toolbar.get_bundle("image_processing")
|
|
for name, action in bundle.bundle_actions.items():
|
|
if name not in ["image_processing_fft", "image_processing_log"]:
|
|
action().action.setVisible(False)
|
|
|
|
self.toolbar.add_action(
|
|
"interpolation_info",
|
|
MaterialIconAction(
|
|
icon_name="info", tooltip="Show Interpolation Info", checkable=True, parent=self
|
|
),
|
|
)
|
|
self.toolbar.components.get_action("interpolation_info").action.triggered.connect(
|
|
self.toggle_interpolation_info
|
|
)
|
|
self.toolbar.components.get_action("interpolation_info").action.setChecked(
|
|
self._image_config.show_config_label
|
|
)
|
|
|
|
def show_heatmap_settings(self):
|
|
"""
|
|
Show the heatmap settings dialog.
|
|
"""
|
|
heatmap_settings_action = self.toolbar.components.get_action("heatmap_settings").action
|
|
if self.heatmap_dialog is None or not self.heatmap_dialog.isVisible():
|
|
heatmap_settings = HeatmapSettings(parent=self, target_widget=self, popup=True)
|
|
self.heatmap_dialog = SettingsDialog(
|
|
self, settings_widget=heatmap_settings, window_title="Heatmap Settings", modal=False
|
|
)
|
|
self.heatmap_dialog.resize(700, 350)
|
|
# When the dialog is closed, update the toolbar icon and clear the reference
|
|
self.heatmap_dialog.finished.connect(self._heatmap_dialog_closed)
|
|
self.heatmap_dialog.show()
|
|
heatmap_settings_action.setChecked(True)
|
|
else:
|
|
# If already open, bring it to the front
|
|
self.heatmap_dialog.raise_()
|
|
self.heatmap_dialog.activateWindow()
|
|
heatmap_settings_action.setChecked(True) # keep it toggled
|
|
|
|
def toggle_interpolation_info(self):
|
|
"""
|
|
Toggle the visibility of the interpolation info label.
|
|
"""
|
|
self._image_config.show_config_label = not self._image_config.show_config_label
|
|
self.toolbar.components.get_action("interpolation_info").action.setChecked(
|
|
self._image_config.show_config_label
|
|
)
|
|
self.redraw_config_label()
|
|
|
|
def _heatmap_dialog_closed(self):
|
|
"""
|
|
Slot for when the heatmap settings dialog is closed.
|
|
"""
|
|
self.heatmap_dialog = None
|
|
self.toolbar.components.get_action("heatmap_settings").action.setChecked(False)
|
|
|
|
@SafeSlot(dict, dict)
|
|
def on_scan_status(self, msg: dict, meta: dict):
|
|
"""
|
|
Initial scan status message handler, which is triggered at the begging and end of scan.
|
|
|
|
Args:
|
|
msg(dict): The message content.
|
|
meta(dict): The message metadata.
|
|
"""
|
|
current_scan_id = msg.get("scan_id", None)
|
|
if current_scan_id is None:
|
|
return
|
|
if current_scan_id != self.scan_id:
|
|
self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts
|
|
self.reset()
|
|
self.new_scan.emit()
|
|
self.new_scan_id.emit(current_scan_id)
|
|
self.old_scan_id = self.scan_id
|
|
self.scan_id = current_scan_id
|
|
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # type: ignore
|
|
|
|
# First trigger to update the scan curves
|
|
self.sync_signal_update.emit()
|
|
|
|
@SafeSlot(dict, dict)
|
|
def on_scan_progress(self, msg: dict, meta: dict):
|
|
self.sync_signal_update.emit()
|
|
status = msg.get("done")
|
|
if status:
|
|
QTimer.singleShot(100, self.update_plot)
|
|
QTimer.singleShot(300, self.update_plot)
|
|
|
|
@SafeSlot(verify_sender=True)
|
|
def update_plot(self, _=None) -> None:
|
|
"""
|
|
Update the plot with the current data.
|
|
"""
|
|
if self.scan_item is None:
|
|
logger.info("No scan executed so far; skipping update.")
|
|
return
|
|
data, access_key = self._fetch_scan_data_and_access()
|
|
if data == "none":
|
|
logger.info("No scan executed so far; skipping update.")
|
|
return
|
|
|
|
if self._image_config is None:
|
|
return
|
|
try:
|
|
device_x = self._image_config.device_x.device
|
|
signal_x = self._image_config.device_x.signal
|
|
device_y = self._image_config.device_y.device
|
|
signal_y = self._image_config.device_y.signal
|
|
device_z = self._image_config.device_z.device
|
|
signal_z = self._image_config.device_z.signal
|
|
except AttributeError:
|
|
return
|
|
|
|
if access_key == "val":
|
|
x_data = data.get(device_x, {}).get(signal_x, {}).get(access_key, None)
|
|
y_data = data.get(device_y, {}).get(signal_y, {}).get(access_key, None)
|
|
z_data = data.get(device_z, {}).get(signal_z, {}).get(access_key, None)
|
|
else:
|
|
x_data = data.get(device_x, {}).get(signal_x, {}).read().get("value", None)
|
|
y_data = data.get(device_y, {}).get(signal_y, {}).read().get("value", None)
|
|
z_data = data.get(device_z, {}).get(signal_z, {}).read().get("value", None)
|
|
|
|
if not isinstance(x_data, list):
|
|
x_data = x_data.tolist() if isinstance(x_data, np.ndarray) else None
|
|
if not isinstance(y_data, list):
|
|
y_data = y_data.tolist() if isinstance(y_data, np.ndarray) else None
|
|
if not isinstance(z_data, list):
|
|
z_data = z_data.tolist() if isinstance(z_data, np.ndarray) else None
|
|
|
|
if x_data is None or y_data is None or z_data is None:
|
|
logger.warning("x, y, or z data is None; skipping update.")
|
|
return
|
|
if len(x_data) != len(y_data) or len(x_data) != len(z_data):
|
|
logger.warning(
|
|
"x, y, and z data lengths do not match; skipping update. "
|
|
f"Lengths: x={len(x_data)}, y={len(y_data)}, z={len(z_data)}"
|
|
)
|
|
return
|
|
|
|
if hasattr(self.scan_item, "status_message"):
|
|
scan_msg = self.scan_item.status_message
|
|
elif hasattr(self.scan_item, "metadata"):
|
|
metadata = self.scan_item.metadata["bec"]
|
|
status = metadata["status"]
|
|
scan_id = metadata["scan_id"]
|
|
scan_name = metadata["scan_name"]
|
|
scan_type = metadata["scan_type"]
|
|
scan_number = metadata["scan_number"]
|
|
request_inputs = metadata["request_inputs"]
|
|
if "arg_bundle" in request_inputs and isinstance(request_inputs["arg_bundle"], str):
|
|
# Convert the arg_bundle from a JSON string to a dictionary
|
|
request_inputs["arg_bundle"] = json.loads(request_inputs["arg_bundle"])
|
|
positions = metadata.get("positions", [])
|
|
positions = positions.tolist() if isinstance(positions, np.ndarray) else positions
|
|
|
|
scan_msg = messages.ScanStatusMessage(
|
|
status=status,
|
|
scan_id=scan_id,
|
|
scan_name=scan_name,
|
|
scan_number=scan_number,
|
|
scan_type=scan_type,
|
|
request_inputs=request_inputs,
|
|
info={"positions": positions},
|
|
)
|
|
else:
|
|
scan_msg = None
|
|
|
|
if scan_msg is None:
|
|
logger.warning("Scan message is None; skipping update.")
|
|
return
|
|
self.status_message = scan_msg
|
|
|
|
if self._image_config.show_config_label:
|
|
self.redraw_config_label()
|
|
|
|
if self._is_grid_scan_supported(scan_msg):
|
|
img, transform = self.get_grid_scan_image(z_data, scan_msg)
|
|
self._apply_image_update(img, transform)
|
|
return
|
|
|
|
if len(z_data) < 4:
|
|
# LinearNDInterpolator requires at least 4 points to interpolate
|
|
logger.warning("Not enough data points to interpolate; skipping update.")
|
|
return
|
|
|
|
self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg)
|
|
|
|
def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None):
|
|
"""Apply interpolated image and transform to the heatmap display.
|
|
|
|
This method updates the main image with the computed data and emits
|
|
the image_updated signal. Color bar signals are temporarily blocked
|
|
during the update to prevent cascading updates.
|
|
|
|
Args:
|
|
img(np.ndarray): The interpolated image data, or None if unavailable
|
|
transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable
|
|
"""
|
|
if img is None or transform is None:
|
|
logger.warning("Image data is None; skipping update.")
|
|
return
|
|
|
|
if self._color_bar is not None:
|
|
self._color_bar.blockSignals(True)
|
|
if self.main_image is None:
|
|
logger.warning("Main image item is None; cannot update image.")
|
|
return
|
|
self.main_image.set_data(img, transform=transform)
|
|
if self._color_bar is not None:
|
|
self._color_bar.blockSignals(False)
|
|
self.image_updated.emit()
|
|
if self.crosshair is not None:
|
|
self.crosshair.update_markers_on_image_change()
|
|
|
|
def _request_step_scan_interpolation(
|
|
self,
|
|
x_data: list[float],
|
|
y_data: list[float],
|
|
z_data: list[float],
|
|
msg: messages.ScanStatusMessage,
|
|
):
|
|
"""Request step-scan interpolation in a background thread.
|
|
|
|
If a thread is already running, the request is queued as a pending request
|
|
and will be processed when the current interpolation completes.
|
|
|
|
Args:
|
|
x_data(list[float]): X coordinates of data points
|
|
y_data(list[float]): Y coordinates of data points
|
|
z_data(list[float]): Z values at each point
|
|
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
|
|
"""
|
|
request = _InterpolationRequest(
|
|
x_data=list(x_data),
|
|
y_data=list(y_data),
|
|
z_data=list(z_data),
|
|
data_version=len(z_data),
|
|
scan_id=msg.scan_id,
|
|
interpolation=self._image_config.interpolation,
|
|
oversampling_factor=self._image_config.oversampling_factor,
|
|
)
|
|
|
|
if self._interpolation_worker is not None and self._interpolation_worker.is_processing:
|
|
self._pending_interpolation_request = request
|
|
return
|
|
|
|
self._start_step_scan_interpolation(request)
|
|
|
|
def _ensure_interpolation_thread(self):
|
|
if self._interpolation_thread is None:
|
|
self._interpolation_thread = QThread()
|
|
self._interpolation_worker = _StepInterpolationWorker()
|
|
self._interpolation_worker.moveToThread(self._interpolation_thread)
|
|
self.interpolation_requested.connect(
|
|
self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection
|
|
)
|
|
self._interpolation_worker.finished.connect(
|
|
self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection
|
|
)
|
|
self._interpolation_worker.failed.connect(
|
|
self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection
|
|
)
|
|
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
|
|
self._interpolation_thread.start()
|
|
|
|
def _start_step_scan_interpolation(self, request: _InterpolationRequest):
|
|
# data_version = len(z_data) at the time of the request; keep the latest to gate results.
|
|
self._ensure_interpolation_thread()
|
|
if self._interpolation_thread is not None and not self._interpolation_thread.isRunning():
|
|
self._interpolation_thread.start()
|
|
self._latest_interpolation_version = request.data_version
|
|
self.interpolation_requested.emit(request, request.data_version)
|
|
|
|
def _on_interpolation_finished(
|
|
self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str
|
|
):
|
|
# Only accept results that match the latest dispatched version for the active scan.
|
|
if data_version == self._latest_interpolation_version and scan_id == self.scan_id:
|
|
self._apply_image_update(img, transform)
|
|
else:
|
|
logger.info("Discarding outdated interpolation result.")
|
|
self._maybe_start_pending_interpolation()
|
|
|
|
def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str):
|
|
logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}")
|
|
self._maybe_start_pending_interpolation()
|
|
|
|
def _finish_interpolation_thread(self):
|
|
self._pending_interpolation_request = None
|
|
if self._interpolation_worker is not None:
|
|
try:
|
|
self.interpolation_requested.disconnect(self._interpolation_worker.process)
|
|
except (TypeError, RuntimeError) as ext:
|
|
logger.warning(f"Processing thread already disconnected: {ext}")
|
|
pass
|
|
self._interpolation_worker.deleteLater()
|
|
self._interpolation_worker = None
|
|
if self._interpolation_thread is not None:
|
|
if self._interpolation_thread.isRunning():
|
|
self._interpolation_thread.quit()
|
|
if not self._interpolation_thread.wait(3000): # 3s timeout
|
|
logger.error(
|
|
f"Interpolation thread of widget {self.gui_id} did not stop within timeout 3s; leaving it dangling."
|
|
)
|
|
self._interpolation_thread.deleteLater()
|
|
self._interpolation_thread = None
|
|
logger.info(f"Interpolation thread finished of widget {self.gui_id}")
|
|
|
|
def _maybe_start_pending_interpolation(self):
|
|
if self._pending_interpolation_request is None:
|
|
return
|
|
if self._pending_interpolation_request.scan_id != self.scan_id:
|
|
self._pending_interpolation_request = None
|
|
return
|
|
if self._interpolation_worker is not None and self._interpolation_worker.is_processing:
|
|
return
|
|
|
|
pending = self._pending_interpolation_request
|
|
self._pending_interpolation_request = None
|
|
self._start_step_scan_interpolation(pending)
|
|
|
|
def _cancel_interpolation(self):
|
|
"""Cancel any pending interpolation request without invalidating in-flight work.
|
|
|
|
This clears the pending request queue but does not invalidate in-flight work,
|
|
allowing any currently running interpolation to complete and update the display
|
|
if it matches the current scan.
|
|
"""
|
|
self._pending_interpolation_request = None
|
|
# Do not change the active data version so an in-flight worker can still deliver.
|
|
|
|
def _invalidate_interpolation_generation(self):
|
|
"""Invalidate all pending interpolation results and ignore in-flight updates."""
|
|
self._pending_interpolation_request = None
|
|
self._latest_interpolation_version = -1
|
|
|
|
def redraw_config_label(self):
|
|
scan_msg = self.status_message
|
|
if scan_msg is None:
|
|
return
|
|
if not self._image_config.show_config_label:
|
|
self.config_label.setVisible(False)
|
|
return
|
|
|
|
self.config_label.setOffset((-30, 1))
|
|
self.config_label.setVisible(True)
|
|
self.config_label.clear()
|
|
self.config_label.addItem(self.plot_item, f"Scan: {scan_msg.scan_number}")
|
|
self.config_label.addItem(self.plot_item, f"Scan Name: {scan_msg.scan_name}")
|
|
if scan_msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
|
|
self.config_label.addItem(
|
|
self.plot_item, f"Interpolation: {self._image_config.interpolation}"
|
|
)
|
|
self.config_label.addItem(
|
|
self.plot_item, f"Oversampling: {self._image_config.oversampling_factor}x"
|
|
)
|
|
|
|
def get_image_data(
|
|
self,
|
|
x_data: list[float] | None = None,
|
|
y_data: list[float] | None = None,
|
|
z_data: list[float] | None = None,
|
|
) -> tuple[np.ndarray | None, QTransform | None]:
|
|
"""
|
|
Get the image data for the heatmap. Depending on the scan type, it will
|
|
either pre-allocate the grid (grid_scan) or interpolate the data (step scan).
|
|
|
|
Args:
|
|
x_data (np.ndarray): The x data.
|
|
y_data (np.ndarray): The y data.
|
|
z_data (np.ndarray): The z data.
|
|
|
|
Returns:
|
|
tuple[np.ndarray, QTransform]: The image data and the QTransform.
|
|
"""
|
|
msg = self.status_message
|
|
if x_data is None or y_data is None or z_data is None or msg is None:
|
|
logger.warning("x, y, or z data is None; skipping update.")
|
|
return None, None
|
|
|
|
if self._is_grid_scan_supported(msg):
|
|
return self.get_grid_scan_image(z_data, msg)
|
|
if len(z_data) < 4:
|
|
# LinearNDInterpolator requires at least 4 points to interpolate
|
|
return None, None
|
|
return self.get_step_scan_image(x_data, y_data, z_data)
|
|
|
|
def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool:
|
|
"""Check if the scan can use optimized grid_scan rendering.
|
|
|
|
Grid scans can avoid interpolation if both X and Y devices match
|
|
the configured devices and interpolation is not enforced.
|
|
|
|
Args:
|
|
msg(messages.ScanStatusMessage): Scan status message containing scan metadata
|
|
|
|
Returns:
|
|
True if grid_scan optimization is applicable, False otherwise
|
|
"""
|
|
if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation:
|
|
return False
|
|
|
|
signal_x = self._image_config.device_x.signal
|
|
signal_y = self._image_config.device_y.signal
|
|
return (
|
|
signal_x in msg.request_inputs["arg_bundle"]
|
|
and signal_y in msg.request_inputs["arg_bundle"]
|
|
)
|
|
|
|
def get_grid_scan_image(
|
|
self, z_data: list[float], msg: messages.ScanStatusMessage
|
|
) -> tuple[np.ndarray, QTransform]:
|
|
"""
|
|
Get the image data for a grid scan.
|
|
Args:
|
|
z_data (np.ndarray): The z data.
|
|
msg (messages.ScanStatusMessage): The scan status message.
|
|
|
|
Returns:
|
|
tuple[np.ndarray, QTransform]: The image data and the QTransform.
|
|
"""
|
|
|
|
args = self.arg_bundle_to_dict(4, msg.request_inputs["arg_bundle"])
|
|
|
|
signal_x = self._image_config.device_x.signal
|
|
signal_y = self._image_config.device_y.signal
|
|
shape = (args[signal_x][-1], args[signal_y][-1])
|
|
|
|
data = self.main_image.raw_data
|
|
|
|
if data is None or data.shape != shape:
|
|
data = np.empty(shape)
|
|
data.fill(np.nan)
|
|
elif self.reload:
|
|
data.fill(np.nan)
|
|
|
|
snaked = msg.request_inputs["kwargs"].get("snaked", True)
|
|
|
|
slow_entry, fast_entry = (
|
|
msg.request_inputs["arg_bundle"][0],
|
|
msg.request_inputs["arg_bundle"][4],
|
|
)
|
|
|
|
scan_pos = np.asarray(msg.info["positions"], dtype=float)
|
|
relative = bool(msg.request_inputs["kwargs"].get("relative", False))
|
|
|
|
def _axis_column(entry: str) -> int:
|
|
return 0 if entry == slow_entry else 1
|
|
|
|
def _axis_levels(entry: str, npts: int) -> np.ndarray:
|
|
start, stop = args[entry][:2]
|
|
if relative:
|
|
origin = float(scan_pos[0, _axis_column(entry)] - start)
|
|
return origin + np.linspace(start, stop, npts)
|
|
return np.linspace(start, stop, npts)
|
|
|
|
x_levels = _axis_levels(signal_x, shape[0])
|
|
y_levels = _axis_levels(signal_y, shape[1])
|
|
|
|
pixel_size_x = (
|
|
float(x_levels[-1] - x_levels[0]) / max(shape[0] - 1, 1) if shape[0] > 1 else 1.0
|
|
)
|
|
pixel_size_y = (
|
|
float(y_levels[-1] - y_levels[0]) / max(shape[1] - 1, 1) if shape[1] > 1 else 1.0
|
|
)
|
|
|
|
transform = QTransform()
|
|
transform.scale(pixel_size_x, pixel_size_y)
|
|
transform.translate(x_levels[0] / pixel_size_x - 0.5, y_levels[0] / pixel_size_y - 0.5)
|
|
|
|
# Fill the data array with the z values
|
|
if self._grid_index is None or self.reload:
|
|
self._grid_index = 0
|
|
self.reload = False
|
|
|
|
for i in range(self._grid_index, len(z_data)):
|
|
slow_i, fast_i = divmod(i, args[fast_entry][-1])
|
|
if snaked and (slow_i % 2 == 1):
|
|
fast_i = args[fast_entry][-1] - 1 - fast_i
|
|
|
|
if signal_x == fast_entry:
|
|
x_i, y_i = fast_i, slow_i
|
|
else:
|
|
x_i, y_i = slow_i, fast_i
|
|
|
|
data[x_i, y_i] = z_data[i]
|
|
self._grid_index = len(z_data)
|
|
return data, transform
|
|
|
|
def get_step_scan_image(
|
|
self, x_data: list[float], y_data: list[float], z_data: list[float]
|
|
) -> tuple[np.ndarray, QTransform]:
|
|
"""
|
|
Get the image data for an arbitrary step scan.
|
|
|
|
Args:
|
|
x_data (list[float]): The x data.
|
|
y_data (list[float]): The y data.
|
|
z_data (list[float]): The z data.
|
|
|
|
Returns:
|
|
tuple[np.ndarray, QTransform]: The image data and the QTransform.
|
|
"""
|
|
return self.compute_step_scan_image(
|
|
x_data=x_data,
|
|
y_data=y_data,
|
|
z_data=z_data,
|
|
oversampling_factor=self._image_config.oversampling_factor,
|
|
interpolation_method=self._image_config.interpolation,
|
|
)
|
|
|
|
@staticmethod
|
|
def compute_step_scan_image(
|
|
x_data: list[float] | np.ndarray,
|
|
y_data: list[float] | np.ndarray,
|
|
z_data: list[float] | np.ndarray,
|
|
oversampling_factor: float,
|
|
interpolation_method: str,
|
|
) -> tuple[np.ndarray, QTransform]:
|
|
"""Compute interpolated heatmap image from step-scan data.
|
|
|
|
This static method is suitable for execution in a background thread
|
|
as it doesn't access any instance state.
|
|
|
|
Args:
|
|
x_data(list[float]): X coordinates of data points
|
|
y_data(list[float]): Y coordinates of data points
|
|
z_data(list[float]): Z values at each point
|
|
oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution)
|
|
interpolation_method(str): One of 'linear', 'nearest', or 'clough'
|
|
|
|
Returns:
|
|
(tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates
|
|
"""
|
|
xy_data = np.column_stack((x_data, y_data))
|
|
grid_x, grid_y, transform = Heatmap.build_image_grid(
|
|
positions=xy_data, oversampling_factor=oversampling_factor
|
|
)
|
|
|
|
if interpolation_method == "linear":
|
|
interp = LinearNDInterpolator(xy_data, z_data)
|
|
elif interpolation_method == "nearest":
|
|
interp = NearestNDInterpolator(xy_data, z_data)
|
|
elif interpolation_method == "clough":
|
|
interp = CloughTocher2DInterpolator(xy_data, z_data)
|
|
else: # pragma: no cover - guarded by validation
|
|
raise ValueError(
|
|
"Interpolation method must be either 'linear', 'nearest', or 'clough'."
|
|
)
|
|
grid_z = interp(grid_x, grid_y)
|
|
|
|
return grid_z, transform
|
|
|
|
def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]:
|
|
"""
|
|
LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id
|
|
to avoid recalculating the grid for the same scan.
|
|
|
|
Args:
|
|
positions: positions of the data points.
|
|
|
|
Returns:
|
|
tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform.
|
|
"""
|
|
return self.build_image_grid(
|
|
positions=positions, oversampling_factor=self._image_config.oversampling_factor
|
|
)
|
|
|
|
@staticmethod
|
|
def build_image_grid(
|
|
positions: np.ndarray, oversampling_factor: float
|
|
) -> tuple[np.ndarray, np.ndarray, QTransform]:
|
|
"""Build an interpolation grid covering the data positions.
|
|
|
|
Args:
|
|
positions: (N, 2) array of (x, y) coordinates
|
|
oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution)
|
|
|
|
Returns:
|
|
Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids
|
|
for interpolation and transform maps pixel to world coordinates
|
|
"""
|
|
base_width, base_height = Heatmap.estimate_image_resolution(positions)
|
|
width = max(1, int(base_width * oversampling_factor))
|
|
height = max(1, int(base_height * oversampling_factor))
|
|
|
|
grid_x, grid_y = np.mgrid[
|
|
min(positions[:, 0]) : max(positions[:, 0]) : width * 1j,
|
|
min(positions[:, 1]) : max(positions[:, 1]) : height * 1j,
|
|
]
|
|
|
|
x_min, x_max = min(positions[:, 0]), max(positions[:, 0])
|
|
y_min, y_max = min(positions[:, 1]), max(positions[:, 1])
|
|
x_range = x_max - x_min
|
|
y_range = y_max - y_min
|
|
x_scale = x_range / width
|
|
y_scale = y_range / height
|
|
|
|
transform = QTransform()
|
|
transform.scale(x_scale, y_scale)
|
|
transform.translate(x_min / x_scale - 0.5, y_min / y_scale - 0.5)
|
|
|
|
return grid_x, grid_y, transform
|
|
|
|
@staticmethod
|
|
def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]:
|
|
"""
|
|
Estimate the number of pixels needed for the image based on the coordinates.
|
|
|
|
Args:
|
|
coords (np.ndarray): The coordinates of the points.
|
|
|
|
Returns:
|
|
tuple[int, int]: The estimated width and height of the image."""
|
|
if coords.ndim != 2 or coords.shape[1] != 2:
|
|
raise ValueError("Input must be an (m x 2) array of (x, y) coordinates.")
|
|
|
|
x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
|
|
y_min, y_max = coords[:, 1].min(), coords[:, 1].max()
|
|
|
|
tree = cKDTree(coords)
|
|
distances, _ = tree.query(coords, k=2)
|
|
distances = distances[:, 1] # Get the second nearest neighbor distance
|
|
avg_distance = np.mean(distances)
|
|
|
|
width_extent = x_max - x_min
|
|
height_extent = y_max - y_min
|
|
|
|
# Calculate the number of pixels needed based on the average distance
|
|
width_pixels = int(np.ceil(width_extent / avg_distance))
|
|
height_pixels = int(np.ceil(height_extent / avg_distance))
|
|
|
|
return max(1, width_pixels), max(1, height_pixels)
|
|
|
|
@staticmethod
|
|
def arg_bundle_to_dict(bundle_size: int, args: list) -> dict:
|
|
"""
|
|
Convert the argument bundle to a dictionary.
|
|
|
|
Args:
|
|
bundle_size (int): The size of each argument bundle.
|
|
args (list): The argument bundle.
|
|
|
|
Returns:
|
|
dict: The dictionary representation of the argument bundle.
|
|
"""
|
|
params = {}
|
|
for cmds in partition(bundle_size, args):
|
|
params[cmds[0]] = list(cmds[1:])
|
|
return params
|
|
|
|
def _fetch_scan_data_and_access(self):
|
|
"""
|
|
Decide whether the widget is in live or historical mode
|
|
and return the appropriate data dict and access key.
|
|
|
|
Returns:
|
|
data_dict (dict): The data structure for the current scan.
|
|
access_key (str): Either 'val' (live) or 'value' (history).
|
|
"""
|
|
if self.scan_item is None:
|
|
# Optionally fetch the latest from history if nothing is set
|
|
# self.update_with_scan_history(-1)
|
|
if self.scan_item is None:
|
|
logger.info("No scan executed so far; skipping update.")
|
|
return "none", "none"
|
|
|
|
if hasattr(self.scan_item, "live_data"):
|
|
# Live scan
|
|
return self.scan_item.live_data, "val"
|
|
|
|
# Historical
|
|
scan_devices = self.scan_item.devices
|
|
return scan_devices, "value"
|
|
|
|
def reset(self):
|
|
self._cancel_interpolation()
|
|
self._grid_index = None
|
|
self.main_image.clear()
|
|
if self.crosshair is not None:
|
|
self.crosshair.reset()
|
|
super().reset()
|
|
|
|
################################################################################
|
|
# Widget Specific Properties
|
|
################################################################################
|
|
|
|
@SafeProperty(str)
|
|
def device_x(self) -> str:
|
|
"""Device name for the X axis."""
|
|
if self._image_config.device_x is None:
|
|
return ""
|
|
return self._image_config.device_x.device or ""
|
|
|
|
@device_x.setter
|
|
def device_x(self, device_name: str) -> None:
|
|
"""
|
|
Set the X device name.
|
|
|
|
Args:
|
|
device_name(str): Device name for the X axis
|
|
"""
|
|
device_name = device_name or ""
|
|
|
|
# Get current entry or validate
|
|
if device_name:
|
|
try:
|
|
signal = self.entry_validator.validate_signal(device_name, None)
|
|
self._image_config.device_x = HeatmapDeviceSignal(device=device_name, signal=signal)
|
|
self.property_changed.emit("device_x", device_name)
|
|
self.update_labels() # Update axis labels
|
|
self._try_auto_plot()
|
|
except Exception:
|
|
pass # Silently fail if device is not available yet
|
|
else:
|
|
self._image_config.device_x = None
|
|
self.property_changed.emit("device_x", "")
|
|
self.update_labels() # Clear axis labels
|
|
|
|
@SafeProperty(str)
|
|
def signal_x(self) -> str:
|
|
"""Signal entry for the X axis device."""
|
|
if self._image_config.device_x is None:
|
|
return ""
|
|
return self._image_config.device_x.signal or ""
|
|
|
|
@signal_x.setter
|
|
def signal_x(self, entry: str) -> None:
|
|
"""
|
|
Set the X device entry.
|
|
|
|
Args:
|
|
entry(str): Signal entry for the X axis device
|
|
"""
|
|
if not entry:
|
|
return
|
|
|
|
if self._image_config.device_x is None:
|
|
logger.warning("Cannot set signal_x without device_x set first.")
|
|
return
|
|
|
|
device_name = self._image_config.device_x.device
|
|
try:
|
|
# Validate the entry for this device
|
|
validated_signal = self.entry_validator.validate_signal(device_name, entry)
|
|
self._image_config.device_x = HeatmapDeviceSignal(
|
|
device=device_name, signal=validated_signal
|
|
)
|
|
self.property_changed.emit("signal_x", validated_signal)
|
|
self.update_labels() # Update axis labels
|
|
self._try_auto_plot()
|
|
except Exception:
|
|
pass # Silently fail if validation fails
|
|
|
|
@SafeProperty(str)
|
|
def device_y(self) -> str:
|
|
"""Device name for the Y axis."""
|
|
if self._image_config.device_y is None:
|
|
return ""
|
|
return self._image_config.device_y.device or ""
|
|
|
|
@device_y.setter
|
|
def device_y(self, device_name: str) -> None:
|
|
"""
|
|
Set the Y device name.
|
|
|
|
Args:
|
|
device_name(str): Device name for the Y axis
|
|
"""
|
|
device_name = device_name or ""
|
|
|
|
# Get current entry or validate
|
|
if device_name:
|
|
try:
|
|
signal = self.entry_validator.validate_signal(device_name, None)
|
|
self._image_config.device_y = HeatmapDeviceSignal(device=device_name, signal=signal)
|
|
self.property_changed.emit("device_y", device_name)
|
|
self.update_labels() # Update axis labels
|
|
self._try_auto_plot()
|
|
except Exception:
|
|
pass # Silently fail if device is not available yet
|
|
else:
|
|
self._image_config.device_y = None
|
|
self.property_changed.emit("device_y", "")
|
|
self.update_labels() # Clear axis labels
|
|
|
|
@SafeProperty(str)
|
|
def signal_y(self) -> str:
|
|
"""Signal entry for the Y axis device."""
|
|
if self._image_config.device_y is None:
|
|
return ""
|
|
return self._image_config.device_y.signal or ""
|
|
|
|
@signal_y.setter
|
|
def signal_y(self, entry: str) -> None:
|
|
"""
|
|
Set the Y device entry.
|
|
|
|
Args:
|
|
entry(str): Signal entry for the Y axis device
|
|
"""
|
|
if not entry:
|
|
return
|
|
|
|
if self._image_config.device_y is None:
|
|
logger.warning("Cannot set signal_y without device_y set first.")
|
|
return
|
|
|
|
device_name = self._image_config.device_y.device
|
|
try:
|
|
# Validate the entry for this device
|
|
validated_signal = self.entry_validator.validate_signal(device_name, entry)
|
|
self._image_config.device_y = HeatmapDeviceSignal(
|
|
device=device_name, signal=validated_signal
|
|
)
|
|
self.property_changed.emit("signal_y", validated_signal)
|
|
self.update_labels() # Update axis labels
|
|
self._try_auto_plot()
|
|
except Exception as e:
|
|
logger.debug(f"Y device entry validation failed: {e}")
|
|
pass # Silently fail if validation fails
|
|
|
|
@SafeProperty(str)
|
|
def device_z(self) -> str:
|
|
"""Device name for the Z (color) axis."""
|
|
if self._image_config.device_z is None:
|
|
return ""
|
|
return self._image_config.device_z.device or ""
|
|
|
|
@device_z.setter
|
|
def device_z(self, device_name: str) -> None:
|
|
"""
|
|
Set the Z device name.
|
|
|
|
Args:
|
|
device_name(str): Device name for the Z axis
|
|
"""
|
|
device_name = device_name or ""
|
|
|
|
# Get current entry or validate
|
|
if device_name:
|
|
try:
|
|
signal = self.entry_validator.validate_signal(device_name, None)
|
|
self._image_config.device_z = HeatmapDeviceSignal(device=device_name, signal=signal)
|
|
self.property_changed.emit("device_z", device_name)
|
|
self.update_labels() # Update axis labels (title)
|
|
self._try_auto_plot()
|
|
except Exception as e:
|
|
logger.debug(f"Z device name validation failed: {e}")
|
|
pass # Silently fail if device is not available yet
|
|
else:
|
|
self._image_config.device_z = None
|
|
self.property_changed.emit("device_z", "")
|
|
self.update_labels() # Clear axis labels
|
|
|
|
@SafeProperty(str)
|
|
def signal_z(self) -> str:
|
|
"""Signal entry for the Z (color) axis device."""
|
|
if self._image_config.device_z is None:
|
|
return ""
|
|
return self._image_config.device_z.signal or ""
|
|
|
|
@signal_z.setter
|
|
def signal_z(self, entry: str) -> None:
|
|
"""
|
|
Set the Z device entry.
|
|
|
|
Args:
|
|
entry(str): Signal entry for the Z axis device
|
|
"""
|
|
if not entry:
|
|
return
|
|
|
|
if self._image_config.device_z is None:
|
|
logger.warning("Cannot set signal_z without device_z set first.")
|
|
return
|
|
|
|
device_name = self._image_config.device_z.device
|
|
try:
|
|
# Validate the entry for this device
|
|
validated_signal = self.entry_validator.validate_signal(device_name, entry)
|
|
self._image_config.device_z = HeatmapDeviceSignal(
|
|
device=device_name, signal=validated_signal
|
|
)
|
|
self.property_changed.emit("signal_z", validated_signal)
|
|
self.update_labels() # Update axis labels (title)
|
|
self._try_auto_plot()
|
|
except Exception as e:
|
|
logger.debug(f"Z device entry validation failed: {e}")
|
|
pass # Silently fail if validation fails
|
|
|
|
def _try_auto_plot(self) -> None:
|
|
"""
|
|
Attempt to automatically call plot() if all three devices are set.
|
|
Similar to waveform's approach but requires all three devices.
|
|
"""
|
|
has_x = self._image_config.device_x is not None
|
|
has_y = self._image_config.device_y is not None
|
|
has_z = self._image_config.device_z is not None
|
|
|
|
if has_x and has_y and has_z:
|
|
device_x = self._image_config.device_x.device
|
|
signal_x = self._image_config.device_x.signal
|
|
device_y = self._image_config.device_y.device
|
|
signal_y = self._image_config.device_y.signal
|
|
device_z = self._image_config.device_z.device
|
|
signal_z = self._image_config.device_z.signal
|
|
try:
|
|
self.plot(
|
|
device_x=device_x,
|
|
device_y=device_y,
|
|
device_z=device_z,
|
|
signal_x=signal_x,
|
|
signal_y=signal_y,
|
|
signal_z=signal_z,
|
|
validate_bec=False, # Don't validate - entries already validated
|
|
)
|
|
except Exception as e:
|
|
logger.debug(f"Auto-plot failed: {e}")
|
|
pass # Silently fail if plot cannot be called yet
|
|
|
|
@SafeProperty(str)
|
|
def interpolation_method(self) -> str:
|
|
"""
|
|
The interpolation method used for the heatmap.
|
|
"""
|
|
return self._image_config.interpolation
|
|
|
|
@interpolation_method.setter
|
|
def interpolation_method(self, value: str):
|
|
"""
|
|
Set the interpolation method for the heatmap.
|
|
Args:
|
|
value(str): The interpolation method, either 'linear' or 'nearest'.
|
|
"""
|
|
if value not in ["linear", "nearest"]:
|
|
raise ValueError("Interpolation method must be either 'linear' or 'nearest'.")
|
|
self._image_config.interpolation = value
|
|
self.heatmap_property_changed.emit()
|
|
|
|
@SafeProperty(float)
|
|
def oversampling_factor(self) -> float:
|
|
"""
|
|
The oversampling factor for grid resolution.
|
|
"""
|
|
return self._image_config.oversampling_factor
|
|
|
|
@oversampling_factor.setter
|
|
def oversampling_factor(self, value: float):
|
|
"""
|
|
Set the oversampling factor for grid resolution.
|
|
Args:
|
|
value(float): The oversampling factor (1.0 = no oversampling, 2.0 = 2x resolution).
|
|
"""
|
|
if value <= 0:
|
|
raise ValueError("Oversampling factor must be greater than 0.")
|
|
self._image_config.oversampling_factor = value
|
|
self.heatmap_property_changed.emit()
|
|
|
|
@SafeProperty(bool)
|
|
def enforce_interpolation(self) -> bool:
|
|
"""
|
|
Whether to enforce interpolation even for grid scans.
|
|
"""
|
|
return self._image_config.enforce_interpolation
|
|
|
|
@enforce_interpolation.setter
|
|
def enforce_interpolation(self, value: bool):
|
|
"""
|
|
Set whether to enforce interpolation even for grid scans.
|
|
|
|
Args:
|
|
value(bool): Whether to enforce interpolation.
|
|
"""
|
|
self._image_config.enforce_interpolation = value
|
|
self.heatmap_property_changed.emit()
|
|
|
|
################################################################################
|
|
# Post Processing
|
|
################################################################################
|
|
|
|
@SafeProperty(bool, auto_emit=True)
|
|
def fft(self) -> bool:
|
|
"""
|
|
Whether FFT postprocessing is enabled.
|
|
"""
|
|
return self.main_image.fft
|
|
|
|
@fft.setter
|
|
def fft(self, enable: bool):
|
|
"""
|
|
Set FFT postprocessing.
|
|
|
|
Args:
|
|
enable(bool): Whether to enable FFT postprocessing.
|
|
"""
|
|
self.main_image.fft = enable
|
|
|
|
@SafeProperty(bool, auto_emit=True)
|
|
def log(self) -> bool:
|
|
"""
|
|
Whether logarithmic scaling is applied.
|
|
"""
|
|
return self.main_image.log
|
|
|
|
@log.setter
|
|
def log(self, enable: bool):
|
|
"""
|
|
Set logarithmic scaling.
|
|
|
|
Args:
|
|
enable(bool): Whether to enable logarithmic scaling.
|
|
"""
|
|
self.main_image.log = enable
|
|
|
|
@SafeProperty(int)
|
|
def num_rotation_90(self) -> int:
|
|
"""
|
|
The number of 90° rotations to apply counterclockwise.
|
|
"""
|
|
return self.main_image.num_rotation_90
|
|
|
|
@num_rotation_90.setter
|
|
def num_rotation_90(self, value: int):
|
|
"""
|
|
Set the number of 90° rotations to apply counterclockwise.
|
|
|
|
Args:
|
|
value(int): The number of 90° rotations to apply.
|
|
"""
|
|
self.main_image.num_rotation_90 = value
|
|
|
|
@SafeProperty(bool, auto_emit=True)
|
|
def transpose(self) -> bool:
|
|
"""
|
|
Whether the image is transposed.
|
|
"""
|
|
return self.main_image.transpose
|
|
|
|
@transpose.setter
|
|
def transpose(self, enable: bool):
|
|
"""
|
|
Set the image to be transposed.
|
|
|
|
Args:
|
|
enable(bool): Whether to enable transposing the image.
|
|
"""
|
|
self.main_image.transpose = enable
|
|
|
|
def cleanup(self):
|
|
self._finish_interpolation_thread()
|
|
super().cleanup()
|
|
|
|
|
|
if __name__ == "__main__": # pragma: no cover
|
|
import sys
|
|
|
|
from qtpy.QtWidgets import QApplication
|
|
|
|
app = QApplication(sys.argv)
|
|
heatmap = Heatmap()
|
|
heatmap.plot(device_x="samx", device_y="samy", device_z="bpm4i", oversampling_factor=5.0)
|
|
heatmap.show()
|
|
sys.exit(app.exec_())
|