mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 11:41:49 +02:00
WIP Dap implemnetation work in progress
This commit is contained in:
@ -7,6 +7,7 @@ import numpy as np
|
||||
import pyqtgraph as pg
|
||||
|
||||
from bec_lib import bec_logger
|
||||
from bec_lib import messages
|
||||
from bec_lib.device import ReadoutPriority
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from pydantic import Field, field_validator
|
||||
@ -47,6 +48,7 @@ class Waveform(PlotBase):
|
||||
# TODO implement signals
|
||||
scan_signal_update = Signal() # TODO maybe rename to async_signal_update
|
||||
async_signal_update = Signal()
|
||||
dap_signal_update = Signal()
|
||||
# dap_params_update = Signal(dict, dict)
|
||||
# dap_summary_update = Signal(dict, dict)
|
||||
# autorange_signal = Signal()
|
||||
@ -73,6 +75,7 @@ class Waveform(PlotBase):
|
||||
# Curve data
|
||||
self._sync_curves = []
|
||||
self._async_curves = []
|
||||
self._dap_curves = []
|
||||
self._mode: Literal["sync", "async", "mixed"] = (
|
||||
"sync" # TODO mode probably not needed as well, both wil be allowed
|
||||
)
|
||||
@ -97,6 +100,9 @@ class Waveform(PlotBase):
|
||||
self.proxy_update_plot = pg.SignalProxy(
|
||||
self.scan_signal_update, rateLimit=25, slot=self.update_sync_curves
|
||||
)
|
||||
self.proxy_dap_update = pg.SignalProxy(
|
||||
self.dap_signal_update, rateLimit=10, slot=self._update_dap
|
||||
)
|
||||
# self.proxy_update_dap = pg.SignalProxy(
|
||||
# self.scan_signal_update, rateLimit=25, slot=self.refresh_dap
|
||||
# )
|
||||
@ -117,7 +123,6 @@ class Waveform(PlotBase):
|
||||
def x_mode(self) -> str:
|
||||
return self._x_axis_mode["name"]
|
||||
|
||||
# TODO implement automatic x mode suffix update according to mode
|
||||
@x_mode.setter
|
||||
def x_mode(self, value: str):
|
||||
self._x_axis_mode["name"] = value
|
||||
@ -181,6 +186,7 @@ class Waveform(PlotBase):
|
||||
# High Level methods for API
|
||||
################################################################################
|
||||
# TODO such as plot, add, remove curve, etc.
|
||||
@SafeSlot(popup_error=True)
|
||||
def plot(
|
||||
self,
|
||||
arg1: list | np.ndarray | str | None = None,
|
||||
@ -192,7 +198,7 @@ class Waveform(PlotBase):
|
||||
y_entry: str | None = None,
|
||||
color: str | None = None,
|
||||
label: str | None = None,
|
||||
validate: bool = True,
|
||||
validate: bool = True, # TODO global vs local validation rules
|
||||
dap: str | None = None, # TODO add dap custom curve wrapper
|
||||
**kwargs,
|
||||
) -> Curve:
|
||||
@ -264,14 +270,14 @@ class Waveform(PlotBase):
|
||||
|
||||
# 3) If y_name => device
|
||||
if y_name is None:
|
||||
raise ValueError(
|
||||
"y_name must be provided if not using custom data"
|
||||
) # TODO provide logger
|
||||
logger.error("y_name must be provided if not using custom data")
|
||||
raise ValueError("y_name must be provided if not using custom data")
|
||||
|
||||
# TODO make more robust
|
||||
# if user didn't specify y_entry, fallback
|
||||
if y_entry is None:
|
||||
y_entry = y_name
|
||||
# TODO decide if to use logger or raise
|
||||
# try:
|
||||
y_entry = self.entry_validator.validate_signal(name=y_name, entry=y_entry)
|
||||
# except ValueError:
|
||||
# self.notification_label()
|
||||
|
||||
# device curve
|
||||
curve = self._add_curve(
|
||||
@ -287,17 +293,22 @@ class Waveform(PlotBase):
|
||||
# TODO double check the logic
|
||||
if x_name is not None:
|
||||
self._x_axis_mode["name"] = x_name
|
||||
if x_entry is not None:
|
||||
self._x_axis_mode["entry"] = x_entry
|
||||
if x_name not in ["timestamp", "index", "auto"]:
|
||||
self._x_axis_mode["entry"] = self.entry_validator.validate_signal(
|
||||
name=x_name, entry=x_entry
|
||||
)
|
||||
|
||||
# TODO implement x_mode change if putted by user
|
||||
if dap is not None:
|
||||
self._add_curve(
|
||||
source="dap", device_name=y_name, device_entry=y_entry, dap=dap, **kwargs
|
||||
)
|
||||
|
||||
# FIXME figure out dap logic adding
|
||||
# TODO implement the plot method
|
||||
return curve
|
||||
|
||||
################################################################################
|
||||
# Curve Management Methods
|
||||
def _add_dap_curve(self): ...
|
||||
|
||||
def _add_curve(
|
||||
self,
|
||||
source: Literal["custom", "device", "dap"],
|
||||
@ -307,9 +318,12 @@ class Waveform(PlotBase):
|
||||
device_entry: str | None = None,
|
||||
x_data: np.ndarray | None = None,
|
||||
y_data: np.ndarray | None = None,
|
||||
dap: str | None = None,
|
||||
**kwargs,
|
||||
) -> Curve:
|
||||
# TODO check the label logic
|
||||
# TODO parent_label has to be done better with consideration of custom labels
|
||||
parent_label = None
|
||||
if not label:
|
||||
# Generate fallback
|
||||
if source == "custom":
|
||||
@ -317,7 +331,8 @@ class Waveform(PlotBase):
|
||||
if source == "device":
|
||||
label = f"{device_name}-{device_entry}"
|
||||
if source == "dap":
|
||||
label = f"{device_name}-{device_entry}-DAP"
|
||||
label = f"{device_name}-{device_entry}-{dap}"
|
||||
parent_label = f"{device_name}-{device_entry}"
|
||||
|
||||
if self._check_curve_id(label):
|
||||
raise ValueError(f"Curve with ID '{label}' already exists in widget '{self.gui_id}'.")
|
||||
@ -333,13 +348,15 @@ class Waveform(PlotBase):
|
||||
label=label,
|
||||
color=color,
|
||||
source=source,
|
||||
parent_label=parent_label,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If device-based, add device signal
|
||||
if source == "device":
|
||||
if not device_name or not device_entry:
|
||||
raise ValueError("device_name and device_entry are required for 'device' source.")
|
||||
self.entry_validator.validate_signal(device_name, device_entry) # TODO notify user
|
||||
# if not device_name or not device_entry:
|
||||
# raise ValueError("device_name and device_entry are required for 'device' source.")
|
||||
config.signal = DeviceSignal(name=device_name, entry=device_entry)
|
||||
|
||||
# If custom, we might want x_data, y_data
|
||||
@ -349,6 +366,9 @@ class Waveform(PlotBase):
|
||||
raise ValueError("x_data,y_data must be provided for 'custom' source.")
|
||||
final_data = (x_data, y_data)
|
||||
|
||||
if source == "dap": # TODO change logic
|
||||
config.signal.dap = dap
|
||||
|
||||
# Finally, create the curve item
|
||||
curve = self._add_curve_object(name=label, source=source, config=config, data=final_data)
|
||||
return curve
|
||||
@ -394,24 +414,6 @@ class Waveform(PlotBase):
|
||||
for i, curve in enumerate(all_curves):
|
||||
curve.set_color(color_list[i])
|
||||
|
||||
def _remove_curve_by_source(self, source: Literal["device", "custom", "dap", "sync", "async"]):
|
||||
"""
|
||||
Remove all curves by their source from the plot widget.
|
||||
|
||||
Args:
|
||||
source(str): The source of the curves to remove.
|
||||
"""
|
||||
# TODO check logic
|
||||
for curve in self.curves:
|
||||
if curve.config.source == source:
|
||||
self.plot_item.removeItem(curve)
|
||||
if source == "sync":
|
||||
for curve in self._sync_curves:
|
||||
self.plot_item.removeItem(curve)
|
||||
if source == "async":
|
||||
for curve in self._async_curves:
|
||||
self.plot_item.removeItem(curve)
|
||||
|
||||
def clear_all(self):
|
||||
"""
|
||||
Clear all curves from the plot widget.
|
||||
@ -503,6 +505,21 @@ class Waveform(PlotBase):
|
||||
"""
|
||||
return self.READOUT_PRIORITY_HANDLER[self.dev[name].readout_priority]
|
||||
|
||||
def _find_curve_by_label(self, label: str) -> Curve | None:
|
||||
"""
|
||||
Find a curve by its label.
|
||||
|
||||
Args:
|
||||
label(str): The label of the curve to find.
|
||||
|
||||
Returns:
|
||||
Curve|None: The curve object if found, None otherwise.
|
||||
"""
|
||||
for c in self.curves:
|
||||
if c.name() == label:
|
||||
return c
|
||||
return None
|
||||
|
||||
################################################################################
|
||||
# BEC Update Methods
|
||||
################################################################################
|
||||
@ -640,6 +657,65 @@ class Waveform(PlotBase):
|
||||
else:
|
||||
curve.setData(data_plot)
|
||||
|
||||
@SafeSlot()
|
||||
def _update_dap(self):
|
||||
for dap_curve in self._dap_curves:
|
||||
parent_label = getattr(dap_curve.config, "dap_parent_label", None)
|
||||
if not parent_label:
|
||||
continue
|
||||
# find the device curve
|
||||
parent_curve = self._find_curve_by_label(parent_label)
|
||||
if parent_curve is None:
|
||||
logger.warning(f"No device curve found for DAP curve '{dap_curve.name()}'!")
|
||||
continue
|
||||
|
||||
x_data, y_data = parent_curve.get_data()
|
||||
model_name = dap_curve.config.signals.dap
|
||||
model = getattr(self.dap, model_name)
|
||||
|
||||
# TODO implement DAP logic
|
||||
msg = messages.DAPRequestMessage(
|
||||
dap_cls="LmfitService1D",
|
||||
dap_type="on_demand",
|
||||
config={
|
||||
"args": [],
|
||||
"kwargs": {"data_x": x_data, "data_y": y_data}, # TODO add xmin,xmax as before
|
||||
"class_args": model._plugin_info["class_args"],
|
||||
"class_kwargs": model._plugin_info["class_kwargs"],
|
||||
},
|
||||
metadata={"RID": f"{self.scan_id}-{self.gui_id}"},
|
||||
)
|
||||
self.client.connector.set_and_publish(MessageEndpoints.dap_request(), msg)
|
||||
|
||||
# TODO get data from corresponding curves
|
||||
# for curve in self._dap_curves:
|
||||
# corresponding_curve =
|
||||
|
||||
@SafeSlot(dict, dict)
|
||||
def update_dap(self, msg, metadata):
|
||||
"""Callback for DAP response message."""
|
||||
...
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
# scan_id, x_name, x_entry, y_name, y_entry = msg["dap_request"].content["config"]["args"]
|
||||
# model = msg["dap_request"].content["config"]["class_kwargs"]["model"]
|
||||
#
|
||||
# curve_id_request = f"{y_name}-{y_entry}-{model}"
|
||||
#
|
||||
# for curve_id, curve in self._curves_data["DAP"].items():
|
||||
# if curve_id == curve_id_request:
|
||||
# if msg["data"] is not None:
|
||||
# x = msg["data"][0]["x"]
|
||||
# y = msg["data"][0]["y"]
|
||||
# curve.setData(x, y)
|
||||
# curve.dap_params = msg["data"][1]["fit_parameters"]
|
||||
# curve.dap_summary = msg["data"][1]["fit_summary"]
|
||||
# metadata.update({"curve_id": curve_id_request})
|
||||
# self.dap_params_update.emit(curve.dap_params, metadata)
|
||||
# self.dap_summary_update.emit(curve.dap_summary, metadata)
|
||||
# break
|
||||
#
|
||||
|
||||
def _get_x_data(self, device_name: str, device_entry: str):
|
||||
"""
|
||||
Get the x data for the curves with the decision logic based on the widget x mode configuration:
|
||||
@ -759,6 +835,7 @@ class Waveform(PlotBase):
|
||||
# Reset sync/async curve lists
|
||||
self._async_curves = []
|
||||
self._sync_curves = []
|
||||
self._dap_curves = []
|
||||
found_async = False
|
||||
found_sync = False
|
||||
mode = "sync"
|
||||
@ -768,6 +845,10 @@ class Waveform(PlotBase):
|
||||
|
||||
# Iterate over all curves
|
||||
for curve in self.curves:
|
||||
# categorise dap curves firsts
|
||||
if curve.config.source == "dap":
|
||||
self._dap_curves.append(curve)
|
||||
continue
|
||||
dev_name = curve.config.signal.name
|
||||
if dev_name in readout_priority_async:
|
||||
self._async_curves.append(curve)
|
||||
@ -780,21 +861,19 @@ class Waveform(PlotBase):
|
||||
f"Device {dev_name} not found in readout priority list."
|
||||
) # TODO change to logger
|
||||
|
||||
# Determine mode of the scan
|
||||
# Determine the mode of the scan
|
||||
if found_async and found_sync:
|
||||
mode = "mixed"
|
||||
print(
|
||||
logger.warning(
|
||||
f"Found both async and sync devices in the scan. X-axis integrity cannot be guaranteed."
|
||||
) # TODO change to logger
|
||||
)
|
||||
# TODO do some prompt to user to decide which mode to use
|
||||
elif found_async:
|
||||
mode = "async"
|
||||
elif found_sync:
|
||||
mode = "sync"
|
||||
|
||||
print(f"Mode: {mode}") # TODO change to logger
|
||||
print(f"Sync curves: {self._sync_curves}")
|
||||
print(f"Async curves: {self._async_curves}")
|
||||
logger.info(f"Curve acquisition mode: {mode}")
|
||||
|
||||
return mode
|
||||
|
||||
@ -813,6 +892,7 @@ if __name__ == "__main__":
|
||||
widget = Waveform()
|
||||
widget.show()
|
||||
widget.plot("monitor_async")
|
||||
widget.plot("bullshit")
|
||||
# widget.plot(y_name="bpm4i", y_entry="bpm4i")
|
||||
# widget.plot(y_name="bpm3a", y_entry="bpm3a")
|
||||
sys.exit(app.exec_())
|
||||
|
Reference in New Issue
Block a user