From 2e971e8cc54fe46f2ac28e683a40d47b167c9f16 Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Wed, 28 Jan 2026 21:52:00 +0100 Subject: [PATCH] feat(waveform): composite DAP with multiple models --- bec_widgets/cli/client.py | 24 ++- bec_widgets/widgets/plots/waveform/curve.py | 4 +- .../widgets/plots/waveform/waveform.py | 194 +++++++++++++++--- tests/unit_tests/test_waveform.py | 55 +++++ 4 files changed, 238 insertions(+), 39 deletions(-) diff --git a/bec_widgets/cli/client.py b/bec_widgets/cli/client.py index ccb69716..13d697f6 100644 --- a/bec_widgets/cli/client.py +++ b/bec_widgets/cli/client.py @@ -5959,8 +5959,8 @@ class Waveform(RPCBase): y_entry: "str | None" = None, color: "str | None" = None, label: "str | None" = None, - dap: "str | None" = None, - dap_parameters: "dict | lmfit.Parameters | None | object" = None, + dap: "str | list[str] | None" = None, + dap_parameters: "dict | list | lmfit.Parameters | None | object" = None, scan_id: "str | None" = None, scan_number: "int | None" = None, **kwargs, @@ -5982,11 +5982,14 @@ class Waveform(RPCBase): y_entry(str): The name of the entry for the y-axis. color(str): The color of the curve. label(str): The label of the curve. - dap(str): The dap model to use for the curve. When provided, a DAP curve is + dap(str | list[str]): The dap model to use for the curve. When provided, a DAP curve is attached automatically for device, history, or custom data sources. Use - the same string as the LMFit model name. - dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. - Values can be numeric (interpreted as fixed parameters) or dicts like`{"value": 1.0, "vary": False}`. + the same string as the LMFit model name, or a list of model names to build a composite. + dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to + the DAP server. For a single model: values can be numeric (interpreted as fixed parameters) + or dicts like `{"value": 1.0, "vary": False}`. For composite models (dap is list), use either + a list aligned to the model list (each item is a param dict), or a dict of + `{ "ModelName": { "param": {...} } }` when model names are unique. scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and the y‑data (and optional x‑data) are fetched from that historical scan. Such curves are never cleared by live‑scan resets. @@ -6000,10 +6003,10 @@ class Waveform(RPCBase): def add_dap_curve( self, device_label: "str", - dap_name: "str", + dap_name: "str | list[str]", color: "str | None" = None, dap_oversample: "int" = 1, - dap_parameters: "dict | lmfit.Parameters | None" = None, + dap_parameters: "dict | list | lmfit.Parameters | None" = None, **kwargs, ) -> "Curve": """ @@ -6013,10 +6016,11 @@ class Waveform(RPCBase): Args: device_label(str): The label of the source curve to add DAP to. - dap_name(str): The name of the DAP model to use. + dap_name(str | list[str]): The name of the DAP model to use, or a list of model + names to build a composite model. color(str): The color of the curve. dap_oversample(int): The oversampling factor for the DAP curve. - dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. + dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. **kwargs Returns: diff --git a/bec_widgets/widgets/plots/waveform/curve.py b/bec_widgets/widgets/plots/waveform/curve.py index 8ab8c47d..e95ccc37 100644 --- a/bec_widgets/widgets/plots/waveform/curve.py +++ b/bec_widgets/widgets/plots/waveform/curve.py @@ -22,9 +22,9 @@ class DeviceSignal(BaseModel): name: str entry: str - dap: str | None = None + dap: str | list[str] | None = None dap_oversample: int = 1 - dap_parameters: dict | None = None + dap_parameters: dict | list | None = None model_config: dict = {"validate_assignment": True} diff --git a/bec_widgets/widgets/plots/waveform/waveform.py b/bec_widgets/widgets/plots/waveform/waveform.py index e6bd43ee..aeadcec1 100644 --- a/bec_widgets/widgets/plots/waveform/waveform.py +++ b/bec_widgets/widgets/plots/waveform/waveform.py @@ -705,8 +705,8 @@ class Waveform(PlotBase): y_entry: str | None = None, color: str | None = None, label: str | None = None, - dap: str | None = None, - dap_parameters: dict | lmfit.Parameters | None | object = None, + dap: str | list[str] | None = None, + dap_parameters: dict | list | lmfit.Parameters | None | object = None, scan_id: str | None = None, scan_number: int | None = None, **kwargs, @@ -728,11 +728,14 @@ class Waveform(PlotBase): y_entry(str): The name of the entry for the y-axis. color(str): The color of the curve. label(str): The label of the curve. - dap(str): The dap model to use for the curve. When provided, a DAP curve is + dap(str | list[str]): The dap model to use for the curve. When provided, a DAP curve is attached automatically for device, history, or custom data sources. Use - the same string as the LMFit model name. - dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. - Values can be numeric (interpreted as fixed parameters) or dicts like`{"value": 1.0, "vary": False}`. + the same string as the LMFit model name, or a list of model names to build a composite. + dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to + the DAP server. For a single model: values can be numeric (interpreted as fixed parameters) + or dicts like `{"value": 1.0, "vary": False}`. For composite models (dap is list), use either + a list aligned to the model list (each item is a param dict), or a dict of + `{ "ModelName": { "param": {...} } }` when model names are unique. scan_id(str): Optional scan ID. When provided, the curve is treated as a **history** curve and the y‑data (and optional x‑data) are fetched from that historical scan. Such curves are never cleared by live‑scan resets. @@ -836,10 +839,10 @@ class Waveform(PlotBase): def add_dap_curve( self, device_label: str, - dap_name: str, + dap_name: str | list[str], color: str | None = None, dap_oversample: int = 1, - dap_parameters: dict | lmfit.Parameters | None = None, + dap_parameters: dict | list | lmfit.Parameters | None = None, **kwargs, ) -> Curve: """ @@ -849,10 +852,11 @@ class Waveform(PlotBase): Args: device_label(str): The label of the source curve to add DAP to. - dap_name(str): The name of the DAP model to use. + dap_name(str | list[str]): The name of the DAP model to use, or a list of model + names to build a composite model. color(str): The color of the curve. dap_oversample(int): The oversampling factor for the DAP curve. - dap_parameters(dict | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. + dap_parameters(dict | list | lmfit.Parameters | None): Optional lmfit parameter overrides sent to the DAP server. **kwargs Returns: @@ -877,7 +881,7 @@ class Waveform(PlotBase): dev_entry = "custom" # 2) Build a label for the new DAP curve - dap_label = f"{device_label}-{dap_name}" + dap_label = f"{device_label}-{self._format_dap_label(dap_name)}" # 3) Possibly raise if the DAP curve already exists if self._check_curve_id(dap_label): @@ -904,7 +908,7 @@ class Waveform(PlotBase): entry=dev_entry, dap=dap_name, dap_oversample=dap_oversample, - dap_parameters=self._normalize_dap_parameters(dap_parameters), + dap_parameters=self._normalize_dap_parameters(dap_parameters, dap_name=dap_name), ) # 4) Create the DAP curve config using `_add_curve(...)` @@ -1776,7 +1780,9 @@ class Waveform(PlotBase): x_data, y_data = parent_curve.get_data() model_name = dap_curve.config.signal.dap - model = getattr(self.dap, model_name) + model = None + if not isinstance(model_name, (list, tuple)): + model = getattr(self.dap, model_name) try: x_min, x_max = self.roi_region x_data, y_data = self._crop_data(x_data, y_data, x_min, x_max) @@ -1793,14 +1799,21 @@ class Waveform(PlotBase): if dap_parameters: dap_kwargs["parameters"] = dap_parameters + if model is not None: + class_args = model._plugin_info["class_args"] + class_kwargs = model._plugin_info["class_kwargs"] + else: + class_args = [] + class_kwargs = {"model": model_name} + msg = messages.DAPRequestMessage( dap_cls="LmfitService1D", dap_type="on_demand", config={ "args": [], "kwargs": dap_kwargs, - "class_args": model._plugin_info["class_args"], - "class_kwargs": model._plugin_info["class_kwargs"], + "class_args": class_args, + "class_kwargs": class_kwargs, "curve_label": dap_curve.name(), }, metadata={"RID": f"{self.scan_id}-{self.gui_id}"}, @@ -1808,18 +1821,61 @@ class Waveform(PlotBase): self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg) @staticmethod - def _normalize_dap_parameters(parameters: dict | lmfit.Parameters | None) -> dict | None: + def _normalize_dap_parameters( + parameters: dict | list | lmfit.Parameters | None, dap_name: str | list[str] | None = None + ) -> dict | list | None: """ Normalize user-provided lmfit parameters into a JSON-serializable dict suitable for the DAP server. Supports: - - `lmfit.Parameters` + - `lmfit.Parameters` (single-model only) - `dict[name -> number]` (treated as fixed parameter with `vary=False`) - `dict[name -> dict]` (lmfit.Parameter fields; defaults to `vary=False` if unspecified) - `dict[name -> lmfit.Parameter]` + - composite: `list[dict[param_name -> spec]]` aligned to model list + - composite: `dict[model_name -> dict[param_name -> spec]]` (unique model names only) """ if parameters is None: return None + if isinstance(dap_name, (list, tuple)): + if lmfit is not None and isinstance(parameters, lmfit.Parameters): + raise TypeError("dap_parameters must be a dict when using composite dap models.") + if isinstance(parameters, (list, tuple)): + normalized_list: list[dict | None] = [] + for idx, item in enumerate(parameters): + if item is None: + normalized_list.append(None) + continue + if not isinstance(item, dict): + raise TypeError( + f"dap_parameters list item {idx} must be a dict of parameter overrides." + ) + normalized_list.append(Waveform._normalize_param_overrides(item)) + return normalized_list or None + if not isinstance(parameters, dict): + raise TypeError( + "dap_parameters must be a dict of model->params when using composite dap models." + ) + model_names = set(dap_name) + invalid_models = set(parameters.keys()) - model_names + if invalid_models: + raise TypeError( + f"Invalid dap_parameters keys for composite model: {sorted(invalid_models)}" + ) + normalized_composite: dict[str, dict] = {} + for model_name in dap_name: + model_params = parameters.get(model_name) + if model_params is None: + continue + if not isinstance(model_params, dict): + raise TypeError( + f"dap_parameters for '{model_name}' must be a dict of parameter overrides." + ) + normalized = Waveform._normalize_param_overrides(model_params) + if normalized: + normalized_composite[model_name] = normalized + return normalized_composite or None + if lmfit is not None and isinstance(parameters, lmfit.Parameters): return serialize_lmfit_params(parameters) if not isinstance(parameters, dict): @@ -1829,6 +1885,10 @@ class Waveform(PlotBase): ) raise TypeError("dap_parameters must be a dict or lmfit.Parameters (or omitted).") + return Waveform._normalize_param_overrides(parameters) + + @staticmethod + def _normalize_param_overrides(parameters: dict) -> dict | None: normalized: dict[str, dict] = {} for name, spec in parameters.items(): if spec is None: @@ -1850,6 +1910,12 @@ class Waveform(PlotBase): return normalized or None + @staticmethod + def _format_dap_label(dap_name: str | list[str]) -> str: + if isinstance(dap_name, (list, tuple)): + return "+".join(dap_name) + return dap_name + @SafeSlot(dict, dict) def update_dap_curves(self, msg, metadata): """ @@ -2401,24 +2467,20 @@ class DemoApp(QMainWindow): # pragma: no cover def __init__(self): super().__init__() self.setWindowTitle("Waveform Demo") - self.resize(1200, 600) + self.resize(1600, 600) self.main_widget = QWidget(self) self.layout = QHBoxLayout(self.main_widget) self.setCentralWidget(self.main_widget) - self.waveform_popup = Waveform(popups=True) - self.waveform_popup.plot(y_name="waveform") - - self.waveform_side = Waveform(popups=False) - self.waveform_side.plot(y_name="bpm4i", y_entry="bpm4i", dap="GaussianModel") - self.waveform_side.plot(y_name="bpm3a", y_entry="bpm3a") - self.custom_waveform = Waveform(popups=True) self._populate_custom_curve_demo() - self.layout.addWidget(self.waveform_side) - self.layout.addWidget(self.waveform_popup) + self.sine_waveform = Waveform(popups=True) + self.sine_waveform.dap_params_update.connect(self._log_sine_dap_params) + self._populate_sine_curve_demo() + self.layout.addWidget(self.custom_waveform) + self.layout.addWidget(self.sine_waveform) def _populate_custom_curve_demo(self): """ @@ -2479,6 +2541,84 @@ class DemoApp(QMainWindow): # pragma: no cover else: logger.info("Skipping lmfit.Parameters demo (lmfit not installed on client).") + # Composite example: spectrum with three Gaussians (DAP-only) + x_spec = np.linspace(-5, 5, 800) + rng_spec = np.random.default_rng(123) + centers = [-2.0, 0.6, 2.4] + amplitudes = [2.5, 3.2, 1.8] + sigmas = [0.35, 0.5, 0.3] + y_spec = ( + amplitudes[0] * np.exp(-((x_spec - centers[0]) ** 2) / (2 * sigmas[0] ** 2)) + + amplitudes[1] * np.exp(-((x_spec - centers[1]) ** 2) / (2 * sigmas[1] ** 2)) + + amplitudes[2] * np.exp(-((x_spec - centers[2]) ** 2) / (2 * sigmas[2] ** 2)) + + rng_spec.normal(loc=0, scale=0.06, size=x_spec.size) + ) + + self.custom_waveform.plot( + x=x_spec, + y=y_spec, + label="custom-gaussian-spectrum-fit", + dap=["GaussianModel", "GaussianModel", "GaussianModel"], + dap_parameters=[ + {"center": {"value": centers[0], "vary": False}}, + {"center": {"value": centers[1], "vary": False}}, + {"center": {"value": centers[2], "vary": False}}, + ], + ) + + def _populate_sine_curve_demo(self): + """ + Showcase how lmfit's base SineModel can struggle with a drifting baseline. + """ + x = np.linspace(0, 6 * np.pi, 600) + rng = np.random.default_rng(7) + amplitude = 1.6 + frequency = 0.75 + phase = 0.4 + offset = 0.8 + slope = 0.08 + noise = rng.normal(loc=0, scale=0.12, size=x.size) + y = offset + slope * x + amplitude * np.sin(2 * np.pi * frequency * x + phase) + noise + + # Base SineModel (no offset support) to show the mismatch + self.sine_waveform.plot(x=x, y=y, label="custom-sine-data", dap="SineModel") + + # Composite model: Sine + Linear baseline (offset + slope) + self.sine_waveform.plot( + x=x, + y=y, + label="custom-sine-composite", + dap=["SineModel", "LinearModel"], + dap_oversample=4, + # TODO have to guess correctly units for LMFit SineModel + # dap_parameters={ + # "SineModel": { + # "amplitude": {"value": amplitude * 0.9, "vary": True}, + # "frequency": {"value": 2 * np.pi * frequency * 1.05, "vary": True}, + # "shift": {"value": 0.0, "vary": True}, + # }, + # "LinearModel": { + # "intercept": {"value": offset, "vary": True}, + # "slope": {"value": slope, "vary": True}, + # }, + # }, + ) + + if lmfit is None: + logger.info("Skipping sine lmfit demo (lmfit not installed on client).") + return + + return + + def _log_sine_dap_params(self, params: dict, metadata: dict): + curve_id = metadata.get("curve_id") + if curve_id not in { + "custom-sine-data-SineModel", + "custom-sine-composite-SineModel+LinearModel", + }: + return + logger.info(f"SineModel DAP fit params ({curve_id}): {params}") + if __name__ == "__main__": # pragma: no cover import sys diff --git a/tests/unit_tests/test_waveform.py b/tests/unit_tests/test_waveform.py index a2a8570c..a484da89 100644 --- a/tests/unit_tests/test_waveform.py +++ b/tests/unit_tests/test_waveform.py @@ -537,6 +537,31 @@ def test_normalize_dap_parameters_invalid_type_raises(): Waveform._normalize_dap_parameters(["amplitude", 1.0]) # type: ignore[arg-type] +def test_normalize_dap_parameters_composite_list(): + normalized = Waveform._normalize_dap_parameters( + [{"center": 1.0}, {"sigma": {"value": 0.5, "min": 0.0}}], + dap_name=["GaussianModel", "GaussianModel"], + ) + assert normalized == [ + {"center": {"name": "center", "value": 1.0, "vary": False}}, + {"sigma": {"name": "sigma", "value": 0.5, "min": 0.0, "vary": False}}, + ] + + +def test_normalize_dap_parameters_composite_dict(): + normalized = Waveform._normalize_dap_parameters( + { + "GaussianModel": {"center": {"value": 1.0, "vary": True}}, + "LorentzModel": {"amplitude": 2.0}, + }, + dap_name=["GaussianModel", "LorentzModel"], + ) + assert normalized["GaussianModel"]["center"]["value"] == 1.0 + assert normalized["GaussianModel"]["center"]["vary"] is True + assert normalized["LorentzModel"]["amplitude"]["value"] == 2.0 + assert normalized["LorentzModel"]["amplitude"]["vary"] is False + + def test_request_dap_includes_normalized_parameters(qtbot, mocked_client_with_dap, monkeypatch): wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap) curve = wf.plot( @@ -567,6 +592,36 @@ def test_request_dap_includes_normalized_parameters(qtbot, mocked_client_with_da } +def test_request_dap_includes_composite_parameters_list(qtbot, mocked_client_with_dap, monkeypatch): + wf = create_widget(qtbot, Waveform, client=mocked_client_with_dap) + curve = wf.plot( + x=[0, 1, 2], + y=[1, 2, 3], + label="custom-composite", + dap=["GaussianModel", "GaussianModel"], + dap_parameters=[{"center": 0.0}, {"center": 1.0}], + ) + dap_curve = wf.get_curve(f"{curve.name()}-GaussianModel+GaussianModel") + assert dap_curve is not None + + captured = {} + + def capture(topic, msg, *args, **kwargs): # noqa: ARG001 + captured["topic"] = topic + captured["msg"] = msg + + monkeypatch.setattr(wf.client.connector, "set_and_publish", capture) + wf.request_dap() + + msg = captured["msg"] + dap_kwargs = msg.content["config"]["kwargs"] + assert dap_kwargs["parameters"] == [ + {"center": {"name": "center", "value": 0.0, "vary": False}}, + {"center": {"name": "center", "value": 1.0, "vary": False}}, + ] + assert msg.content["config"]["class_kwargs"]["model"] == ["GaussianModel", "GaussianModel"] + + def test_fetch_scan_data_and_access(qtbot, mocked_client, monkeypatch): """ Test the _fetch_scan_data_and_access method returns live_data/val if in a live scan,