0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

WIP Waveform sync and async operation working

This commit is contained in:
2025-01-21 20:34:24 +01:00
parent 03d0dbb7f5
commit 627ac91f55

View File

@ -1,11 +1,12 @@
from __future__ import annotations
import json
from collections import defaultdict
from typing import Literal, Optional
import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger
from bec_lib.device import ReadoutPriority
from bec_lib.endpoints import MessageEndpoints
from pydantic import Field, field_validator
@ -18,6 +19,8 @@ from bec_widgets.utils.colors import Colors, set_theme
from bec_widgets.widgets.plots_next_gen.plot_base import PlotBase
from bec_widgets.widgets.plots_next_gen.waveform.curve import Curve, CurveConfig, DeviceSignal
logger = bec_logger.logger
# noinspection PyDataclass
class WaveformConfig(ConnectionConfig):
@ -68,10 +71,8 @@ class Waveform(PlotBase):
self.setObjectName("Waveform")
# Curve data
self._curves_by_class = defaultdict(dict) # TODO needed can be 'device', 'custom','dap'
self._sync_curves = []
self._async_curves = []
self._curves = self.plot_item.curves
self._mode: Literal["sync", "async", "mixed"] = (
"sync" # TODO mode probably not needed as well, both wil be allowed
)
@ -80,7 +81,6 @@ class Waveform(PlotBase):
self.old_scan_id = None
self.scan_id = None
self.scan_item = None
self.current_sources = {"sync": [], "async": []} # TODO maybe not needed
self._x_axis_mode = {
"name": "auto",
"entry": None,
@ -109,11 +109,6 @@ class Waveform(PlotBase):
# self.async_signal_update, self.update_async_curves
# ) # TODO implement
# TODO test curves
# self.plot([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], label="test_curve")
# self.plot([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], label="test_curve2")
################################################################################
# Widget Specific Properties
################################################################################
@ -126,6 +121,11 @@ class Waveform(PlotBase):
@x_mode.setter
def x_mode(self, value: str):
self._x_axis_mode["name"] = value
self._switch_x_axis_item(mode=value)
# self._update_x_label_suffix() #TODO update straight away or wait for the next scan??
self.async_signal_update.emit()
self.scan_signal_update.emit()
self.plot_item.enableAutoRange(x=True)
@SafeProperty(str)
def color_palette(self) -> str:
@ -153,7 +153,7 @@ class Waveform(PlotBase):
A JSON string property that serializes all curves' pydantic configs.
"""
raw_list = []
for c in self._curves:
for c in self.plot_item.curves:
cfg_dict = c.config.dict()
raw_list.append(cfg_dict)
return json.dumps(raw_list, indent=2)
@ -171,11 +171,11 @@ class Waveform(PlotBase):
Returns:
list: List of curves.
"""
return self._curves
return self.plot_item.curves
@curves.setter
def curves(self, value: list[Curve]):
self._curves = value
self.plot_item.curves = value
################################################################################
# High Level methods for API
@ -222,35 +222,69 @@ class Waveform(PlotBase):
"""
# 1. Custom curve logic
if x is not None and y is not None:
return self._add_curve_custom(x=x, y=y, label=label, color=color, **kwargs)
return self._add_curve(
source="custom", label=label, color=color, x_data=x, y_data=y, **kwargs
)
# Another custom case if user put 'arg1' as data
if isinstance(arg1, list) or isinstance(arg1, np.ndarray):
# if user also gave 'y' => custom
if isinstance(y, list) or isinstance(y, np.ndarray):
return self._add_curve(
source="custom",
label=label,
color=color,
x_data=np.asarray(arg1),
y_data=np.asarray(y),
**kwargs,
)
# if user did not pass 'y', we guess we want to do x=..., y=...
if y is None:
x_ary = np.arange(len(arg1))
return self._add_curve(
source="custom",
label=label,
color=color,
x_data=x_ary,
y_data=np.asarray(arg1),
**kwargs,
)
# if it's a 2D array
if isinstance(arg1, np.ndarray) and arg1.ndim == 2 and y is None:
x_ary = arg1[:, 0]
y_ary = arg1[:, 1]
return self._add_curve(
source="custom", label=label, color=color, x_data=x_ary, y_data=y_ary, **kwargs
)
# 2) If user gave 'arg1' as str => interpret as y_name
if isinstance(arg1, str):
y_name = arg1
elif isinstance(arg1, list):
if isinstance(y, list):
return self._add_curve_custom(x=arg1, y=y, label=label, color=color, **kwargs)
if y is None:
x = np.arange(len(arg1))
return self._add_curve_custom(x=x, y=arg1, label=label, color=color, **kwargs)
elif isinstance(arg1, np.ndarray) and y is None:
if arg1.ndim == 1:
x = np.arange(arg1.size)
return self._add_curve_custom(x=x, y=arg1, label=label, color=color, **kwargs)
if arg1.ndim == 2:
x = arg1[:, 0]
y = arg1[:, 1]
return self._add_curve_custom(x=x, y=y, label=label, color=color, **kwargs)
if y_name is None:
raise ValueError("y_name must be provided.") # TODO provide logger
# 2. BEC device curve logic
# 3) If y_name => device
if y_name is None:
raise ValueError(
"y_name must be provided if not using custom data"
) # TODO provide logger
# TODO make more robust
# if user didn't specify y_entry, fallback
if y_entry is None:
y_entry = y_name
self._add_device_curve(y_name, y_entry) # TODO change y_name and y_entry
# 3. X mode logic if provided
# TODO double check the x_mode logic
# device curve
curve = self._add_curve(
source="device",
label=label,
color=color,
device_name=y_name,
device_entry=y_entry,
**kwargs,
)
# 4) If user gave x_name => store in x_axis_mode
# TODO double check the logic
if x_name is not None:
self._x_axis_mode["name"] = x_name
if x_entry is not None:
@ -260,156 +294,105 @@ class Waveform(PlotBase):
# FIXME figure out dap logic adding
# TODO implement the plot method
return curve
################################################################################
# Curve Management Methods
################################################################################
# TODO implement curve management methods
def _add_device_curve(self, device_name: str, device_signal: str):
"""Add BEC Device curve, can be sync(monitored device) or async device."""
# TODO implement signal fetch from BEC if not provided
# Setup identifiers
source = "device"
curve_id = f"{device_name}-{device_signal}"
# Check if curve already exists
curve_exits = self._check_curve_id(curve_id)
if curve_exits:
raise ValueError(
f"Curve with ID '{curve_id}' already exists in widget '{self.gui_id}'."
) # TODO change to logger
# TODO do device check with BEC if it is loaded
# Create curve by config
color = self._generate_color_from_palette() # TODO check the refresh logic of this
curve_config = CurveConfig(
widget_class="BECCurve",
parent_id=self.gui_id,
label=curve_id,
color=color,
source=source,
signal=DeviceSignal(name=device_name, entry=device_signal),
)
self._add_curve_object(name=curve_id, source=source, config=curve_config)
# TODO consolidate with adding curve object
def _add_curve_custom(
def _add_curve(
self,
x: list | np.ndarray,
y: list | np.ndarray,
label: str = None,
color: str = None,
source: Literal["custom", "device", "dap"],
label: str | None = None,
color: str | None = None,
device_name: str | None = None,
device_entry: str | None = None,
x_data: np.ndarray | None = None,
y_data: np.ndarray | None = None,
**kwargs,
) -> Curve:
"""
Add a custom data curve to the plot widget.
# TODO check the label logic
if not label:
# Generate fallback
if source == "custom":
label = f"Curve {len(self.plot_item.curves) + 1}"
if source == "device":
label = f"{device_name}-{device_entry}"
if source == "dap":
label = f"{device_name}-{device_entry}-DAP"
Args:
x(list|np.ndarray): X data of the curve.
y(list|np.ndarray): Y data of the curve.
label(str, optional): Label of the curve. Defaults to None.
color(str, optional): Color of the curve. Defaults to None.
curve_source(str, optional): Tag for source of the curve. Defaults to "custom".
**kwargs: Additional keyword arguments for the curve configuration.
if self._check_curve_id(label):
raise ValueError(f"Curve with ID '{label}' already exists in widget '{self.gui_id}'.")
Returns:
BECCurve: The curve object.
"""
# If color not provided, pick from the palette
if not color:
color = self._generate_color_from_palette()
curve_id = label or f"Curve {len(self.plot_item.curves) + 1}"
curve_exits = self._check_curve_id(curve_id)
if curve_exits:
raise ValueError(
f"Curve with ID '{curve_id}' already exists in widget '{self.gui_id}'."
) # TODO change to logger
color = (
color
or Colors.golden_angle_color(
colormap="magma", # FIXME Config do not have color_palette anymore
num=max(10, len(self.plot_item.curves) + 1),
format="HEX",
)[len(self.plot_item.curves)]
)
# Create curve by config
curve_config = CurveConfig(
widget_class="BECCurve",
# Build the config
config = CurveConfig(
widget_class="Curve",
parent_id=self.gui_id,
label=curve_id,
label=label,
color=color,
source="custom",
source=source,
**kwargs,
)
curve = self._add_curve_object(
name=curve_id, source="custom", config=curve_config, data=(x, y)
)
# If device-based, add device signal
if source == "device":
if not device_name or not device_entry:
raise ValueError("device_name and device_entry are required for 'device' source.")
config.signal = DeviceSignal(name=device_name, entry=device_entry)
# If custom, we might want x_data, y_data
final_data = None
if source == "custom":
if x_data is None or y_data is None:
raise ValueError("x_data,y_data must be provided for 'custom' source.")
final_data = (x_data, y_data)
# Finally, create the curve item
curve = self._add_curve_object(name=label, source=source, config=config, data=final_data)
return curve
def _add_curve_object(
self,
name: str,
source: str, # todo probably also not needed
source: str,
config: CurveConfig,
data: tuple[list | np.ndarray, list | np.ndarray] = None,
) -> Curve:
"""
Add a curve object to the plot widget.
Args:
name(str): ID of the curve.
source(str): Source of the curve.
config(CurveConfig): Configuration of the curve.
data(tuple[list|np.ndarray,list|np.ndarray], optional): Data (x,y) to be plotted. Defaults to None.
Returns:
BECCurve: The curve object.
"""
# curve_exits = self._check_curve_id(config.label)
# if curve_exits:
# raise ValueError(
# f"Curve with ID '{config.label}' already exists in widget '{self.gui_id}'."
# ) # TODO change to logger
#
# color = (
# color
# or Colors.golden_angle_color(
# colormap="magma", # FIXME Config do not have color_palette anymore
# num=max(10, len(self.plot_item.curves) + 1),
# format="HEX",
# )[len(self.plot_item.curves)]
# )
curve = Curve(config=config, name=name, parent_item=self)
self._curves_by_class[source][name] = curve
# self._curves_by_class[source][name] = curve
self.plot_item.addItem(curve)
# self.config.curves[name] = curve.config #TODO will be changed
if data is not None:
curve.setData(data[0], data[1])
# self.set_legend_label_size() #TODO will be changed
return curve
# TODO decide if needed
def _add_curve(
self,
name: str,
config: CurveConfig,
data: tuple[list | np.ndarray, list | np.ndarray] = None,
):
curve = Curve(name=name, config=config, parent_item=self)
self.plot_item.addItem(curve)
if source == "device":
self.async_signal_update.emit()
self.scan_signal_update.emit()
return curve
def _generate_color_from_palette(self) -> str:
# TODO think about refreshing all colors during this
color = Colors.golden_angle_color(
colormap=self.color_palette, num=max(10, len(self.plot_item.curves) + 1), format="HEX"
)[len(self.plot_item.curves)]
return color
"""
Generate a color for the next new curve, based on the current number of curves.
"""
current_count = len(self.plot_item.curves)
color_list = Colors.golden_angle_color(
colormap=self.color_palette, num=max(10, current_count + 1), format="HEX"
)
return color_list[current_count]
def _refresh_colors(self):
"""
Re-assign colors to all existing curves so they match the new count-based distribution.
"""
all_curves = self.plot_item.curves
# Generate enough colors for the new total
color_list = Colors.golden_angle_color(
colormap=self.color_palette, num=max(10, len(all_curves)), format="HEX"
)
for i, curve in enumerate(all_curves):
curve.set_color(color_list[i])
def _remove_curve_by_source(self, source: Literal["device", "custom", "dap", "sync", "async"]):
"""
@ -429,6 +412,14 @@ class Waveform(PlotBase):
for curve in self._async_curves:
self.plot_item.removeItem(curve)
def clear_all(self):
"""
Clear all curves from the plot widget.
"""
curve_list = [curve for curve in self.plot_item.curves]
for curve in curve_list:
self.remove_curve(curve.name())
def remove_curve(self, curve: int | str):
"""
Remove a curve from the plot widget.
@ -436,11 +427,14 @@ class Waveform(PlotBase):
Args:
curve(int|str): The curve to remove. Can be the order of the curve or the name of the curve.
"""
# TODO check if it removes curve from rpc register !!!!
if isinstance(curve, int):
self._remove_curve_by_order(curve)
elif isinstance(curve, str):
self._remove_curve_by_name(curve)
self._refresh_colors()
def _remove_curve_by_name(self, name: str):
"""
Remove a curve by its name from the plot widget.
@ -451,6 +445,7 @@ class Waveform(PlotBase):
for curve in self.plot_item.curves:
if curve.name() == name:
self.plot_item.removeItem(curve)
self._curve_clean_up(curve)
return
def _remove_curve_by_order(self, N: int):
@ -463,9 +458,22 @@ class Waveform(PlotBase):
if N < len(self.plot_item.curves):
curve = self.plot_item.curves[N]
self.plot_item.removeItem(curve)
self._curve_clean_up(curve)
else:
raise IndexError(f"Curve order {N} out of range.") # TODO can be logged
def _curve_clean_up(self, curve: Curve):
"""
Clean up the curve by disconnecting the async update signal (even for sync curves).
Args:
curve(Curve): The curve to clean up.
"""
self.bec_dispatcher.disconnect_slot(
self.on_async_readback,
MessageEndpoints.device_async_readback(self.scan_id, curve.name()),
)
def _check_curve_id(self, curve_id: str) -> bool:
"""
Check if a curve ID exists in the plot widget.
@ -476,11 +484,25 @@ class Waveform(PlotBase):
Returns:
bool: True if the curve ID exists, False otherwise.
"""
curve_ids = [curve.name() for curve in self._curves]
curve_ids = [curve.name() for curve in self.plot_item.curves]
if curve_id in curve_ids:
return True
return False
# TODO extend and implement
def _get_device_readout_priority(self, name: str):
"""
Get the type of device from the entry_validator.
Args:
name(str): Name of the device.
entry(str): Entry of the device.
Returns:
str: Type of the device.
"""
return self.READOUT_PRIORITY_HANDLER[self.dev[name].readout_priority]
################################################################################
# BEC Update Methods
################################################################################
@ -516,13 +538,18 @@ class Waveform(PlotBase):
# First trigger to sync and async data
if self._mode == "sync":
self.scan_signal_update.emit()
print("Sync mode") # TODO change to logger
elif self._mode == "async":
for curve in self._async_curves:
self._setup_async_curve(curve)
self.async_signal_update.emit()
print("Async mode") # TODO change to logger
else:
self.scan_signal_update.emit()
for curve in self._async_curves:
self._setup_async_curve(curve)
self.async_signal_update.emit()
print("Mixed mode") # TODO change to logger
@SafeSlot(dict, dict)
def on_scan_progress(self, msg: dict, meta: dict):
@ -551,6 +578,8 @@ class Waveform(PlotBase):
device_entry = curve.config.signal.entry
device_data = data.get(device_name, {}).get(device_entry, {}).get("val", None)
x_data = self._get_x_data(device_name, device_entry)
if len(data) == 0: # case if the data is empty because motor is not scanned
return
if device_data is not None and x_data is not None:
curve.setData(x_data, device_data)
if device_data is not None and x_data is None:
@ -570,6 +599,7 @@ class Waveform(PlotBase):
MessageEndpoints.device_async_readback(self.scan_id, name),
from_start=True,
)
print(f"Setup async curve {name}") # TODO change to logger
@SafeSlot(dict, dict)
def on_async_readback(self, msg, metadata):
@ -599,6 +629,7 @@ class Waveform(PlotBase):
x_data = np.hstack((x_data, async_data["timestamp"]))
else:
x_data = async_data["timestamp"]
# FIXME x axis wrong if timestamp switched during scan
curve.setData(x_data, new_data)
else:
curve.setData(new_data)
@ -634,18 +665,38 @@ class Waveform(PlotBase):
else self.scan_item.data
)
# 1 User wants custom signal
# TODO extend validation
if self._x_axis_mode["name"] not in ["timestamp", "index", "auto"]:
x_name = self._x_axis_mode["name"]
x_entry = self._x_axis_mode.get("entry", None)
if x_entry is None:
x_entry = self.entry_validator.validate_signal(x_name, None)
# if the motor was not scanned, an empty list is returned and curves are not updated
x_data = live_data.get(x_name, {}).get(x_entry, {}).get("val", [])
new_suffix = f" [custom: {x_name}-{x_entry}]"
# 2 User wants timestamp
if self._x_axis_mode["name"] == "timestamp":
print("Timestamp mode") # TODO change to logger
print(f"Device name: {device_name}, entry: {device_entry}") # TODO change to logger
timestamps = live_data[device_name][device_entry].timestamps
x_data = timestamps
new_suffix = " [timestamp]"
# 3 User wants index
if self._x_axis_mode["name"] == "index":
x_data = None
new_suffix = " [index]"
# 4 Best effort automatic mode
if self._x_axis_mode["name"] is None or self._x_axis_mode["name"] == "auto":
# 4.1 If there are async curves, use index
if len(self._async_curves) > 0:
x_data = None
new_suffix = " [auto: index]"
self._update_x_label_suffix(new_suffix)
# 4.2 If there are sync curves, use the first device from the scan report
else:
x_name = self.scan_item.status_message.info["scan_report_devices"][0]
x_entry = self.entry_validator.validate_signal(x_name, None)
@ -662,6 +713,7 @@ class Waveform(PlotBase):
Args:
new_suffix(str): The new suffix to add to the x_label.
"""
if new_suffix == self._x_axis_mode["label_suffix"]:
return
@ -678,7 +730,32 @@ class Waveform(PlotBase):
self.x_label = updated_label
self._x_axis_mode["label_suffix"] = new_suffix
def _switch_x_axis_item(self, mode: str):
"""
Switch the x-axis mode between timestamp, index, the best effort and custom signal.
Args:
mode(str): Mode of the x-axis.
- "timestamp": Use the timestamp signal.
- "index": Use the index signal.
- "best_effort": Use the best effort signal.
- Custom signal name of device from BEC.
"""
print(f'Switching x-axis mode to "{mode}"') # TODO change to logger
date_axis = pg.graphicsItems.DateAxisItem.DateAxisItem(orientation="bottom")
default_axis = pg.AxisItem(orientation="bottom")
if mode == "timestamp":
self.plot_item.setAxisItems({"bottom": date_axis})
else:
self.plot_item.setAxisItems({"bottom": default_axis})
def _categorise_device_curves(self, readout_priority: dict) -> str:
"""
Categorise the device curves into sync and async based on the readout priority.
Args:
readout_priority(dict): The readout priority of the scan.
"""
# Reset sync/async curve lists
self._async_curves = []
self._sync_curves = []
@ -690,7 +767,7 @@ class Waveform(PlotBase):
readout_priority_sync = readout_priority.get("monitored", [])
# Iterate over all curves
for curve_id, curve in self._curves_by_class["device"].items():
for curve in self.curves:
dev_name = curve.config.signal.name
if dev_name in readout_priority_async:
self._async_curves.append(curve)
@ -715,6 +792,10 @@ class Waveform(PlotBase):
elif found_sync:
mode = "sync"
print(f"Mode: {mode}") # TODO change to logger
print(f"Sync curves: {self._sync_curves}")
print(f"Async curves: {self._async_curves}")
return mode
################################################################################