0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

WIP Dap implemnetation work in progress

This commit is contained in:
2025-01-22 12:27:17 +01:00
parent 627ac91f55
commit a990ad4bf4

View File

@ -7,6 +7,7 @@ import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger
from bec_lib import messages
from bec_lib.device import ReadoutPriority
from bec_lib.endpoints import MessageEndpoints
from pydantic import Field, field_validator
@ -47,6 +48,7 @@ class Waveform(PlotBase):
# TODO implement signals
scan_signal_update = Signal() # TODO maybe rename to async_signal_update
async_signal_update = Signal()
dap_signal_update = Signal()
# dap_params_update = Signal(dict, dict)
# dap_summary_update = Signal(dict, dict)
# autorange_signal = Signal()
@ -73,6 +75,7 @@ class Waveform(PlotBase):
# Curve data
self._sync_curves = []
self._async_curves = []
self._dap_curves = []
self._mode: Literal["sync", "async", "mixed"] = (
"sync" # TODO mode probably not needed as well, both wil be allowed
)
@ -97,6 +100,9 @@ class Waveform(PlotBase):
self.proxy_update_plot = pg.SignalProxy(
self.scan_signal_update, rateLimit=25, slot=self.update_sync_curves
)
self.proxy_dap_update = pg.SignalProxy(
self.dap_signal_update, rateLimit=10, slot=self._update_dap
)
# self.proxy_update_dap = pg.SignalProxy(
# self.scan_signal_update, rateLimit=25, slot=self.refresh_dap
# )
@ -117,7 +123,6 @@ class Waveform(PlotBase):
def x_mode(self) -> str:
return self._x_axis_mode["name"]
# TODO implement automatic x mode suffix update according to mode
@x_mode.setter
def x_mode(self, value: str):
self._x_axis_mode["name"] = value
@ -181,6 +186,7 @@ class Waveform(PlotBase):
# High Level methods for API
################################################################################
# TODO such as plot, add, remove curve, etc.
@SafeSlot(popup_error=True)
def plot(
self,
arg1: list | np.ndarray | str | None = None,
@ -192,7 +198,7 @@ class Waveform(PlotBase):
y_entry: str | None = None,
color: str | None = None,
label: str | None = None,
validate: bool = True,
validate: bool = True, # TODO global vs local validation rules
dap: str | None = None, # TODO add dap custom curve wrapper
**kwargs,
) -> Curve:
@ -264,14 +270,14 @@ class Waveform(PlotBase):
# 3) If y_name => device
if y_name is None:
raise ValueError(
"y_name must be provided if not using custom data"
) # TODO provide logger
logger.error("y_name must be provided if not using custom data")
raise ValueError("y_name must be provided if not using custom data")
# TODO make more robust
# if user didn't specify y_entry, fallback
if y_entry is None:
y_entry = y_name
# TODO decide if to use logger or raise
# try:
y_entry = self.entry_validator.validate_signal(name=y_name, entry=y_entry)
# except ValueError:
# self.notification_label()
# device curve
curve = self._add_curve(
@ -287,17 +293,22 @@ class Waveform(PlotBase):
# TODO double check the logic
if x_name is not None:
self._x_axis_mode["name"] = x_name
if x_entry is not None:
self._x_axis_mode["entry"] = x_entry
if x_name not in ["timestamp", "index", "auto"]:
self._x_axis_mode["entry"] = self.entry_validator.validate_signal(
name=x_name, entry=x_entry
)
# TODO implement x_mode change if putted by user
if dap is not None:
self._add_curve(
source="dap", device_name=y_name, device_entry=y_entry, dap=dap, **kwargs
)
# FIXME figure out dap logic adding
# TODO implement the plot method
return curve
################################################################################
# Curve Management Methods
def _add_dap_curve(self): ...
def _add_curve(
self,
source: Literal["custom", "device", "dap"],
@ -307,9 +318,12 @@ class Waveform(PlotBase):
device_entry: str | None = None,
x_data: np.ndarray | None = None,
y_data: np.ndarray | None = None,
dap: str | None = None,
**kwargs,
) -> Curve:
# TODO check the label logic
# TODO parent_label has to be done better with consideration of custom labels
parent_label = None
if not label:
# Generate fallback
if source == "custom":
@ -317,7 +331,8 @@ class Waveform(PlotBase):
if source == "device":
label = f"{device_name}-{device_entry}"
if source == "dap":
label = f"{device_name}-{device_entry}-DAP"
label = f"{device_name}-{device_entry}-{dap}"
parent_label = f"{device_name}-{device_entry}"
if self._check_curve_id(label):
raise ValueError(f"Curve with ID '{label}' already exists in widget '{self.gui_id}'.")
@ -333,13 +348,15 @@ class Waveform(PlotBase):
label=label,
color=color,
source=source,
parent_label=parent_label,
**kwargs,
)
# If device-based, add device signal
if source == "device":
if not device_name or not device_entry:
raise ValueError("device_name and device_entry are required for 'device' source.")
self.entry_validator.validate_signal(device_name, device_entry) # TODO notify user
# if not device_name or not device_entry:
# raise ValueError("device_name and device_entry are required for 'device' source.")
config.signal = DeviceSignal(name=device_name, entry=device_entry)
# If custom, we might want x_data, y_data
@ -349,6 +366,9 @@ class Waveform(PlotBase):
raise ValueError("x_data,y_data must be provided for 'custom' source.")
final_data = (x_data, y_data)
if source == "dap": # TODO change logic
config.signal.dap = dap
# Finally, create the curve item
curve = self._add_curve_object(name=label, source=source, config=config, data=final_data)
return curve
@ -394,24 +414,6 @@ class Waveform(PlotBase):
for i, curve in enumerate(all_curves):
curve.set_color(color_list[i])
def _remove_curve_by_source(self, source: Literal["device", "custom", "dap", "sync", "async"]):
"""
Remove all curves by their source from the plot widget.
Args:
source(str): The source of the curves to remove.
"""
# TODO check logic
for curve in self.curves:
if curve.config.source == source:
self.plot_item.removeItem(curve)
if source == "sync":
for curve in self._sync_curves:
self.plot_item.removeItem(curve)
if source == "async":
for curve in self._async_curves:
self.plot_item.removeItem(curve)
def clear_all(self):
"""
Clear all curves from the plot widget.
@ -503,6 +505,21 @@ class Waveform(PlotBase):
"""
return self.READOUT_PRIORITY_HANDLER[self.dev[name].readout_priority]
def _find_curve_by_label(self, label: str) -> Curve | None:
"""
Find a curve by its label.
Args:
label(str): The label of the curve to find.
Returns:
Curve|None: The curve object if found, None otherwise.
"""
for c in self.curves:
if c.name() == label:
return c
return None
################################################################################
# BEC Update Methods
################################################################################
@ -640,6 +657,65 @@ class Waveform(PlotBase):
else:
curve.setData(data_plot)
@SafeSlot()
def _update_dap(self):
for dap_curve in self._dap_curves:
parent_label = getattr(dap_curve.config, "dap_parent_label", None)
if not parent_label:
continue
# find the device curve
parent_curve = self._find_curve_by_label(parent_label)
if parent_curve is None:
logger.warning(f"No device curve found for DAP curve '{dap_curve.name()}'!")
continue
x_data, y_data = parent_curve.get_data()
model_name = dap_curve.config.signals.dap
model = getattr(self.dap, model_name)
# TODO implement DAP logic
msg = messages.DAPRequestMessage(
dap_cls="LmfitService1D",
dap_type="on_demand",
config={
"args": [],
"kwargs": {"data_x": x_data, "data_y": y_data}, # TODO add xmin,xmax as before
"class_args": model._plugin_info["class_args"],
"class_kwargs": model._plugin_info["class_kwargs"],
},
metadata={"RID": f"{self.scan_id}-{self.gui_id}"},
)
self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg)
# TODO get data from corresponding curves
# for curve in self._dap_curves:
# corresponding_curve =
@SafeSlot(dict, dict)
def update_dap(self, msg, metadata):
"""Callback for DAP response message."""
...
# pylint: disable=unused-variable
# scan_id, x_name, x_entry, y_name, y_entry = msg["dap_request"].content["config"]["args"]
# model = msg["dap_request"].content["config"]["class_kwargs"]["model"]
#
# curve_id_request = f"{y_name}-{y_entry}-{model}"
#
# for curve_id, curve in self._curves_data["DAP"].items():
# if curve_id == curve_id_request:
# if msg["data"] is not None:
# x = msg["data"][0]["x"]
# y = msg["data"][0]["y"]
# curve.setData(x, y)
# curve.dap_params = msg["data"][1]["fit_parameters"]
# curve.dap_summary = msg["data"][1]["fit_summary"]
# metadata.update({"curve_id": curve_id_request})
# self.dap_params_update.emit(curve.dap_params, metadata)
# self.dap_summary_update.emit(curve.dap_summary, metadata)
# break
#
def _get_x_data(self, device_name: str, device_entry: str):
"""
Get the x data for the curves with the decision logic based on the widget x mode configuration:
@ -759,6 +835,7 @@ class Waveform(PlotBase):
# Reset sync/async curve lists
self._async_curves = []
self._sync_curves = []
self._dap_curves = []
found_async = False
found_sync = False
mode = "sync"
@ -768,6 +845,10 @@ class Waveform(PlotBase):
# Iterate over all curves
for curve in self.curves:
# categorise dap curves firsts
if curve.config.source == "dap":
self._dap_curves.append(curve)
continue
dev_name = curve.config.signal.name
if dev_name in readout_priority_async:
self._async_curves.append(curve)
@ -780,21 +861,19 @@ class Waveform(PlotBase):
f"Device {dev_name} not found in readout priority list."
) # TODO change to logger
# Determine mode of the scan
# Determine the mode of the scan
if found_async and found_sync:
mode = "mixed"
print(
logger.warning(
f"Found both async and sync devices in the scan. X-axis integrity cannot be guaranteed."
) # TODO change to logger
)
# TODO do some prompt to user to decide which mode to use
elif found_async:
mode = "async"
elif found_sync:
mode = "sync"
print(f"Mode: {mode}") # TODO change to logger
print(f"Sync curves: {self._sync_curves}")
print(f"Async curves: {self._async_curves}")
logger.info(f"Curve acquisition mode: {mode}")
return mode
@ -813,6 +892,7 @@ if __name__ == "__main__":
widget = Waveform()
widget.show()
widget.plot("monitor_async")
widget.plot("bullshit")
# widget.plot(y_name="bpm4i", y_entry="bpm4i")
# widget.plot(y_name="bpm3a", y_entry="bpm3a")
sys.exit(app.exec_())