0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

WIP X mode works for sync curves

This commit is contained in:
2025-01-18 21:16:49 +01:00
parent 1139eefb66
commit 4eda839948

View File

@ -1,23 +1,20 @@
from __future__ import annotations
import json
import pyqtgraph as pg
import numpy as np
from collections import defaultdict
from typing import Literal, Optional
from pydantic import Field, field_validator
from qtpy.QtCore import Slot
import numpy as np
import pyqtgraph as pg
from bec_lib.device import ReadoutPriority
from bec_lib.endpoints import MessageEndpoints
from pydantic import Field, field_validator
from qtpy.QtCore import Signal
from qtpy.QtWidgets import QWidget
from bec_lib.endpoints import MessageEndpoints
from bec_widgets.qt_utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils import ConnectionConfig
from bec_widgets.utils.colors import set_theme, Colors
from bec_widgets.utils.colors import Colors, set_theme
from bec_widgets.widgets.plots_next_gen.plot_base import PlotBase
from bec_widgets.widgets.plots_next_gen.waveform.curve import Curve, CurveConfig, DeviceSignal
@ -27,9 +24,6 @@ class WaveformConfig(ConnectionConfig):
color_palette: Optional[str] = Field(
"magma", description="The color palette of the figure widget.", validate_default=True
)
# curves: dict[str, CurveConfig] = Field(
# {}, description="The list of curves to be added to the 1D waveform widget."
# )
model_config: dict = {"validate_assignment": True}
_validate_color_palette = field_validator("color_palette")(Colors.validate_color_map)
@ -87,16 +81,19 @@ class Waveform(PlotBase):
self.scan_id = None
self.scan_item = None
self.current_sources = {"sync": [], "async": []} # TODO maybe not needed
self.x_mode = "auto" # TODO maybe default could be 'best_effort'
self._x_axis_mode = {
"name": None,
"name": "auto",
"entry": None,
"readout_priority": None,
"label_suffix": "",
} # TODO decide which one to use
# Scan status update loop
self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status())
self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress())
# Curve update loop
# TODO review relevant bec_dispatcher signals
# Scan segment update proxy
self.proxy_update_plot = pg.SignalProxy(
self.scan_signal_update, rateLimit=25, slot=self.update_sync_curves
)
@ -105,19 +102,6 @@ class Waveform(PlotBase):
# )
# self.async_signal_update.connect(self.replot_async_curve)
# self.autorange_signal.connect(self.auto_range)
# self.bec_dispatcher.connect_slot(self.on_scan_segment, MessageEndpoints.scan_segment())
# self.bec_dispatcher.connect_slot(
# self.on_scan_segment, MessageEndpoints.scan_segment()
# ) # TODO probably not needed
# Scan status update loop
self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status())
self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress())
# Curve update loop
# self.proxy_scan_update = pg.SignalProxy(
# self.scan_signal_update, rateLimit=25, slot=self.update_sync_curves
# ) # TODO implement
# self.proxy_dap_update = pg.SignalProxy(
# self.dap_signal_update, rateLimit=25, slot=self.update_dap_curves
# ) # TODO implement
@ -133,6 +117,15 @@ class Waveform(PlotBase):
################################################################################
# Widget Specific Properties
################################################################################
@SafeProperty(str)
def x_mode(self) -> str:
return self._x_axis_mode["name"]
@x_mode.setter
def x_mode(self, value: str):
self._x_axis_mode["name"] = value
@SafeProperty(str)
def color_palette(self) -> str:
return self.config.color_palette
@ -194,14 +187,9 @@ class Waveform(PlotBase):
x: list | np.ndarray | None = None,
x_name: str | None = None,
y_name: str | None = None,
z_name: str | None = None, # TODO not needed
x_entry: str | None = None,
y_entry: str | None = None,
z_entry: str | None = None, # TODO not needed
color: str | None = None,
color_map_z: (
str | None
) = "magma", # TODO probably not needed here there will be wrapper for this
label: str | None = None,
validate: bool = True,
dap: str | None = None, # TODO add dap custom curve wrapper
@ -221,12 +209,9 @@ class Waveform(PlotBase):
- "index": Use the index signal.
- Custom signal name of device from BEC.
y_name(str): The name of the device for the y-axis.
z_name(str): The name of the device for the z-axis.
x_entry(str): The name of the entry for the x-axis.
y_entry(str): The name of the entry for the y-axis.
z_entry(str): The name of the entry for the z-axis.
color(str): The color of the curve.
color_map_z(str): The color map to use for the z-axis.
label(str): The label of the curve.
validate(bool): If True, validate the device names and entries.
dap(str): The dap model to use for the curve, only available for sync devices. If not specified, none will be added.
@ -260,6 +245,9 @@ class Waveform(PlotBase):
# 2. BEC device curve logic
self._add_device_curve(y_name, y_entry) # TODO change y_name and y_entry
# 3. X mode logic if provided
# if x_name is not None:
# TODO implement x_mode change if putted by user
# FIXME figure out dap logic adding
@ -415,9 +403,23 @@ class Waveform(PlotBase):
)[len(self.plot_item.curves)]
return color
def _remove_curve_by_source(self, source: str):
# TODO consider if this is needed
pass
def _remove_curve_by_source(self, source: Literal["device", "custom", "dap", "sync", "async"]):
"""
Remove all curves by their source from the plot widget.
Args:
source(str): The source of the curves to remove.
"""
# TODO check logic
for curve in self.curves:
if curve.config.source == source:
self.plot_item.removeItem(curve)
if source == "sync":
for curve in self._sync_curves:
self.plot_item.removeItem(curve)
if source == "async":
for curve in self._async_curves:
self.plot_item.removeItem(curve)
def remove_curve(self, curve: int | str):
"""
@ -475,15 +477,17 @@ class Waveform(PlotBase):
# BEC Update Methods
################################################################################
# TODO here will go bec related update slots
@SafeSlot(dict, dict)
def on_scan_segment(self, msg: dict, meta: dict):
# TODO probably not needed
print(f"Scan segment: {msg}")
@SafeSlot(dict, dict)
def on_scan_status(self, msg: dict, meta: dict):
print(f"Scan status: {msg}")
print(f"Scan status meta: {meta}")
"""
Initial scan status message handler, which is triggered at the begging and end of scan.
Used for triggering the update of the sync and async curves.
Args:
msg(dict): The message content.
meta(dict): The message metadata.
"""
current_scan_id = msg.get("scan_id", None)
readout_priority = msg.get("readout_priority", None)
if current_scan_id is None or readout_priority is None:
@ -511,15 +515,19 @@ class Waveform(PlotBase):
self.scan_signal_update.emit()
self.async_signal_update.emit()
# TODO scan progress update loop triggering curve updates
@SafeSlot(dict, dict)
def on_scan_progress(self, msg: dict, meta: dict, *args, **kwargs):
print(f"Scan progress: {msg}")
print(f"Scan progress meta: {meta}")
def on_scan_progress(self, msg: dict, meta: dict):
"""
Slot for handling scan progress messages. Used for triggering the update of the sync curves.
Args:
msg(dict): The message content.
meta(dict): The message metadata.
"""
self.scan_signal_update.emit()
# @SafeSlot()
def update_sync_curves(self):
print("Updating sync curves")
try:
data = (
self.scan_item.live_data
@ -528,15 +536,84 @@ class Waveform(PlotBase):
)
except AttributeError:
return
for curve in self._sync_curves:
device_name = curve.config.signal.name
device_entry = curve.config.signal.entry
device_data = data[device_name][device_entry].val
x_name = self.scan_item.status_message.info["scan_report_devices"][0]
# TODO logic for x_entry
x_entry = x_name
x_data = data[x_name][x_entry].val
curve.setData(x_data, device_data)
device_data = data.get(device_name, {}).get(device_entry, {}).get("val", None)
x_data = self._get_x_data(device_name, device_entry)
if device_data is not None and x_data is not None:
curve.setData(x_data, device_data)
if device_data is not None and x_data is None:
curve.setData(device_data)
def _get_x_data(self, device_name: str, device_entry: str):
"""
Get the x data for the curves with the decision logic based on the widget x mode configuration:
- If x is called 'timestamp', use the timestamp data from the scan item.
- If x is called 'index', use the rolling index.
- If x is a custom signal, use the data from the scan item.
- If x is not specified, use the first device from the scan report.
Additionally, checks and updates the x label suffix.
Args:
device_name(str): The name of the device.
device_entry(str): The entry of the device
Returns:
list|np.ndarray|None: X data for the curve.
"""
x_data = None
new_suffix = None
live_data = (
self.scan_item.live_data
if hasattr(self.scan_item, "live_data")
else self.scan_item.data
)
if self._x_axis_mode["name"] == "timestamp":
timestamps = live_data[device_name][device_entry].timestamps
x_data = timestamps
new_suffix = " [timestamp]"
if self._x_axis_mode["name"] == "index":
x_data = None
new_suffix = " [index]"
if self._x_axis_mode["name"] is None or self._x_axis_mode["name"] == "auto":
if len(self._async_curves) > 0:
x_data = None
new_suffix = " [auto: index]"
self._update_x_label_suffix(new_suffix)
else:
x_name = self.scan_item.status_message.info["scan_report_devices"][0]
x_entry = self.entry_validator.validate_signal(x_name, None)
x_data = live_data.get(x_name, {}).get(x_entry, {}).get("val", None)
new_suffix = f" [auto: {x_name}-{x_entry}]"
self._update_x_label_suffix(new_suffix)
return x_data
def _update_x_label_suffix(self, new_suffix: str):
"""
Update x_label so it ends with `new_suffix`, removing any old suffix.
Args:
new_suffix(str): The new suffix to add to the x_label.
"""
if new_suffix == self._x_axis_mode["label_suffix"]:
return
old_label = self.x_label
if self._x_axis_mode["label_suffix"] and old_label.endswith(
self._x_axis_mode["label_suffix"]
):
old_label = old_label[: -len(self._x_axis_mode["label_suffix"])]
updated_label = old_label
if new_suffix:
updated_label += new_suffix
self.x_label = updated_label
self._x_axis_mode["label_suffix"] = new_suffix
def _categorise_device_curves(self, readout_priority: dict) -> str:
# Reset sync/async curve lists
@ -574,8 +651,6 @@ class Waveform(PlotBase):
mode = "async"
elif found_sync:
mode = "sync"
else:
mode = "sync"
return mode