diff --git a/bec_lib/bec_lib/dap_plugin_objects.py b/bec_lib/bec_lib/dap_plugin_objects.py index 644b3c5c..b2d97f88 100644 --- a/bec_lib/bec_lib/dap_plugin_objects.py +++ b/bec_lib/bec_lib/dap_plugin_objects.py @@ -1,5 +1,6 @@ from __future__ import annotations +import builtins import time import uuid from typing import TYPE_CHECKING @@ -16,6 +17,13 @@ from bec_lib.scan_items import ScanItem if TYPE_CHECKING: from bec_lib.client import BECClient +try: + import matplotlib.pyplot as plt + + plt.ion() +except ImportError: + plt = None + class DAPPluginObjectBase: """ @@ -86,6 +94,8 @@ class DAPPluginObjectBase: return self._convert_result(response) def _convert_result(self, result: messages.BECMessage): + if not result.content["data"]: + return None if not callable(self._result_cls): return result.content["data"] # pylint: disable=not-callable @@ -168,10 +178,14 @@ class LmfitService1DResult: Result of fitting 1D data using lmfit. """ - def __init__(self, result: list[dict], model_name: str = None): + def __init__(self, result: list[dict], model_name: str = None, client: BECClient = None): self._data = result[0] self._report = result[1] self._model = model_name + if client: + self._client = client + else: + self._client = builtins.__dict__.get("bec") if "amplitude" in self.params: self.amplitude = self.params["amplitude"] if "center" in self.params: @@ -240,6 +254,44 @@ class LmfitService1DResult: max_index = np.argmax(self._data["y"]) return {"x": self._data["x"][max_index], "y": self._data["y"][max_index]} + @property + def input_data(self): + """ + Get the input data used for the fit. + + Returns: + dict: The input data used for the fit. + """ + input_data = self._report.get("input") + scan_id = input_data.get("scan_id") + if not scan_id: + return None + + scan_item = self._client.queue.scan_storage.find_scan_by_ID(scan_id) + if not scan_item: + return None + + x = scan_item.data[input_data["device_x"]][input_data["signal_x"]].val + y = scan_item.data[input_data["device_y"]][input_data["signal_y"]].val + + return {"x": x, "y": y} + + def plot(self): + """ + Plot the fit. + """ + # move this to BECWidgets once it's available + if not plt: + raise ImportError( + "matplotlib is not installed. Cannot plot. Please install matplotlib using 'pip install matplotlib'." + ) + input_data = self.input_data + plt.figure() + plt.plot(input_data["x"], input_data["y"], label="data", color="black", marker="o") + plt.plot(self._data["x"], self._data["y"], label=f"{self._model}", color="red") + plt.legend() + plt.show() + def __str__(self) -> str: return f"{self._model} fit result: \n Params: {self.params} \n Min: {self.min} \n Max: {self.max}"