mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-12-28 01:41:20 +01:00
feat(curve, waveform): add dap_parameters for lmfit customization in DAP requests
This commit is contained in:
@@ -5418,6 +5418,7 @@ class Waveform(RPCBase):
|
||||
color: "str | None" = None,
|
||||
label: "str | None" = None,
|
||||
dap: "str | None" = None,
|
||||
dap_parameters: "dict | lmfit.Parameters | None | object" = None,
|
||||
scan_id: "str | None" = None,
|
||||
scan_number: "int | None" = None,
|
||||
**kwargs,
|
||||
@@ -5442,6 +5443,8 @@ class Waveform(RPCBase):
|
||||
dap(str): The dap model to use for the curve. When provided, a DAP curve is
|
||||
attached automatically for device, history, or custom data sources. Use
|
||||
the same string as the LMFit model name.
|
||||
dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
|
||||
Values can be numeric (interpreted as fixed parameters) or dicts like`{"value": 1.0, "vary": False}`.
|
||||
scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and
|
||||
the y‑data (and optional x‑data) are fetched from that historical scan. Such curves are
|
||||
never cleared by live‑scan resets.
|
||||
@@ -5458,6 +5461,7 @@ class Waveform(RPCBase):
|
||||
dap_name: "str",
|
||||
color: "str | None" = None,
|
||||
dap_oversample: "int" = 1,
|
||||
dap_parameters: "dict | lmfit.Parameters | None" = None,
|
||||
**kwargs,
|
||||
) -> "Curve":
|
||||
"""
|
||||
@@ -5470,6 +5474,7 @@ class Waveform(RPCBase):
|
||||
dap_name(str): The name of the DAP model to use.
|
||||
color(str): The color of the curve.
|
||||
dap_oversample(int): The oversampling factor for the DAP curve.
|
||||
dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
|
||||
**kwargs
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -24,6 +24,7 @@ class DeviceSignal(BaseModel):
|
||||
entry: str
|
||||
dap: str | None = None
|
||||
dap_oversample: int = 1
|
||||
dap_parameters: dict | None = None
|
||||
|
||||
model_config: dict = {"validate_assignment": True}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import lmfit
|
||||
import numpy as np
|
||||
import pyqtgraph as pg
|
||||
from bec_lib import bec_logger, messages
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.lmfit_serializer import serialize_lmfit_params, serialize_param_object
|
||||
from bec_lib.scan_data_container import ScanDataContainer
|
||||
from pydantic import Field, ValidationError, field_validator
|
||||
from qtpy.QtCore import Qt, QTimer, Signal
|
||||
@@ -41,6 +41,15 @@ from bec_widgets.widgets.services.scan_history_browser.scan_history_browser impo
|
||||
)
|
||||
|
||||
logger = bec_logger.logger
|
||||
_DAP_PARAM = object()
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
import lmfit # type: ignore
|
||||
else:
|
||||
try:
|
||||
import lmfit # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
lmfit = None
|
||||
|
||||
|
||||
# noinspection PyDataclass
|
||||
@@ -697,6 +706,7 @@ class Waveform(PlotBase):
|
||||
color: str | None = None,
|
||||
label: str | None = None,
|
||||
dap: str | None = None,
|
||||
dap_parameters: dict | lmfit.Parameters | None | object = None,
|
||||
scan_id: str | None = None,
|
||||
scan_number: int | None = None,
|
||||
**kwargs,
|
||||
@@ -721,6 +731,8 @@ class Waveform(PlotBase):
|
||||
dap(str): The dap model to use for the curve. When provided, a DAP curve is
|
||||
attached automatically for device, history, or custom data sources. Use
|
||||
the same string as the LMFit model name.
|
||||
dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
|
||||
Values can be numeric (interpreted as fixed parameters) or dicts like`{"value": 1.0, "vary": False}`.
|
||||
scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and
|
||||
the y‑data (and optional x‑data) are fetched from that historical scan. Such curves are
|
||||
never cleared by live‑scan resets.
|
||||
@@ -733,6 +745,8 @@ class Waveform(PlotBase):
|
||||
source = "custom"
|
||||
x_data = None
|
||||
y_data = None
|
||||
if dap_parameters is _DAP_PARAM:
|
||||
dap_parameters = kwargs.pop("dap_parameters", None) or kwargs.pop("parameters", None)
|
||||
|
||||
# 1. Custom curve logic
|
||||
if x is not None and y is not None:
|
||||
@@ -810,7 +824,9 @@ class Waveform(PlotBase):
|
||||
curve = self._add_curve(config=config, x_data=x_data, y_data=y_data)
|
||||
|
||||
if dap is not None and curve.config.source in ("device", "history", "custom"):
|
||||
self.add_dap_curve(device_label=curve.name(), dap_name=dap, **kwargs)
|
||||
self.add_dap_curve(
|
||||
device_label=curve.name(), dap_name=dap, dap_parameters=dap_parameters, **kwargs
|
||||
)
|
||||
|
||||
return curve
|
||||
|
||||
@@ -823,6 +839,7 @@ class Waveform(PlotBase):
|
||||
dap_name: str,
|
||||
color: str | None = None,
|
||||
dap_oversample: int = 1,
|
||||
dap_parameters: dict | lmfit.Parameters | None = None,
|
||||
**kwargs,
|
||||
) -> Curve:
|
||||
"""
|
||||
@@ -835,6 +852,7 @@ class Waveform(PlotBase):
|
||||
dap_name(str): The name of the DAP model to use.
|
||||
color(str): The color of the curve.
|
||||
dap_oversample(int): The oversampling factor for the DAP curve.
|
||||
dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server.
|
||||
**kwargs
|
||||
|
||||
Returns:
|
||||
@@ -882,7 +900,11 @@ class Waveform(PlotBase):
|
||||
|
||||
# Attach device signal with DAP
|
||||
config.signal = DeviceSignal(
|
||||
name=dev_name, entry=dev_entry, dap=dap_name, dap_oversample=dap_oversample
|
||||
name=dev_name,
|
||||
entry=dev_entry,
|
||||
dap=dap_name,
|
||||
dap_oversample=dap_oversample,
|
||||
dap_parameters=self._normalize_dap_parameters(dap_parameters),
|
||||
)
|
||||
|
||||
# 4) Create the DAP curve config using `_add_curve(...)`
|
||||
@@ -1762,12 +1784,21 @@ class Waveform(PlotBase):
|
||||
x_min = None
|
||||
x_max = None
|
||||
|
||||
dap_parameters = getattr(dap_curve.config.signal, "dap_parameters", None)
|
||||
dap_kwargs = {
|
||||
"data_x": x_data,
|
||||
"data_y": y_data,
|
||||
"oversample": dap_curve.dap_oversample,
|
||||
}
|
||||
if dap_parameters:
|
||||
dap_kwargs["parameters"] = dap_parameters
|
||||
|
||||
msg = messages.DAPRequestMessage(
|
||||
dap_cls="LmfitService1D",
|
||||
dap_type="on_demand",
|
||||
config={
|
||||
"args": [],
|
||||
"kwargs": {"data_x": x_data, "data_y": y_data},
|
||||
"kwargs": dap_kwargs,
|
||||
"class_args": model._plugin_info["class_args"],
|
||||
"class_kwargs": model._plugin_info["class_kwargs"],
|
||||
"curve_label": dap_curve.name(),
|
||||
@@ -1776,6 +1807,49 @@ class Waveform(PlotBase):
|
||||
)
|
||||
self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_dap_parameters(parameters: dict | lmfit.Parameters | None) -> dict | None:
|
||||
"""
|
||||
Normalize user-provided lmfit parameters into a JSON-serializable dict suitable for the DAP server.
|
||||
|
||||
Supports:
|
||||
- `lmfit.Parameters`
|
||||
- `dict[name -> number]` (treated as fixed parameter with `vary=False`)
|
||||
- `dict[name -> dict]` (lmfit.Parameter fields; defaults to `vary=False` if unspecified)
|
||||
- `dict[name -> lmfit.Parameter]`
|
||||
"""
|
||||
if parameters is None:
|
||||
return None
|
||||
if lmfit is not None and isinstance(parameters, lmfit.Parameters):
|
||||
return serialize_lmfit_params(parameters)
|
||||
if not isinstance(parameters, dict):
|
||||
if lmfit is None:
|
||||
raise TypeError(
|
||||
"dap_parameters must be a dict when lmfit is not installed on the client."
|
||||
)
|
||||
raise TypeError("dap_parameters must be a dict or lmfit.Parameters (or omitted).")
|
||||
|
||||
normalized: dict[str, dict] = {}
|
||||
for name, spec in parameters.items():
|
||||
if spec is None:
|
||||
continue
|
||||
if isinstance(spec, (int, float, np.number)):
|
||||
normalized[name] = {"name": name, "value": float(spec), "vary": False}
|
||||
continue
|
||||
if lmfit is not None and isinstance(spec, lmfit.Parameter):
|
||||
normalized[name] = serialize_param_object(spec)
|
||||
continue
|
||||
if isinstance(spec, dict):
|
||||
normalized[name] = {"name": name, **spec}
|
||||
if "vary" not in normalized[name]:
|
||||
normalized[name]["vary"] = False
|
||||
continue
|
||||
raise TypeError(
|
||||
f"Invalid dap_parameters entry for '{name}': expected number, dict, or lmfit.Parameter."
|
||||
)
|
||||
|
||||
return normalized or None
|
||||
|
||||
@SafeSlot(dict, dict)
|
||||
def update_dap_curves(self, msg, metadata):
|
||||
"""
|
||||
@@ -1793,14 +1867,6 @@ class Waveform(PlotBase):
|
||||
if not curve:
|
||||
return
|
||||
|
||||
# Get data from the parent (device) curve
|
||||
parent_curve = self._find_curve_by_label(curve.config.parent_label)
|
||||
if parent_curve is None:
|
||||
return
|
||||
x_parent, _ = parent_curve.get_data()
|
||||
if x_parent is None or len(x_parent) == 0:
|
||||
return
|
||||
|
||||
# Retrieve and store the fit parameters and summary from the DAP server response
|
||||
try:
|
||||
curve.dap_params = msg["data"][1]["fit_parameters"]
|
||||
@@ -1809,19 +1875,13 @@ class Waveform(PlotBase):
|
||||
logger.warning(f"Failed to retrieve DAP data for curve '{curve.name()}'")
|
||||
return
|
||||
|
||||
# Render model according to the DAP model name and parameters
|
||||
model_name = curve.config.signal.dap
|
||||
model_function = getattr(lmfit.models, model_name)()
|
||||
|
||||
x_min, x_max = x_parent.min(), x_parent.max()
|
||||
oversample = curve.dap_oversample
|
||||
new_x = np.linspace(x_min, x_max, int(len(x_parent) * oversample))
|
||||
|
||||
# Evaluate the model with the provided parameters to generate the y values
|
||||
new_y = model_function.eval(**curve.dap_params, x=new_x)
|
||||
|
||||
# Update the curve with the new data
|
||||
curve.setData(new_x, new_y)
|
||||
# Plot the fitted curve using the server-provided output to avoid requiring lmfit on the client.
|
||||
try:
|
||||
fit_data = msg["data"][0]
|
||||
curve.setData(np.asarray(fit_data["x"]), np.asarray(fit_data["y"]))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to plot DAP result for curve '{curve.name()}'")
|
||||
return
|
||||
|
||||
metadata.update({"curve_id": curve_id})
|
||||
self.dap_params_update.emit(curve.dap_params, metadata)
|
||||
@@ -2377,8 +2437,48 @@ class DemoApp(QMainWindow): # pragma: no cover
|
||||
sigma = 0.8
|
||||
y = amplitude * np.exp(-((x - center) ** 2) / (2 * sigma**2)) + noise
|
||||
|
||||
# 1) No explicit parameters: server will use lmfit defaults/guesses.
|
||||
self.custom_waveform.plot(x=x, y=y, label="custom-gaussian", dap="GaussianModel")
|
||||
|
||||
# 2) Easy dict: numbers mean "fix this parameter to value" (vary=False).
|
||||
self.custom_waveform.plot(
|
||||
x=x,
|
||||
y=y,
|
||||
label="custom-gaussian-fixed-easy",
|
||||
dap="GaussianModel",
|
||||
dap_parameters={"amplitude": 1.0},
|
||||
dap_oversample=5,
|
||||
)
|
||||
|
||||
# 3) lmfit-style dict: any subset of lmfit.Parameter fields.
|
||||
# Here `center` is not fixed (vary=True) but its initial value is set.
|
||||
self.custom_waveform.plot(
|
||||
x=x,
|
||||
y=y,
|
||||
label="custom-gaussian-override-dict",
|
||||
dap="GaussianModel",
|
||||
dap_parameters={
|
||||
"center": {"value": 1.2, "vary": True},
|
||||
"sigma": {"value": sigma, "vary": False, "min": 0.0},
|
||||
},
|
||||
)
|
||||
|
||||
# 4) Passing a real `lmfit.Parameters` object (optional: requires lmfit on the client).
|
||||
if lmfit is not None:
|
||||
params_gauss = lmfit.models.GaussianModel().make_params()
|
||||
params_gauss["amplitude"].set(value=amplitude, vary=False)
|
||||
params_gauss["center"].set(value=center, vary=False)
|
||||
params_gauss["sigma"].set(value=sigma, vary=False, min=0.0)
|
||||
self.custom_waveform.plot(
|
||||
x=x,
|
||||
y=y,
|
||||
label="custom-gaussian-fixed-params",
|
||||
dap="GaussianModel",
|
||||
dap_parameters=params_gauss,
|
||||
)
|
||||
else:
|
||||
logger.info("Skipping lmfit.Parameters demo (lmfit not installed on client).")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
import sys
|
||||
|
||||
@@ -75,6 +75,7 @@ def test_rpc_plotting_shortcuts_init_configs(qtbot, connected_client_gui_obj):
|
||||
assert c1._config_dict["signal"] == {
|
||||
"dap": None,
|
||||
"name": "bpm4i",
|
||||
"dap_parameters": None,
|
||||
"entry": "bpm4i",
|
||||
"dap_oversample": 1,
|
||||
}
|
||||
|
||||
@@ -516,6 +516,57 @@ def test_plot_custom_curve_with_inline_dap(qtbot, mocked_client_with_dap):
|
||||
assert dap_curve.config.signal.dap == "GaussianModel"
|
||||
|
||||
|
||||
def test_normalize_dap_parameters_number_dict():
|
||||
normalized = Waveform._normalize_dap_parameters({"amplitude": 1.0, "center": 2})
|
||||
assert normalized == {
|
||||
"amplitude": {"name": "amplitude", "value": 1.0, "vary": False},
|
||||
"center": {"name": "center", "value": 2.0, "vary": False},
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_dap_parameters_dict_spec_defaults_vary_false():
|
||||
normalized = Waveform._normalize_dap_parameters({"sigma": {"value": 0.8, "min": 0.0}})
|
||||
assert normalized["sigma"]["name"] == "sigma"
|
||||
assert normalized["sigma"]["value"] == 0.8
|
||||
assert normalized["sigma"]["min"] == 0.0
|
||||
assert normalized["sigma"]["vary"] is False
|
||||
|
||||
|
||||
def test_normalize_dap_parameters_invalid_type_raises():
|
||||
with pytest.raises(TypeError):
|
||||
Waveform._normalize_dap_parameters(["amplitude", 1.0]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_request_dap_includes_normalized_parameters(qtbot, mocked_client_with_dap, monkeypatch):
|
||||
wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap)
|
||||
curve = wf.plot(
|
||||
x=[0, 1, 2],
|
||||
y=[1, 2, 3],
|
||||
label="custom-inline-params",
|
||||
dap="GaussianModel",
|
||||
dap_parameters={"amplitude": 1.0},
|
||||
)
|
||||
dap_curve = wf.get_curve(f"{curve.name()}-GaussianModel")
|
||||
assert dap_curve is not None
|
||||
dap_curve.dap_oversample = 3
|
||||
|
||||
captured = {}
|
||||
|
||||
def capture(topic, msg, *args, **kwargs): # noqa: ARG001
|
||||
captured["topic"] = topic
|
||||
captured["msg"] = msg
|
||||
|
||||
monkeypatch.setattr(wf.client.connector, "set_and_publish", capture)
|
||||
wf.request_dap()
|
||||
|
||||
msg = captured["msg"]
|
||||
dap_kwargs = msg.content["config"]["kwargs"]
|
||||
assert dap_kwargs["oversample"] == 3
|
||||
assert dap_kwargs["parameters"] == {
|
||||
"amplitude": {"name": "amplitude", "value": 1.0, "vary": False}
|
||||
}
|
||||
|
||||
|
||||
def test_fetch_scan_data_and_access(qtbot, mocked_client, monkeypatch):
|
||||
"""
|
||||
Test the _fetch_scan_data_and_access method returns live_data/val if in a live scan,
|
||||
|
||||
Reference in New Issue
Block a user