1
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2026-04-11 03:00:54 +02:00

Compare commits

...

5 Commits

18 changed files with 1884 additions and 474 deletions

View File

@@ -2502,16 +2502,30 @@ class Image(RPCBase):
@property
@rpc_call
def monitor(self) -> "str":
def device_name(self) -> "str":
"""
The name of the monitor to use for the image.
The name of the device to monitor for image data.
"""
@monitor.setter
@device_name.setter
@rpc_call
def monitor(self) -> "str":
def device_name(self) -> "str":
"""
The name of the monitor to use for the image.
The name of the device to monitor for image data.
"""
@property
@rpc_call
def device_entry(self) -> "str":
"""
The signal/entry name to monitor on the device.
"""
@device_entry.setter
@rpc_call
def device_entry(self) -> "str":
"""
The signal/entry name to monitor on the device.
"""
@rpc_call
@@ -2617,8 +2631,8 @@ class Image(RPCBase):
@rpc_call
def image(
self,
monitor: "str | tuple | None" = None,
monitor_type: "Literal['auto', '1d', '2d']" = "auto",
device_name: "str | None" = None,
device_entry: "str | None" = None,
color_map: "str | None" = None,
color_bar: "Literal['simple', 'full'] | None" = None,
vrange: "tuple[int, int] | None" = None,
@@ -2627,14 +2641,14 @@ class Image(RPCBase):
Set the image source and update the image.
Args:
monitor(str|tuple|None): The name of the monitor to use for the image, or a tuple of (device, signal) for preview signals. If None or empty string, the current monitor will be disconnected.
monitor_type(str): The type of monitor to use. Options are "1d", "2d", or "auto".
device_name(str|None): The name of the device to monitor. If None or empty string, the current monitor will be disconnected.
device_entry(str|None): The signal/entry name to monitor on the device.
color_map(str): The color map to use for the image.
color_bar(str): The type of color bar to use. Options are "simple" or "full".
vrange(tuple): The range of values to use for the color map.
Returns:
ImageItem: The image object.
ImageItem: The image object, or None if connection failed.
"""
@property
@@ -5959,7 +5973,8 @@ class Waveform(RPCBase):
y_entry: "str | None" = None,
color: "str | None" = None,
label: "str | None" = None,
dap: "str | None" = None,
dap: "str | list[str] | None" = None,
dap_parameters: "dict | list | lmfit.Parameters | None | object" = None,
scan_id: "str | None" = None,
scan_number: "int | None" = None,
**kwargs,
@@ -5981,9 +5996,14 @@ class Waveform(RPCBase):
y_entry(str): The name of the entry for the y-axis.
color(str): The color of the curve.
label(str): The label of the curve.
dap(str): The dap model to use for the curve. When provided, a DAP curve is
dap(str | list[str]): The dap model to use for the curve. When provided, a DAP curve is
attached automatically for device, history, or custom data sources. Use
the same string as the LMFit model name.
the same string as the LMFit model name, or a list of model names to build a composite.
dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to
the DAP server. For a single model: values can be numeric (interpreted as fixed parameters)
or dicts like `{"value": 1.0, "vary": False}`. For composite models (dap is list), use either
a list aligned to the model list (each item is a param dict), or a dict of
`{ "ModelName": { "param": {...} } }` when model names are unique.
scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and
the ydata (and optional xdata) are fetched from that historical scan. Such curves are
never cleared by livescan resets.
@@ -5997,9 +6017,10 @@ class Waveform(RPCBase):
def add_dap_curve(
self,
device_label: "str",
dap_name: "str",
dap_name: "str | list[str]",
color: "str | None" = None,
dap_oversample: "int" = 1,
dap_parameters: "dict | list | lmfit.Parameters | None" = None,
**kwargs,
) -> "Curve":
"""
@@ -6009,9 +6030,11 @@ class Waveform(RPCBase):
Args:
device_label(str): The label of the source curve to add DAP to.
dap_name(str): The name of the DAP model to use.
dap_name(str | list[str]): The name of the DAP model to use, or a list of model
names to build a composite model.
color(str): The color of the curve.
dap_oversample(int): The oversampling factor for the DAP curve.
dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
**kwargs
Returns:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import re
from functools import lru_cache
from typing import Literal
import numpy as np
@@ -9,6 +10,7 @@ from bec_lib import bec_logger
from bec_qthemes import apply_theme as apply_theme_global
from bec_qthemes._theme import AccentColors
from pydantic_core import PydanticCustomError
from pyqtgraph.graphicsItems.GradientEditorItem import Gradients
from qtpy.QtCore import QEvent, QEventLoop
from qtpy.QtGui import QColor
from qtpy.QtWidgets import QApplication
@@ -57,6 +59,91 @@ def apply_theme(theme: Literal["dark", "light"]):
class Colors:
@staticmethod
def list_available_colormaps() -> list[str]:
"""
List colormap names available via the pyqtgraph colormap registry.
Note: This does not include `GradientEditorItem` presets (used by HistogramLUT menus).
"""
def _list(source: str | None = None) -> list[str]:
try:
return pg.colormap.listMaps() if source is None else pg.colormap.listMaps(source)
except Exception: # pragma: no cover - backend may be missing
return []
return [*_list(None), *_list("matplotlib"), *_list("colorcet")]
@staticmethod
def list_available_gradient_presets() -> list[str]:
"""
List `GradientEditorItem` preset names (HistogramLUT right-click menu entries).
"""
from pyqtgraph.graphicsItems.GradientEditorItem import Gradients
return list(Gradients.keys())
@staticmethod
def canonical_colormap_name(color_map: str) -> str:
"""
Return an available colormap/preset name if a case-insensitive match exists.
"""
requested = (color_map or "").strip()
if not requested:
return requested
registry = Colors.list_available_colormaps()
presets = Colors.list_available_gradient_presets()
available = set(registry) | set(presets)
if requested in available:
return requested
# Case-insensitive match.
lower_to_canonical = {name.lower(): name for name in available}
return lower_to_canonical.get(requested.lower(), requested)
@staticmethod
def get_colormap(color_map: str) -> pg.ColorMap:
"""
Resolve a string into a `pg.ColorMap` using either:
- the `pg.colormap` registry (optionally including matplotlib/colorcet backends), or
- `GradientEditorItem` presets (HistogramLUT right-click menu).
"""
name = Colors.canonical_colormap_name(color_map)
if not name:
raise ValueError("Empty colormap name")
return Colors._get_colormap_cached(name)
@staticmethod
@lru_cache(maxsize=256)
def _get_colormap_cached(name: str) -> pg.ColorMap:
# 1) Registry/backends
try:
cmap = pg.colormap.get(name)
if cmap is not None:
return cmap
except Exception:
pass
for source in ("matplotlib", "colorcet"):
try:
cmap = pg.colormap.get(name, source=source)
if cmap is not None:
return cmap
except Exception:
continue
# 2) Presets -> ColorMap
if name not in Gradients:
raise KeyError(f"Colormap '{name}' not found")
ge = pg.GradientEditorItem()
ge.loadPreset(name)
return ge.colorMap()
@staticmethod
def golden_ratio(num: int) -> list:
@@ -138,7 +225,7 @@ class Colors:
if theme_offset < 0 or theme_offset > 1:
raise ValueError("theme_offset must be between 0 and 1")
cmap = pg.colormap.get(colormap)
cmap = Colors.get_colormap(colormap)
min_pos, max_pos = Colors.set_theme_offset(theme, theme_offset)
# Generate positions that are evenly spaced within the acceptable range
@@ -186,7 +273,7 @@ class Colors:
ValueError: If theme_offset is not between 0 and 1.
"""
cmap = pg.colormap.get(colormap)
cmap = Colors.get_colormap(colormap)
phi = (1 + np.sqrt(5)) / 2 # Golden ratio
golden_angle_conjugate = 1 - (1 / phi) # Approximately 0.38196601125
@@ -452,21 +539,24 @@ class Colors:
Raises:
PydanticCustomError: If colormap is invalid.
"""
available_pg_maps = pg.colormap.listMaps()
available_mpl_maps = pg.colormap.listMaps("matplotlib")
available_mpl_colorcet = pg.colormap.listMaps("colorcet")
available_colormaps = available_pg_maps + available_mpl_maps + available_mpl_colorcet
if color_map not in available_colormaps:
normalized = Colors.canonical_colormap_name(color_map)
try:
Colors.get_colormap(normalized)
except Exception as ext:
logger.warning(f"Colormap validation error: {ext}")
if return_error:
available_colormaps = sorted(
set(Colors.list_available_colormaps())
| set(Colors.list_available_gradient_presets())
)
raise PydanticCustomError(
"unsupported colormap",
f"Colormap '{color_map}' not found in the current installation of pyqtgraph. Choose on the following: {available_colormaps}.",
f"Colormap '{color_map}' not found in the current installation of pyqtgraph. Choose from the following: {available_colormaps}.",
{"wrong_value": color_map},
)
else:
return False
return color_map
return normalized
@staticmethod
def relative_luminance(color: QColor) -> float:

View File

@@ -186,6 +186,11 @@ class DeviceComboBox(DeviceInputBase, QComboBox):
device = self.itemData(idx)[0] # type: ignore[assignment]
return super().validate_device(device)
@property
def is_valid_input(self) -> bool:
"""Whether the current text represents a valid device selection."""
return self._is_valid_input
if __name__ == "__main__": # pragma: no cover
# pylint: disable=import-outside-toplevel

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError
from qtpy.QtCore import QPointF, Signal, SignalInstance
from qtpy.QtWidgets import QDialog, QVBoxLayout
from bec_widgets.utils import Colors
from bec_widgets.utils.container_utils import WidgetContainerUtils
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils.side_panel import SidePanel
@@ -131,8 +132,9 @@ class ImageLayerManager:
image.setZValue(z_position)
image.removed.connect(self._remove_destroyed_layer)
# FIXME: For now, we hard-code the default color map here. In the future, this should be configurable.
image.color_map = "plasma"
color_map = getattr(getattr(self.parent, "config", None), "color_map", None)
if color_map:
image.color_map = color_map
self.layers[name] = ImageLayer(name=name, image=image, sync=sync)
self.plot_item.addItem(image)
@@ -249,6 +251,8 @@ class ImageBase(PlotBase):
Base class for the Image widget.
"""
MAX_TICKS_COLORBAR = 10
sync_colorbar_with_autorange = Signal()
image_updated = Signal()
layer_added = Signal(str)
@@ -460,18 +464,20 @@ class ImageBase(PlotBase):
self.setProperty("autorange", False)
if style == "simple":
self._color_bar = pg.ColorBarItem(colorMap=self.config.color_map)
cmap = Colors.get_colormap(self.config.color_map)
self._color_bar = pg.ColorBarItem(colorMap=cmap)
self._color_bar.setImageItem(self.layer_manager["main"].image)
self._color_bar.sigLevelsChangeFinished.connect(disable_autorange)
self.config.color_bar = "simple"
elif style == "full":
self._color_bar = pg.HistogramLUTItem()
self._color_bar.setImageItem(self.layer_manager["main"].image)
self._color_bar.gradient.loadPreset(self.config.color_map)
self.config.color_bar = "full"
self._apply_colormap_to_colorbar(self.config.color_map)
self._color_bar.sigLevelsChanged.connect(disable_autorange)
self.plot_widget.addItem(self._color_bar, row=0, col=1)
self.config.color_bar = style
else:
if self._color_bar:
self.plot_widget.removeItem(self._color_bar)
@@ -484,6 +490,37 @@ class ImageBase(PlotBase):
if vrange: # should be at the end to disable the autorange if defined
self.v_range = vrange
def _apply_colormap_to_colorbar(self, color_map: str) -> None:
if not self._color_bar:
return
cmap = Colors.get_colormap(color_map)
if self.config.color_bar == "simple":
self._color_bar.setColorMap(cmap)
return
if self.config.color_bar != "full":
return
gradient = getattr(self._color_bar, "gradient", None)
if gradient is None:
return
positions = np.linspace(0.0, 1.0, self.MAX_TICKS_COLORBAR)
colors = cmap.map(positions, mode="byte")
colors = np.asarray(colors)
if colors.ndim != 2:
return
if colors.shape[1] == 3: # add alpha
alpha = np.full((colors.shape[0], 1), 255, dtype=colors.dtype)
colors = np.concatenate([colors, alpha], axis=1)
ticks = [(float(p), tuple(int(x) for x in c)) for p, c in zip(positions, colors)]
state = {"mode": "rgb", "ticks": ticks}
gradient.restoreState(state)
################################################################################
# Static rois with roi manager
@@ -754,10 +791,7 @@ class ImageBase(PlotBase):
layer.image.color_map = value
if self._color_bar:
if self.config.color_bar == "simple":
self._color_bar.setColorMap(value)
elif self.config.color_bar == "full":
self._color_bar.gradient.loadPreset(value)
self._apply_colormap_to_colorbar(self.config.color_map)
except ValidationError:
return

View File

@@ -119,7 +119,8 @@ class ImageItem(BECConnector, pg.ImageItem):
"""Set a new color map."""
try:
self.config.color_map = value
self.setColorMap(value)
cmap = Colors.get_colormap(self.config.color_map)
self.setColorMap(cmap)
except ValidationError:
logger.error(f"Invalid colormap '{value}' provided.")

View File

@@ -0,0 +1,253 @@
from qtpy.QtWidgets import QHBoxLayout, QSizePolicy, QWidget
from bec_widgets.utils.toolbars.actions import WidgetAction
from bec_widgets.utils.toolbars.bundles import ToolbarBundle, ToolbarComponents
from bec_widgets.utils.toolbars.connections import BundleConnection
from bec_widgets.widgets.control.device_input.device_combobox.device_combobox import DeviceComboBox
from bec_widgets.widgets.control.device_input.signal_combobox.signal_combobox import SignalComboBox
class DeviceSelection(QWidget):
"""Device and signal selection widget for image toolbar."""
def __init__(self, parent=None, client=None):
super().__init__(parent=parent)
self.client = client
self.supported_signals = [
"PreviewSignal",
"AsyncSignal",
"AsyncMultiSignal",
"DynamicSignal",
]
# Create device combobox with signal class filter
# This will only show devices that have signals matching the supported signal classes
self.device_combo_box = DeviceComboBox(
parent=self, client=self.client, signal_class_filter=self.supported_signals
)
self.device_combo_box.setToolTip("Select Device")
self.device_combo_box.setEditable(True)
# Set expanding size policy so it grows with available space
self.device_combo_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
# Configure SignalComboBox to filter by PreviewSignal and supported async signals
# Also filter by ndim (1D and 2D only) for Image widget
self.signal_combo_box = SignalComboBox(
parent=self,
client=self.client,
signal_class_filter=[
"PreviewSignal",
"AsyncSignal",
"AsyncMultiSignal",
"DynamicSignal",
],
ndim_filter=[1, 2], # Only show 1D and 2D signals for Image widget
store_signal_config=True,
require_device=True,
)
self.signal_combo_box.setToolTip("Select Signal")
self.signal_combo_box.setEditable(True)
# Set expanding size policy so it grows with available space
self.signal_combo_box.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
# Connect comboboxes together
self.device_combo_box.currentTextChanged.connect(self.signal_combo_box.set_device)
self.device_combo_box.device_reset.connect(self.signal_combo_box.reset_selection)
# Simple horizontal layout with stretch to fill space
layout = QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(2)
layout.addWidget(self.device_combo_box, stretch=1)
layout.addWidget(self.signal_combo_box, stretch=1)
def set_device_and_signal(self, device_name: str | None, device_entry: str | None) -> None:
"""Set the displayed device and signal without emitting selection signals."""
device_name = device_name or ""
device_entry = device_entry or ""
self.device_combo_box.blockSignals(True)
self.signal_combo_box.blockSignals(True)
try:
if device_name:
# Set device in device_combo_box
index = self.device_combo_box.findText(device_name)
if index >= 0:
self.device_combo_box.setCurrentIndex(index)
else:
# Device not found in list, but still set it
self.device_combo_box.setCurrentText(device_name)
# Only update signal combobox device filter if it's actually changing
# This prevents redundant repopulation which can cause duplicates !!!!
current_device = getattr(self.signal_combo_box, "_device", None)
if current_device != device_name:
self.signal_combo_box.set_device(device_name)
# Sync signal combobox selection
if device_entry:
# Try to find the signal by component_name (which is what's displayed)
found = False
for i in range(self.signal_combo_box.count()):
text = self.signal_combo_box.itemText(i)
config_data = self.signal_combo_box.itemData(i)
# Check if this matches our signal
if config_data:
component_name = config_data.get("component_name", "")
if text == component_name or text == device_entry:
self.signal_combo_box.setCurrentIndex(i)
found = True
break
if not found:
# Fallback: try to match the device_entry directly
index = self.signal_combo_box.findText(device_entry)
if index >= 0:
self.signal_combo_box.setCurrentIndex(index)
else:
# No device set, clear selections
self.device_combo_box.setCurrentText("")
self.signal_combo_box.reset_selection()
finally:
# Always unblock signals
self.device_combo_box.blockSignals(False)
self.signal_combo_box.blockSignals(False)
def set_connection_status(self, status: str, message: str | None = None) -> None:
tooltip = f"Connection status: {status}"
if message:
tooltip = f"{tooltip}\n{message}"
self.device_combo_box.setToolTip(tooltip)
self.signal_combo_box.setToolTip(tooltip)
if not self.device_combo_box.is_valid_input or not self.signal_combo_box.is_valid_input:
return
if status == "error":
style = "border: 1px solid orange;"
else:
style = "border: 1px solid transparent;"
self.device_combo_box.setStyleSheet(style)
self.signal_combo_box.setStyleSheet(style)
def cleanup(self):
"""Clean up the widget resources."""
self.device_combo_box.close()
self.device_combo_box.deleteLater()
self.signal_combo_box.close()
self.signal_combo_box.deleteLater()
def device_selection_bundle(components: ToolbarComponents, client=None) -> ToolbarBundle:
"""
Creates a device selection toolbar bundle for Image widget.
Includes a resizable splitter after the device selection. All subsequent bundles'
actions will appear compactly after the splitter with no gaps.
Args:
components (ToolbarComponents): The components to be added to the bundle.
client: The BEC client instance.
Returns:
ToolbarBundle: The device selection toolbar bundle.
"""
device_selection_widget = DeviceSelection(parent=components.toolbar, client=client)
components.add_safe(
"device_selection", WidgetAction(widget=device_selection_widget, adjust_size=False)
)
bundle = ToolbarBundle("device_selection", components)
bundle.add_action("device_selection")
bundle.add_splitter(
name="device_selection_splitter",
target_widget=device_selection_widget,
min_width=210,
max_width=600,
)
return bundle
class DeviceSelectionConnection(BundleConnection):
"""
Connection helper for the device selection bundle.
"""
def __init__(self, components: ToolbarComponents, target_widget=None):
super().__init__(parent=components.toolbar)
self.bundle_name = "device_selection"
self.components = components
self.target_widget = target_widget
self._connected = False
self.register_property_sync("device_name", self._sync_from_device_name)
self.register_property_sync("device_entry", self._sync_from_device_entry)
self.register_property_sync("connection_status", self._sync_connection_status)
self.register_property_sync("connection_error", self._sync_connection_status)
def _widget(self) -> DeviceSelection:
return self.components.get_action("device_selection").widget
def connect(self):
if self._connected:
return
widget = self._widget()
widget.device_combo_box.device_selected.connect(
self.target_widget.on_device_selection_changed
)
widget.signal_combo_box.device_signal_changed.connect(
self.target_widget.on_device_selection_changed
)
self.connect_property_sync(self.target_widget)
self._connected = True
def disconnect(self):
if not self._connected:
return
widget = self._widget()
widget.device_combo_box.device_selected.disconnect(
self.target_widget.on_device_selection_changed
)
widget.signal_combo_box.device_signal_changed.disconnect(
self.target_widget.on_device_selection_changed
)
self.disconnect_property_sync(self.target_widget)
self._connected = False
widget.cleanup()
def _sync_from_device_name(self, _):
try:
widget = self._widget()
except Exception:
return
widget.set_device_and_signal(
self.target_widget.device_name, self.target_widget.device_entry
)
self.target_widget._sync_device_entry_from_toolbar()
def _sync_from_device_entry(self, _):
try:
widget = self._widget()
except Exception:
return
widget.set_device_and_signal(
self.target_widget.device_name, self.target_widget.device_entry
)
def _sync_connection_status(self, _):
try:
widget = self._widget()
except Exception:
return
widget.set_connection_status(
self.target_widget._config.connection_status,
self.target_widget._config.connection_error,
)

View File

@@ -22,8 +22,9 @@ class DeviceSignal(BaseModel):
name: str
entry: str
dap: str | None = None
dap: str | list[str] | None = None
dap_oversample: int = 1
dap_parameters: dict | list | None = None
model_config: dict = {"validate_assignment": True}

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import json
from typing import Literal
from typing import TYPE_CHECKING, Literal
import lmfit
import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger, messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.lmfit_serializer import serialize_lmfit_params, serialize_param_object
from bec_lib.scan_data_container import ScanDataContainer
from pydantic import Field, ValidationError, field_validator
from qtpy.QtCore import Qt, QTimer, Signal
@@ -41,6 +41,15 @@ from bec_widgets.widgets.services.scan_history_browser.scan_history_browser impo
)
logger = bec_logger.logger
_DAP_PARAM = object()
if TYPE_CHECKING: # pragma: no cover
import lmfit # type: ignore
else:
try:
import lmfit # type: ignore
except Exception: # pragma: no cover
lmfit = None
# noinspection PyDataclass
@@ -696,7 +705,8 @@ class Waveform(PlotBase):
y_entry: str | None = None,
color: str | None = None,
label: str | None = None,
dap: str | None = None,
dap: str | list[str] | None = None,
dap_parameters: dict | list | lmfit.Parameters | None | object = None,
scan_id: str | None = None,
scan_number: int | None = None,
**kwargs,
@@ -718,9 +728,14 @@ class Waveform(PlotBase):
y_entry(str): The name of the entry for the y-axis.
color(str): The color of the curve.
label(str): The label of the curve.
dap(str): The dap model to use for the curve. When provided, a DAP curve is
dap(str | list[str]): The dap model to use for the curve. When provided, a DAP curve is
attached automatically for device, history, or custom data sources. Use
the same string as the LMFit model name.
the same string as the LMFit model name, or a list of model names to build a composite.
dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to
the DAP server. For a single model: values can be numeric (interpreted as fixed parameters)
or dicts like `{"value": 1.0, "vary": False}`. For composite models (dap is list), use either
a list aligned to the model list (each item is a param dict), or a dict of
`{ "ModelName": { "param": {...} } }` when model names are unique.
scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and
the ydata (and optional xdata) are fetched from that historical scan. Such curves are
never cleared by livescan resets.
@@ -733,6 +748,8 @@ class Waveform(PlotBase):
source = "custom"
x_data = None
y_data = None
if dap_parameters is _DAP_PARAM:
dap_parameters = kwargs.pop("dap_parameters", None) or kwargs.pop("parameters", None)
# 1. Custom curve logic
if x is not None and y is not None:
@@ -810,7 +827,9 @@ class Waveform(PlotBase):
curve = self._add_curve(config=config, x_data=x_data, y_data=y_data)
if dap is not None and curve.config.source in ("device", "history", "custom"):
self.add_dap_curve(device_label=curve.name(), dap_name=dap, **kwargs)
self.add_dap_curve(
device_label=curve.name(), dap_name=dap, dap_parameters=dap_parameters, **kwargs
)
return curve
@@ -820,9 +839,10 @@ class Waveform(PlotBase):
def add_dap_curve(
self,
device_label: str,
dap_name: str,
dap_name: str | list[str],
color: str | None = None,
dap_oversample: int = 1,
dap_parameters: dict | list | lmfit.Parameters | None = None,
**kwargs,
) -> Curve:
"""
@@ -832,9 +852,11 @@ class Waveform(PlotBase):
Args:
device_label(str): The label of the source curve to add DAP to.
dap_name(str): The name of the DAP model to use.
dap_name(str | list[str]): The name of the DAP model to use, or a list of model
names to build a composite model.
color(str): The color of the curve.
dap_oversample(int): The oversampling factor for the DAP curve.
dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
**kwargs
Returns:
@@ -859,7 +881,7 @@ class Waveform(PlotBase):
dev_entry = "custom"
# 2) Build a label for the new DAP curve
dap_label = f"{device_label}-{dap_name}"
dap_label = f"{device_label}-{self._format_dap_label(dap_name)}"
# 3) Possibly raise if the DAP curve already exists
if self._check_curve_id(dap_label):
@@ -882,7 +904,11 @@ class Waveform(PlotBase):
# Attach device signal with DAP
config.signal = DeviceSignal(
name=dev_name, entry=dev_entry, dap=dap_name, dap_oversample=dap_oversample
name=dev_name,
entry=dev_entry,
dap=dap_name,
dap_oversample=dap_oversample,
dap_parameters=self._normalize_dap_parameters(dap_parameters, dap_name=dap_name),
)
# 4) Create the DAP curve config using `_add_curve(...)`
@@ -1754,7 +1780,9 @@ class Waveform(PlotBase):
x_data, y_data = parent_curve.get_data()
model_name = dap_curve.config.signal.dap
model = getattr(self.dap, model_name)
model = None
if not isinstance(model_name, (list, tuple)):
model = getattr(self.dap, model_name)
try:
x_min, x_max = self.roi_region
x_data, y_data = self._crop_data(x_data, y_data, x_min, x_max)
@@ -1762,20 +1790,132 @@ class Waveform(PlotBase):
x_min = None
x_max = None
dap_parameters = getattr(dap_curve.config.signal, "dap_parameters", None)
dap_kwargs = {
"data_x": x_data,
"data_y": y_data,
"oversample": dap_curve.dap_oversample,
}
if dap_parameters:
dap_kwargs["parameters"] = dap_parameters
if model is not None:
class_args = model._plugin_info["class_args"]
class_kwargs = model._plugin_info["class_kwargs"]
else:
class_args = []
class_kwargs = {"model": model_name}
msg = messages.DAPRequestMessage(
dap_cls="LmfitService1D",
dap_type="on_demand",
config={
"args": [],
"kwargs": {"data_x": x_data, "data_y": y_data},
"class_args": model._plugin_info["class_args"],
"class_kwargs": model._plugin_info["class_kwargs"],
"kwargs": dap_kwargs,
"class_args": class_args,
"class_kwargs": class_kwargs,
"curve_label": dap_curve.name(),
},
metadata={"RID": f"{self.scan_id}-{self.gui_id}"},
)
self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg)
@staticmethod
def _normalize_dap_parameters(
parameters: dict | list | lmfit.Parameters | None, dap_name: str | list[str] | None = None
) -> dict | list | None:
"""
Normalize user-provided lmfit parameters into a JSON-serializable dict suitable for the DAP server.
Supports:
- `lmfit.Parameters` (single-model only)
- `dict[name -> number]` (treated as fixed parameter with `vary=False`)
- `dict[name -> dict]` (lmfit.Parameter fields; defaults to `vary=False` if unspecified)
- `dict[name -> lmfit.Parameter]`
- composite: `list[dict[param_name -> spec]]` aligned to model list
- composite: `dict[model_name -> dict[param_name -> spec]]` (unique model names only)
"""
if parameters is None:
return None
if isinstance(dap_name, (list, tuple)):
if lmfit is not None and isinstance(parameters, lmfit.Parameters):
raise TypeError("dap_parameters must be a dict when using composite dap models.")
if isinstance(parameters, (list, tuple)):
normalized_list: list[dict | None] = []
for idx, item in enumerate(parameters):
if item is None:
normalized_list.append(None)
continue
if not isinstance(item, dict):
raise TypeError(
f"dap_parameters list item {idx} must be a dict of parameter overrides."
)
normalized_list.append(Waveform._normalize_param_overrides(item))
return normalized_list or None
if not isinstance(parameters, dict):
raise TypeError(
"dap_parameters must be a dict of model->params when using composite dap models."
)
model_names = set(dap_name)
invalid_models = set(parameters.keys()) - model_names
if invalid_models:
raise TypeError(
f"Invalid dap_parameters keys for composite model: {sorted(invalid_models)}"
)
normalized_composite: dict[str, dict] = {}
for model_name in dap_name:
model_params = parameters.get(model_name)
if model_params is None:
continue
if not isinstance(model_params, dict):
raise TypeError(
f"dap_parameters for '{model_name}' must be a dict of parameter overrides."
)
normalized = Waveform._normalize_param_overrides(model_params)
if normalized:
normalized_composite[model_name] = normalized
return normalized_composite or None
if lmfit is not None and isinstance(parameters, lmfit.Parameters):
return serialize_lmfit_params(parameters)
if not isinstance(parameters, dict):
if lmfit is None:
raise TypeError(
"dap_parameters must be a dict when lmfit is not installed on the client."
)
raise TypeError("dap_parameters must be a dict or lmfit.Parameters (or omitted).")
return Waveform._normalize_param_overrides(parameters)
@staticmethod
def _normalize_param_overrides(parameters: dict) -> dict | None:
normalized: dict[str, dict] = {}
for name, spec in parameters.items():
if spec is None:
continue
if isinstance(spec, (int, float, np.number)):
normalized[name] = {"name": name, "value": float(spec), "vary": False}
continue
if lmfit is not None and isinstance(spec, lmfit.Parameter):
normalized[name] = serialize_param_object(spec)
continue
if isinstance(spec, dict):
normalized[name] = {"name": name, **spec}
if "vary" not in normalized[name]:
normalized[name]["vary"] = False
continue
raise TypeError(
f"Invalid dap_parameters entry for '{name}': expected number, dict, or lmfit.Parameter."
)
return normalized or None
@staticmethod
def _format_dap_label(dap_name: str | list[str]) -> str:
if isinstance(dap_name, (list, tuple)):
return "+".join(dap_name)
return dap_name
@SafeSlot(dict, dict)
def update_dap_curves(self, msg, metadata):
"""
@@ -1793,14 +1933,6 @@ class Waveform(PlotBase):
if not curve:
return
# Get data from the parent (device) curve
parent_curve = self._find_curve_by_label(curve.config.parent_label)
if parent_curve is None:
return
x_parent, _ = parent_curve.get_data()
if x_parent is None or len(x_parent) == 0:
return
# Retrieve and store the fit parameters and summary from the DAP server response
try:
curve.dap_params = msg["data"][1]["fit_parameters"]
@@ -1809,19 +1941,13 @@ class Waveform(PlotBase):
logger.warning(f"Failed to retrieve DAP data for curve '{curve.name()}'")
return
# Render model according to the DAP model name and parameters
model_name = curve.config.signal.dap
model_function = getattr(lmfit.models, model_name)()
x_min, x_max = x_parent.min(), x_parent.max()
oversample = curve.dap_oversample
new_x = np.linspace(x_min, x_max, int(len(x_parent) * oversample))
# Evaluate the model with the provided parameters to generate the y values
new_y = model_function.eval(**curve.dap_params, x=new_x)
# Update the curve with the new data
curve.setData(new_x, new_y)
# Plot the fitted curve using the server-provided output to avoid requiring lmfit on the client.
try:
fit_data = msg["data"][0]
curve.setData(np.asarray(fit_data["x"]), np.asarray(fit_data["y"]))
except Exception:
logger.exception(f"Failed to plot DAP result for curve '{curve.name()}'")
return
metadata.update({"curve_id": curve_id})
self.dap_params_update.emit(curve.dap_params, metadata)
@@ -2341,24 +2467,20 @@ class DemoApp(QMainWindow): # pragma: no cover
def __init__(self):
super().__init__()
self.setWindowTitle("Waveform Demo")
self.resize(1200, 600)
self.resize(1600, 600)
self.main_widget = QWidget(self)
self.layout = QHBoxLayout(self.main_widget)
self.setCentralWidget(self.main_widget)
self.waveform_popup = Waveform(popups=True)
self.waveform_popup.plot(y_name="waveform")
self.waveform_side = Waveform(popups=False)
self.waveform_side.plot(y_name="bpm4i", y_entry="bpm4i", dap="GaussianModel")
self.waveform_side.plot(y_name="bpm3a", y_entry="bpm3a")
self.custom_waveform = Waveform(popups=True)
self._populate_custom_curve_demo()
self.layout.addWidget(self.waveform_side)
self.layout.addWidget(self.waveform_popup)
self.sine_waveform = Waveform(popups=True)
self.sine_waveform.dap_params_update.connect(self._log_sine_dap_params)
self._populate_sine_curve_demo()
self.layout.addWidget(self.custom_waveform)
self.layout.addWidget(self.sine_waveform)
def _populate_custom_curve_demo(self):
"""
@@ -2377,8 +2499,126 @@ class DemoApp(QMainWindow): # pragma: no cover
sigma = 0.8
y = amplitude * np.exp(-((x - center) ** 2) / (2 * sigma**2)) + noise
# 1) No explicit parameters: server will use lmfit defaults/guesses.
self.custom_waveform.plot(x=x, y=y, label="custom-gaussian", dap="GaussianModel")
# 2) Easy dict: numbers mean "fix this parameter to value" (vary=False).
self.custom_waveform.plot(
x=x,
y=y,
label="custom-gaussian-fixed-easy",
dap="GaussianModel",
dap_parameters={"amplitude": 1.0},
dap_oversample=5,
)
# 3) lmfit-style dict: any subset of lmfit.Parameter fields.
# Here `center` is not fixed (vary=True) but its initial value is set.
self.custom_waveform.plot(
x=x,
y=y,
label="custom-gaussian-override-dict",
dap="GaussianModel",
dap_parameters={
"center": {"value": 1.2, "vary": True},
"sigma": {"value": sigma, "vary": False, "min": 0.0},
},
)
# 4) Passing a real `lmfit.Parameters` object (optional: requires lmfit on the client).
if lmfit is not None:
params_gauss = lmfit.models.GaussianModel().make_params()
params_gauss["amplitude"].set(value=amplitude, vary=False)
params_gauss["center"].set(value=center, vary=False)
params_gauss["sigma"].set(value=sigma, vary=False, min=0.0)
self.custom_waveform.plot(
x=x,
y=y,
label="custom-gaussian-fixed-params",
dap="GaussianModel",
dap_parameters=params_gauss,
)
else:
logger.info("Skipping lmfit.Parameters demo (lmfit not installed on client).")
# Composite example: spectrum with three Gaussians (DAP-only)
x_spec = np.linspace(-5, 5, 800)
rng_spec = np.random.default_rng(123)
centers = [-2.0, 0.6, 2.4]
amplitudes = [2.5, 3.2, 1.8]
sigmas = [0.35, 0.5, 0.3]
y_spec = (
amplitudes[0] * np.exp(-((x_spec - centers[0]) ** 2) / (2 * sigmas[0] ** 2))
+ amplitudes[1] * np.exp(-((x_spec - centers[1]) ** 2) / (2 * sigmas[1] ** 2))
+ amplitudes[2] * np.exp(-((x_spec - centers[2]) ** 2) / (2 * sigmas[2] ** 2))
+ rng_spec.normal(loc=0, scale=0.06, size=x_spec.size)
)
self.custom_waveform.plot(
x=x_spec,
y=y_spec,
label="custom-gaussian-spectrum-fit",
dap=["GaussianModel", "GaussianModel", "GaussianModel"],
dap_parameters=[
{"center": {"value": centers[0], "vary": False}},
{"center": {"value": centers[1], "vary": False}},
{"center": {"value": centers[2], "vary": False}},
],
)
def _populate_sine_curve_demo(self):
"""
Showcase how lmfit's base SineModel can struggle with a drifting baseline.
"""
x = np.linspace(0, 6 * np.pi, 600)
rng = np.random.default_rng(7)
amplitude = 1.6
frequency = 0.75
phase = 0.4
offset = 0.8
slope = 0.08
noise = rng.normal(loc=0, scale=0.12, size=x.size)
y = offset + slope * x + amplitude * np.sin(2 * np.pi * frequency * x + phase) + noise
# Base SineModel (no offset support) to show the mismatch
self.sine_waveform.plot(x=x, y=y, label="custom-sine-data", dap="SineModel")
# Composite model: Sine + Linear baseline (offset + slope)
self.sine_waveform.plot(
x=x,
y=y,
label="custom-sine-composite",
dap=["SineModel", "LinearModel"],
dap_oversample=4,
# TODO have to guess correctly units for LMFit SineModel
# dap_parameters={
# "SineModel": {
# "amplitude": {"value": amplitude * 0.9, "vary": True},
# "frequency": {"value": 2 * np.pi * frequency * 1.05, "vary": True},
# "shift": {"value": 0.0, "vary": True},
# },
# "LinearModel": {
# "intercept": {"value": offset, "vary": True},
# "slope": {"value": slope, "vary": True},
# },
# },
)
if lmfit is None:
logger.info("Skipping sine lmfit demo (lmfit not installed on client).")
return
return
def _log_sine_dap_params(self, params: dict, metadata: dict):
curve_id = metadata.get("curve_id")
if curve_id not in {
"custom-sine-data-SineModel",
"custom-sine-composite-SineModel+LinearModel",
}:
return
logger.info(f"SineModel DAP fit params ({curve_id}): {params}")
if __name__ == "__main__": # pragma: no cover
import sys

View File

@@ -32,7 +32,7 @@ dock_area = gui.new()
img_widget = dock_area.new().new(gui.available_widgets.Image)
# Add an ImageWidget to the BECFigure for a 2D detector
img_widget.image(monitor='eiger', monitor_type='2d')
img_widget.image(device_name='eiger', device_entry='preview')
img_widget.title = "Camera Image - Eiger Detector"
```
@@ -46,7 +46,7 @@ dock_area = gui.new()
img_widget = dock_area.new().new(gui.available_widgets.Image)
# Add an ImageWidget to the BECFigure for a 2D detector
img_widget.image(monitor='waveform', monitor_type='1d')
img_widget.image(device_name='waveform', device_entry='data')
img_widget.title = "Line Detector Data"
# Optional: Set the color map and value range
@@ -84,7 +84,7 @@ The Image Widget can be configured for different detectors by specifying the cor
```python
# For a 2D camera detector
img_widget = fig.image(monitor='eiger', monitor_type='2d')
img_widget = fig.image(device_name='eiger', device_entry='preview')
img_widget.set_title("Eiger Camera Image")
```
@@ -92,7 +92,7 @@ img_widget.set_title("Eiger Camera Image")
```python
# For a 1D line detector
img_widget = fig.image(monitor='waveform', monitor_type='1d')
img_widget = fig.image(device_name='waveform', device_entry='data')
img_widget.set_title("Line Detector Data")
```

View File

@@ -59,7 +59,7 @@ def test_rpc_add_dock_with_plots_e2e(qtbot, bec_client_lib, connected_client_gui
mm.map("samx", "samy")
curve = wf.plot(x_name="samx", y_name="bpm4i")
im_item = im.image("eiger")
im_item = im.image(device_name="eiger", device_entry="preview")
assert curve.__class__.__name__ == "RPCReference"
assert curve.__class__ == RPCReference

View File

@@ -42,7 +42,7 @@ def test_rpc_plotting_shortcuts_init_configs(qtbot, connected_client_gui_obj):
c3 = wf.plot(y=[1, 2, 3], x=[1, 2, 3])
assert c3.object_name == "Curve_0"
im.image(monitor="eiger")
im.image(device_name="eiger", device_entry="preview")
mm.map(x_name="samx", y_name="samy")
sw.plot(x_name="samx", y_name="samy", z_name="bpm4a")
mw.plot(monitor="waveform")
@@ -71,6 +71,7 @@ def test_rpc_plotting_shortcuts_init_configs(qtbot, connected_client_gui_obj):
assert c1._config_dict["signal"] == {
"dap": None,
"name": "bpm4i",
"dap_parameters": None,
"entry": "bpm4i",
"dap_oversample": 1,
}
@@ -165,14 +166,14 @@ def test_rpc_image(qtbot, bec_client_lib, connected_client_gui_obj):
scans = client.scans
im = dock_area.new("Image")
im.image(monitor="eiger")
im.image(device_name="eiger", device_entry="preview")
status = scans.line_scan(dev.samx, -5, 5, steps=10, exp_time=0.05, relative=False)
status.wait()
last_image_device = client.connector.get_last(MessageEndpoints.device_monitor_2d("eiger"))[
"data"
].data
last_image_device = client.connector.get_last(
MessageEndpoints.device_preview("eiger", "preview")
)["data"].data
last_image_plot = im.main_image.get_data()
# check plotted data

View File

@@ -15,7 +15,7 @@ def test_rpc_reference_objects(connected_client_gui_obj):
plt.plot(x_name="samx", y_name="bpm4i")
im = dock_area.new("Image")
im.image("eiger")
im.image(device_name="eiger", device_entry="preview")
motor_map = dock_area.new("MotorMap")
motor_map.map("samx", "samy")
plt_z = dock_area.new("Waveform")
@@ -23,7 +23,8 @@ def test_rpc_reference_objects(connected_client_gui_obj):
assert len(plt_z.curves) == 1
assert len(plt.curves) == 1
assert im.monitor == "eiger"
assert im.device_name == "eiger"
assert im.device_entry == "preview"
assert isinstance(im.main_image, RPCReference)
image_item = gui._ipython_registry.get(im.main_image._gui_id, None)

View File

@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
import numpy as np
import pytest
from bec_lib.endpoints import MessageEndpoints
from bec_widgets.cli.rpc.rpc_base import RPCBase, RPCReference
@@ -233,7 +234,7 @@ def test_widgets_e2e_image(qtbot, connected_client_gui_obj, random_generator_fro
scans = bec.scans
dev = bec.device_manager.devices
# Test rpc calls
img = widget.image(dev.eiger)
img = widget.image(device_name=dev.eiger.name, device_entry="preview")
assert img.get_data() is None
# Run a scan and plot the image
s = scans.line_scan(dev.samx, -3, 3, steps=50, exp_time=0.01, relative=False)
@@ -247,13 +248,13 @@ def test_widgets_e2e_image(qtbot, connected_client_gui_obj, random_generator_fro
qtbot.waitUntil(_wait_for_scan_in_history, timeout=7000)
# Check that last image is equivalent to data in Redis
last_img = bec.device_monitor.get_data(
dev.eiger, count=1
) # Get last image from Redis monitor 2D endpoint
last_img = bec.connector.get_last(MessageEndpoints.device_preview("eiger", "preview"))[
"data"
].data
assert np.allclose(img.get_data(), last_img)
# Now add a device with a preview signal
img = widget.image(["eiger", "preview"])
img = widget.image(device_name="eiger", device_entry="preview")
s = scans.line_scan(dev.samx, -3, 3, steps=50, exp_time=0.01, relative=False)
s.wait()

View File

@@ -82,6 +82,45 @@ def test_rgba_to_hex():
assert Colors.rgba_to_hex(255, 87, 51) == "#FF5733FF"
def test_canonical_colormap_name_case_insensitive():
available = Colors.list_available_colormaps()
presets = Colors.list_available_gradient_presets()
if not available and not presets:
pytest.skip("No colormaps or presets available to test canonical mapping.")
name = (available or presets)[0]
requested = name.swapcase()
assert Colors.canonical_colormap_name(requested) == name
def test_validate_color_map_returns_canonical_name():
available = Colors.list_available_colormaps()
presets = Colors.list_available_gradient_presets()
if not available and not presets:
pytest.skip("No colormaps or presets available to test validation.")
name = (available or presets)[0]
requested = name.swapcase()
assert Colors.validate_color_map(requested) == name
def test_get_colormap_uses_gradient_preset_fallback(monkeypatch):
presets = Colors.list_available_gradient_presets()
if not presets:
pytest.skip("No gradient presets available to test fallback.")
preset = presets[0]
Colors._get_colormap_cached.cache_clear()
def _raise(*args, **kwargs):
raise Exception("registry unavailable")
monkeypatch.setattr(pg.colormap, "get", _raise)
cmap = Colors._get_colormap_cached(preset)
assert isinstance(cmap, pg.ColorMap)
@pytest.mark.parametrize("num", [10, 100, 400])
def test_evenly_spaced_colors(num):
colors_qcolor = Colors.evenly_spaced_colors(colormap="magma", num=num, format="QColor")

View File

@@ -4,7 +4,6 @@ import pyqtgraph as pg
import pytest
from bec_widgets.widgets.plots.image.image_base import ImageLayerManager
from bec_widgets.widgets.plots.image.image_item import ImageItem
@pytest.fixture()

View File

@@ -1,6 +1,7 @@
import numpy as np
import pyqtgraph as pg
import pytest
from bec_lib.endpoints import MessageEndpoints
from qtpy.QtCore import QPointF
from bec_widgets.widgets.plots.image.image import Image
@@ -12,6 +13,23 @@ from tests.unit_tests.conftest import create_widget
##################################################
def _set_signal_config(
client,
device_name: str,
signal_name: str,
signal_class: str,
ndim: int,
obj_name: str | None = None,
):
device = client.device_manager.devices[device_name]
device._info["signals"][signal_name] = {
"obj_name": obj_name or signal_name,
"signal_class": signal_class,
"component_name": signal_name,
"describe": {"signal_info": {"ndim": ndim}},
}
def test_initialization_defaults(qtbot, mocked_client):
bec_image_view = create_widget(qtbot, Image, client=mocked_client)
assert bec_image_view.color_map == "plasma"
@@ -114,32 +132,35 @@ def test_enable_colorbar_with_vrange(qtbot, mocked_client, colorbar_type):
##############################################
# Previewsignal update mechanism
# Device/signal update mechanism
def test_image_setup_preview_signal_1d(qtbot, mocked_client, monkeypatch):
def test_image_setup_preview_signal_1d(qtbot, mocked_client):
"""
Ensure that calling .image() with a (device, signal, config) tuple representing
a 1D PreviewSignal connects using the 1D path and updates correctly.
Ensure that calling .image() with a 1D PreviewSignal connects using the 1D path
and updates correctly.
"""
import numpy as np
view = create_widget(qtbot, Image, client=mocked_client)
signal_config = {
"obj_name": "waveform1d_img",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 1}},
}
_set_signal_config(
mocked_client,
"waveform1d",
"img",
signal_class="PreviewSignal",
ndim=1,
obj_name="waveform1d_img",
)
# Set the image monitor to the preview signal
view.image(monitor=("waveform1d", "img", signal_config))
view.image(device_name="waveform1d", device_entry="img")
# Subscriptions should indicate 1D preview connection
sub = view.subscriptions["main"]
assert sub.source == "device_monitor_1d"
assert sub.monitor_type == "1d"
assert sub.monitor == ("waveform1d", "img", signal_config)
assert view.device_name == "waveform1d"
assert view.device_entry == "img"
# Simulate a waveform update from the dispatcher
waveform = np.arange(25, dtype=float)
@@ -148,29 +169,32 @@ def test_image_setup_preview_signal_1d(qtbot, mocked_client, monkeypatch):
np.testing.assert_array_equal(view.main_image.raw_data[0], waveform)
def test_image_setup_preview_signal_2d(qtbot, mocked_client, monkeypatch):
def test_image_setup_preview_signal_2d(qtbot, mocked_client):
"""
Ensure that calling .image() with a (device, signal, config) tuple representing
a 2D PreviewSignal connects using the 2D path and updates correctly.
Ensure that calling .image() with a 2D PreviewSignal connects using the 2D path
and updates correctly.
"""
import numpy as np
view = create_widget(qtbot, Image, client=mocked_client)
signal_config = {
"obj_name": "eiger_img2d",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 2}},
}
_set_signal_config(
mocked_client,
"eiger",
"img2d",
signal_class="PreviewSignal",
ndim=2,
obj_name="eiger_img2d",
)
# Set the image monitor to the preview signal
view.image(monitor=("eiger", "img2d", signal_config))
view.image(device_name="eiger", device_entry="img2d")
# Subscriptions should indicate 2D preview connection
sub = view.subscriptions["main"]
assert sub.source == "device_monitor_2d"
assert sub.monitor_type == "2d"
assert sub.monitor == ("eiger", "img2d", signal_config)
assert view.device_name == "eiger"
assert view.device_entry == "img2d"
# Simulate a 2D image update
test_data = np.arange(16, dtype=float).reshape(4, 4)
@@ -178,38 +202,197 @@ def test_image_setup_preview_signal_2d(qtbot, mocked_client, monkeypatch):
np.testing.assert_array_equal(view.main_image.image, test_data)
def test_preview_signals_skip_0d_entries(qtbot, mocked_client, monkeypatch):
"""
Preview/async combobox should omit 0D signals.
"""
view = create_widget(qtbot, Image, client=mocked_client)
def fake_get(signal_class_filter):
signal_classes = (
signal_class_filter
if isinstance(signal_class_filter, (list, tuple, set))
else [signal_class_filter]
)
if "PreviewSignal" in signal_classes:
return [
(
"eiger",
"sig0d",
{
"obj_name": "sig0d",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 0}},
},
),
(
"eiger",
"sig2d",
{
"obj_name": "sig2d",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 2}},
},
),
]
return []
monkeypatch.setattr(view.client.device_manager, "get_bec_signals", fake_get)
device_selection = view.toolbar.components.get_action("device_selection").widget
device_selection.signal_combo_box.set_device("eiger")
device_selection.signal_combo_box.update_signals_from_signal_classes()
texts = [
device_selection.signal_combo_box.itemText(i)
for i in range(device_selection.signal_combo_box.count())
]
assert "sig0d" not in texts
assert "sig2d" in texts
def test_image_async_signal_uses_obj_name(qtbot, mocked_client, monkeypatch):
"""
Verify async signals use obj_name for endpoints/payloads and reconnect with scan_id.
"""
view = create_widget(qtbot, Image, client=mocked_client)
_set_signal_config(
mocked_client, "eiger", "img", signal_class="AsyncSignal", ndim=1, obj_name="async_obj"
)
view.image(device_name="eiger", device_entry="img")
assert view.subscriptions["main"].async_signal_name == "async_obj"
assert view.async_update is True
# Prepare scan ids and capture dispatcher calls
view.old_scan_id = "old_scan"
view.scan_id = "new_scan"
connected = []
disconnected = []
monkeypatch.setattr(
view.bec_dispatcher,
"connect_slot",
lambda slot, endpoint, from_start=False, cb_info=None: connected.append(
(slot, endpoint, from_start, cb_info)
),
)
monkeypatch.setattr(
view.bec_dispatcher,
"disconnect_slot",
lambda slot, endpoint: disconnected.append((slot, endpoint)),
)
view._setup_async_image(view.scan_id)
expected_new = MessageEndpoints.device_async_signal("new_scan", "eiger", "async_obj")
expected_old = MessageEndpoints.device_async_signal("old_scan", "eiger", "async_obj")
assert any(ep == expected_new for _, ep, _, _ in connected)
assert any(ep == expected_old for _, ep in disconnected)
# Payload extraction should use obj_name
payload = np.array([1, 2, 3])
msg = {"signals": {"async_obj": {"value": payload}}}
assert np.array_equal(view._get_payload_data(msg), payload)
def test_disconnect_clears_async_state(qtbot, mocked_client, monkeypatch):
view = create_widget(qtbot, Image, client=mocked_client)
_set_signal_config(
mocked_client, "eiger", "img", signal_class="AsyncSignal", ndim=2, obj_name="async_obj"
)
view.image(device_name="eiger", device_entry="img")
view.scan_id = "scan_x"
view.old_scan_id = "scan_y"
view.subscriptions["main"].async_signal_name = "async_obj"
# Avoid touching real dispatcher
monkeypatch.setattr(view.bec_dispatcher, "disconnect_slot", lambda *args, **kwargs: None)
view.disconnect_monitor(device_name="eiger", device_entry="img")
assert view.subscriptions["main"].async_signal_name is None
assert view.async_update is False
##############################################
# Device monitor endpoint update mechanism
# Connection guardrails
def test_image_setup_image_2d(qtbot, mocked_client):
bec_image_view = create_widget(qtbot, Image, client=mocked_client)
bec_image_view.image(monitor="eiger", monitor_type="2d")
assert bec_image_view.monitor == "eiger"
assert bec_image_view.subscriptions["main"].source == "device_monitor_2d"
assert bec_image_view.subscriptions["main"].monitor_type == "2d"
assert bec_image_view.main_image.raw_data is None
assert bec_image_view.main_image.image is None
def test_image_setup_rejects_unsupported_signal_class(qtbot, mocked_client):
view = create_widget(qtbot, Image, client=mocked_client)
_set_signal_config(mocked_client, "eiger", "img", signal_class="Signal", ndim=2)
view.image(device_name="eiger", device_entry="img")
assert view.subscriptions["main"].source is None
assert view.subscriptions["main"].monitor_type is None
assert view.async_update is False
def test_image_setup_image_1d(qtbot, mocked_client):
bec_image_view = create_widget(qtbot, Image, client=mocked_client)
bec_image_view.image(monitor="eiger", monitor_type="1d")
assert bec_image_view.monitor == "eiger"
assert bec_image_view.subscriptions["main"].source == "device_monitor_1d"
assert bec_image_view.subscriptions["main"].monitor_type == "1d"
assert bec_image_view.main_image.raw_data is None
assert bec_image_view.main_image.image is None
def test_image_disconnects_with_missing_entry(qtbot, mocked_client):
view = create_widget(qtbot, Image, client=mocked_client)
_set_signal_config(mocked_client, "eiger", "img", signal_class="PreviewSignal", ndim=2)
view.image(device_name="eiger", device_entry="img")
assert view.device_name == "eiger"
assert view.device_entry == "img"
view.image(device_name="eiger", device_entry=None)
assert view.device_name == ""
assert view.device_entry == ""
def test_image_setup_image_auto(qtbot, mocked_client):
bec_image_view = create_widget(qtbot, Image, client=mocked_client)
bec_image_view.image(monitor="eiger", monitor_type="auto")
assert bec_image_view.monitor == "eiger"
assert bec_image_view.subscriptions["main"].source == "auto"
assert bec_image_view.subscriptions["main"].monitor_type == "auto"
assert bec_image_view.main_image.raw_data is None
assert bec_image_view.main_image.image is None
def test_handle_scan_change_clears_buffers_and_resets_crosshair(qtbot, mocked_client, monkeypatch):
view = create_widget(qtbot, Image, client=mocked_client)
view.scan_id = "scan_1"
view.main_image.buffer = [np.array([1.0, 2.0])]
view.main_image.max_len = 2
clear_called = []
monkeypatch.setattr(view.main_image, "clear", lambda: clear_called.append(True))
reset_called = []
if view.crosshair is not None:
monkeypatch.setattr(view.crosshair, "reset", lambda: reset_called.append(True))
view._handle_scan_change("scan_2")
assert view.old_scan_id == "scan_1"
assert view.scan_id == "scan_2"
assert clear_called == [True]
assert view.main_image.buffer == []
assert view.main_image.max_len == 0
if view.crosshair is not None:
assert reset_called == [True]
def test_handle_scan_change_reconnects_async(qtbot, mocked_client, monkeypatch):
view = create_widget(qtbot, Image, client=mocked_client)
view.scan_id = "scan_1"
view.async_update = True
called = []
monkeypatch.setattr(view, "_setup_async_image", lambda scan_id: called.append(scan_id))
view._handle_scan_change("scan_2")
assert called == ["scan_2"]
def test_handle_scan_change_same_scan_noop(qtbot, mocked_client, monkeypatch):
view = create_widget(qtbot, Image, client=mocked_client)
view.scan_id = "scan_1"
view.main_image.buffer = [np.array([1.0])]
view.main_image.max_len = 1
clear_called = []
monkeypatch.setattr(view.main_image, "clear", lambda: clear_called.append(True))
view._handle_scan_change("scan_1")
assert view.scan_id == "scan_1"
assert clear_called == []
assert view.main_image.buffer == [np.array([1.0])]
assert view.main_image.max_len == 1
def test_image_data_update_2d(qtbot, mocked_client):
@@ -245,8 +428,7 @@ def test_toolbar_actions_presence(qtbot, mocked_client):
assert bec_image_view.toolbar.components.exists("image_autorange")
assert bec_image_view.toolbar.components.exists("lock_aspect_ratio")
assert bec_image_view.toolbar.components.exists("image_processing_fft")
assert bec_image_view.toolbar.components.exists("image_device_combo")
assert bec_image_view.toolbar.components.exists("image_dim_combo")
assert bec_image_view.toolbar.components.exists("device_selection")
def test_auto_emit_syncs_image_toolbar_actions(qtbot, mocked_client):
@@ -327,13 +509,40 @@ def test_setting_vrange_with_colorbar(qtbot, mocked_client, colorbar_type):
###################################
def test_setup_image_from_toolbar(qtbot, mocked_client):
def test_setup_image_from_toolbar(qtbot, mocked_client, monkeypatch):
bec_image_view = create_widget(qtbot, Image, client=mocked_client)
bec_image_view.device_combo_box.setCurrentText("eiger")
bec_image_view.dim_combo_box.setCurrentText("2d")
_set_signal_config(mocked_client, "eiger", "img", signal_class="PreviewSignal", ndim=2)
monkeypatch.setattr(
mocked_client.device_manager,
"get_bec_signals",
lambda signal_class_filter: (
[
(
"eiger",
"img",
{
"obj_name": "img",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 2}},
},
)
]
if "PreviewSignal" in (signal_class_filter or [])
else []
),
)
assert bec_image_view.monitor == "eiger"
device_selection = bec_image_view.toolbar.components.get_action("device_selection").widget
device_selection.device_combo_box.update_devices_from_filters()
device_selection.device_combo_box.setCurrentText("eiger")
device_selection.signal_combo_box.setCurrentText("img")
bec_image_view.on_device_selection_changed(None)
qtbot.wait(200)
assert bec_image_view.device_name == "eiger"
assert bec_image_view.device_entry == "img"
assert bec_image_view.subscriptions["main"].source == "device_monitor_2d"
assert bec_image_view.subscriptions["main"].monitor_type == "2d"
assert bec_image_view.main_image.raw_data is None
@@ -598,90 +807,59 @@ def test_roi_plot_data_from_image(qtbot, mocked_client):
##############################################
# MonitorSelectionToolbarBundle specific tests
# Device selection toolbar sync
##############################################
def test_monitor_selection_reverse_device_items(qtbot, mocked_client):
"""
Verify that _reverse_device_items correctly reverses the order of items in the
device combobox while preserving the current selection.
"""
def test_device_selection_syncs_from_properties(qtbot, mocked_client, monkeypatch):
view = create_widget(qtbot, Image, client=mocked_client)
combo = view.device_combo_box
# Replace existing items with a deterministic list
combo.clear()
combo.addItem("samx", 1)
combo.addItem("samy", 2)
combo.addItem("samz", 3)
combo.setCurrentText("samy")
# Reverse the items
view._reverse_device_items()
# Order should be reversed and selection preserved
assert [combo.itemText(i) for i in range(combo.count())] == ["samz", "samy", "samx"]
assert combo.currentText() == "samy"
def test_monitor_selection_populate_preview_signals(qtbot, mocked_client, monkeypatch):
"""
Verify that _populate_preview_signals adds previewsignal devices to the combobox
with the correct userData.
"""
view = create_widget(qtbot, Image, client=mocked_client)
# Provide a deterministic fake device_manager with get_bec_signals
class _FakeDM:
def get_bec_signals(self, _filter):
return [
("eiger", "img", {"obj_name": "eiger_img"}),
("async_device", "img2", {"obj_name": "async_device_img2"}),
_set_signal_config(mocked_client, "eiger", "img2d", signal_class="PreviewSignal", ndim=2)
monkeypatch.setattr(
view.client.device_manager,
"get_bec_signals",
lambda signal_class_filter: (
[
(
"eiger",
"img2d",
{
"obj_name": "img2d",
"signal_class": "PreviewSignal",
"describe": {"signal_info": {"ndim": 2}},
},
)
]
if "PreviewSignal" in (signal_class_filter or [])
else []
),
)
monkeypatch.setattr(view.client, "device_manager", _FakeDM())
view.device_name = "eiger"
view.device_entry = "img2d"
initial_count = view.device_combo_box.count()
qtbot.wait(200) # Allow signal processing
view._populate_preview_signals()
# Two new entries should have been added
assert view.device_combo_box.count() == initial_count + 2
# The first newly added item should carry tuple userData describing the device/signal
data = view.device_combo_box.itemData(initial_count)
assert isinstance(data, tuple) and data[0] == "eiger"
device_selection = view.toolbar.components.get_action("device_selection").widget
qtbot.waitUntil(
lambda: device_selection.device_combo_box.currentText() == "eiger"
and device_selection.signal_combo_box.currentText() == "img2d",
timeout=1000,
)
def test_monitor_selection_adjust_and_connect(qtbot, mocked_client, monkeypatch):
"""
Verify that _adjust_and_connect performs the full set-up:
- fills the combobox with preview signals,
- reverses their order,
- and resets the currentText to an empty string.
"""
def test_device_entry_syncs_from_toolbar(qtbot, mocked_client):
view = create_widget(qtbot, Image, client=mocked_client)
_set_signal_config(mocked_client, "eiger", "img_a", signal_class="PreviewSignal", ndim=2)
_set_signal_config(mocked_client, "eiger", "img_b", signal_class="PreviewSignal", ndim=2)
# Deterministic fake device_manager
class _FakeDM:
def get_bec_signals(self, _filter):
return [("eiger", "img", {"obj_name": "eiger_img"})]
view.device_name = "eiger"
view.device_entry = "img_a"
monkeypatch.setattr(view.client, "device_manager", _FakeDM())
device_selection = view.toolbar.components.get_action("device_selection").widget
device_selection.signal_combo_box.blockSignals(True)
device_selection.signal_combo_box.setCurrentText("img_b")
device_selection.signal_combo_box.blockSignals(False)
combo = view.device_combo_box
# Start from a clean state
combo.clear()
combo.addItem("", None)
combo.setCurrentText("")
view._sync_device_entry_from_toolbar()
# Execute the method under test
view._adjust_and_connect()
# Expect exactly two items: preview label followed by the empty default
assert combo.count() == 2
# Because of the reversal, the preview label comes first
assert combo.itemText(0) == "eiger_img"
# Current selection remains empty
assert combo.currentText() == ""
assert view.device_entry == "img_b"

View File

@@ -516,6 +516,112 @@ def test_plot_custom_curve_with_inline_dap(qtbot, mocked_client_with_dap):
assert dap_curve.config.signal.dap == "GaussianModel"
def test_normalize_dap_parameters_number_dict():
normalized = Waveform._normalize_dap_parameters({"amplitude": 1.0, "center": 2})
assert normalized == {
"amplitude": {"name": "amplitude", "value": 1.0, "vary": False},
"center": {"name": "center", "value": 2.0, "vary": False},
}
def test_normalize_dap_parameters_dict_spec_defaults_vary_false():
normalized = Waveform._normalize_dap_parameters({"sigma": {"value": 0.8, "min": 0.0}})
assert normalized["sigma"]["name"] == "sigma"
assert normalized["sigma"]["value"] == 0.8
assert normalized["sigma"]["min"] == 0.0
assert normalized["sigma"]["vary"] is False
def test_normalize_dap_parameters_invalid_type_raises():
with pytest.raises(TypeError):
Waveform._normalize_dap_parameters(["amplitude", 1.0]) # type: ignore[arg-type]
def test_normalize_dap_parameters_composite_list():
normalized = Waveform._normalize_dap_parameters(
[{"center": 1.0}, {"sigma": {"value": 0.5, "min": 0.0}}],
dap_name=["GaussianModel", "GaussianModel"],
)
assert normalized == [
{"center": {"name": "center", "value": 1.0, "vary": False}},
{"sigma": {"name": "sigma", "value": 0.5, "min": 0.0, "vary": False}},
]
def test_normalize_dap_parameters_composite_dict():
normalized = Waveform._normalize_dap_parameters(
{
"GaussianModel": {"center": {"value": 1.0, "vary": True}},
"LorentzModel": {"amplitude": 2.0},
},
dap_name=["GaussianModel", "LorentzModel"],
)
assert normalized["GaussianModel"]["center"]["value"] == 1.0
assert normalized["GaussianModel"]["center"]["vary"] is True
assert normalized["LorentzModel"]["amplitude"]["value"] == 2.0
assert normalized["LorentzModel"]["amplitude"]["vary"] is False
def test_request_dap_includes_normalized_parameters(qtbot, mocked_client_with_dap, monkeypatch):
wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap)
curve = wf.plot(
x=[0, 1, 2],
y=[1, 2, 3],
label="custom-inline-params",
dap="GaussianModel",
dap_parameters={"amplitude": 1.0},
)
dap_curve = wf.get_curve(f"{curve.name()}-GaussianModel")
assert dap_curve is not None
dap_curve.dap_oversample = 3
captured = {}
def capture(topic, msg, *args, **kwargs): # noqa: ARG001
captured["topic"] = topic
captured["msg"] = msg
monkeypatch.setattr(wf.client.connector, "set_and_publish", capture)
wf.request_dap()
msg = captured["msg"]
dap_kwargs = msg.content["config"]["kwargs"]
assert dap_kwargs["oversample"] == 3
assert dap_kwargs["parameters"] == {
"amplitude": {"name": "amplitude", "value": 1.0, "vary": False}
}
def test_request_dap_includes_composite_parameters_list(qtbot, mocked_client_with_dap, monkeypatch):
wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap)
curve = wf.plot(
x=[0, 1, 2],
y=[1, 2, 3],
label="custom-composite",
dap=["GaussianModel", "GaussianModel"],
dap_parameters=[{"center": 0.0}, {"center": 1.0}],
)
dap_curve = wf.get_curve(f"{curve.name()}-GaussianModel+GaussianModel")
assert dap_curve is not None
captured = {}
def capture(topic, msg, *args, **kwargs): # noqa: ARG001
captured["topic"] = topic
captured["msg"] = msg
monkeypatch.setattr(wf.client.connector, "set_and_publish", capture)
wf.request_dap()
msg = captured["msg"]
dap_kwargs = msg.content["config"]["kwargs"]
assert dap_kwargs["parameters"] == [
{"center": {"name": "center", "value": 0.0, "vary": False}},
{"center": {"name": "center", "value": 1.0, "vary": False}},
]
assert msg.content["config"]["class_kwargs"]["model"] == ["GaussianModel", "GaussianModel"]
def test_fetch_scan_data_and_access(qtbot, mocked_client, monkeypatch):
"""
Test the _fetch_scan_data_and_access method returns live_data/val if in a live scan,