From 76bd0d339ac9ae9e8a3baa0d0d4e951ec1d09670 Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Wed, 8 May 2024 15:12:56 +0200 Subject: [PATCH] feat(widgets/progressbar): SpiralProgressBar added with rpc interface --- bec_widgets/cli/client.py | 252 ++++++++ bec_widgets/cli/generate_cli.py | 5 +- bec_widgets/cli/rpc_wigdet_handler.py | 3 +- .../jupyter_console/jupyter_console_window.py | 6 +- bec_widgets/widgets/__init__.py | 1 + .../widgets/spiral_progress_bar/__init__.py | 1 + .../widgets/spiral_progress_bar/ring.py | 184 ++++++ .../spiral_progress_bar.py | 594 ++++++++++++++++++ tests/end-2-end/test_bec_dock_rpc_e2e.py | 81 +++ tests/unit_tests/test_spiral_progress_bar.py | 338 ++++++++++ 10 files changed, 1461 insertions(+), 4 deletions(-) create mode 100644 bec_widgets/widgets/spiral_progress_bar/__init__.py create mode 100644 bec_widgets/widgets/spiral_progress_bar/ring.py create mode 100644 bec_widgets/widgets/spiral_progress_bar/spiral_progress_bar.py create mode 100644 tests/unit_tests/test_spiral_progress_bar.py diff --git a/bec_widgets/cli/client.py b/bec_widgets/cli/client.py index 50f8278a..9bd674fb 100644 --- a/bec_widgets/cli/client.py +++ b/bec_widgets/cli/client.py @@ -1630,3 +1630,255 @@ class BECDockArea(RPCBase, BECGuiClientMixin): """ Get all registered RPC objects. """ + + +class SpiralProgressBar(RPCBase): + @rpc_call + def get_all_rpc(self) -> "dict": + """ + Get all registered RPC objects. + """ + + @property + @rpc_call + def rpc_id(self) -> "str": + """ + Get the RPC ID of the widget. + """ + + @property + @rpc_call + def config_dict(self) -> "dict": + """ + Get the configuration of the widget. + + Returns: + dict: The configuration of the widget. + """ + + @property + @rpc_call + def rings(self): + """ + None + """ + + @rpc_call + def update_config(self, config: "SpiralProgressBarConfig | dict"): + """ + Update the configuration of the widget. + + Args: + config(SpiralProgressBarConfig|dict): Configuration to update. + """ + + @rpc_call + def add_ring(self, **kwargs) -> "Ring": + """ + Add a new progress bar. + + Args: + **kwargs: Keyword arguments for the new progress bar. + + Returns: + Ring: Ring object. + """ + + @rpc_call + def remove_ring(self, index: "int"): + """ + Remove a progress bar by index. + + Args: + index(int): Index of the progress bar to remove. + """ + + @rpc_call + def set_precision(self, precision: "int", bar_index: "int" = None): + """ + Set the precision for the progress bars. If bar_index is not provide, the precision will be set for all progress bars. + + Args: + precision(int): Precision for the progress bars. + bar_index(int): Index of the progress bar to set the precision for. If provided, only a single precision can be set. + """ + + @rpc_call + def set_min_max_values( + self, + min_values: "int | float | list[int | float]", + max_values: "int | float | list[int | float]", + ): + """ + Set the minimum and maximum values for the progress bars. + + Args: + min_values(int|float | list[float]): Minimum value(s) for the progress bars. If multiple progress bars are displayed, provide a list of minimum values for each progress bar. + max_values(int|float | list[float]): Maximum value(s) for the progress bars. If multiple progress bars are displayed, provide a list of maximum values for each progress bar. + """ + + @rpc_call + def set_number_of_bars(self, num_bars: "int"): + """ + Set the number of progress bars to display. + + Args: + num_bars(int): Number of progress bars to display. + """ + + @rpc_call + def set_value(self, values: "int | list", ring_index: "int" = None): + """ + Set the values for the progress bars. + + Args: + values(int | tuple): Value(s) for the progress bars. If multiple progress bars are displayed, provide a tuple of values for each progress bar. + ring_index(int): Index of the progress bar to set the value for. If provided, only a single value can be set. + + Examples: + >>> SpiralProgressBar.set_value(50) + >>> SpiralProgressBar.set_value([30, 40, 50]) # (outer, middle, inner) + >>> SpiralProgressBar.set_value(60, bar_index=1) # Set the value for the middle progress bar. + """ + + @rpc_call + def set_colors_from_map(self, colormap, color_format: "Literal['RGB', 'HEX']" = "RGB"): + """ + Set the colors for the progress bars from a colormap. + + Args: + colormap(str): Name of the colormap. + color_format(Literal["RGB","HEX"]): Format of the returned colors ('RGB', 'HEX'). + """ + + @rpc_call + def set_colors_directly( + self, colors: "list[str | tuple] | str | tuple", bar_index: "int" = None + ): + """ + Set the colors for the progress bars directly. + + Args: + colors(list[str | tuple] | str | tuple): Color(s) for the progress bars. If multiple progress bars are displayed, provide a list of colors for each progress bar. + bar_index(int): Index of the progress bar to set the color for. If provided, only a single color can be set. + """ + + @rpc_call + def set_line_widths(self, widths: "int | list[int]", bar_index: "int" = None): + """ + Set the line widths for the progress bars. + + Args: + widths(int | list[int]): Line width(s) for the progress bars. If multiple progress bars are displayed, provide a list of line widths for each progress bar. + bar_index(int): Index of the progress bar to set the line width for. If provided, only a single line width can be set. + """ + + @rpc_call + def set_gap(self, gap: "int"): + """ + Set the gap between the progress bars. + + Args: + gap(int): Gap between the progress bars. + """ + + @rpc_call + def set_diameter(self, diameter: "int"): + """ + Set the diameter of the widget. + + Args: + diameter(int): Diameter of the widget. + """ + + @rpc_call + def reset_diameter(self): + """ + Reset the fixed size of the widget. + """ + + @rpc_call + def enable_auto_updates(self, enable: "bool" = True): + """ + Enable or disable updates based on scan status. Overrides manual updates. + The behaviour of the whole progress bar widget will be driven by the scan queue status. + + Args: + enable(bool): True or False. + + Returns: + bool: True if scan segment updates are enabled. + """ + + +class Ring(RPCBase): + @rpc_call + def get_all_rpc(self) -> "dict": + """ + Get all registered RPC objects. + """ + + @property + @rpc_call + def rpc_id(self) -> "str": + """ + Get the RPC ID of the widget. + """ + + @property + @rpc_call + def config_dict(self) -> "dict": + """ + Get the configuration of the widget. + + Returns: + dict: The configuration of the widget. + """ + + @rpc_call + def set_value(self, value: "int | float"): + """ + None + """ + + @rpc_call + def set_color(self, color: "str | tuple"): + """ + None + """ + + @rpc_call + def set_background(self, color: "str | tuple"): + """ + None + """ + + @rpc_call + def set_line_width(self, width: "int"): + """ + None + """ + + @rpc_call + def set_min_max_values(self, min_value: "int", max_value: "int"): + """ + None + """ + + @rpc_call + def set_start_angle(self, start_angle: "int"): + """ + None + """ + + @rpc_call + def set_connections(self, slot: "str", endpoint: "str | EndpointInfo"): + """ + None + """ + + @rpc_call + def reset_connection(self): + """ + None + """ diff --git a/bec_widgets/cli/generate_cli.py b/bec_widgets/cli/generate_cli.py index c9f8dd61..051b2d36 100644 --- a/bec_widgets/cli/generate_cli.py +++ b/bec_widgets/cli/generate_cli.py @@ -109,13 +109,14 @@ if __name__ == "__main__": # pragma: no cover import os from bec_widgets.utils import BECConnector - from bec_widgets.widgets import BECDock, BECDockArea, BECFigure + from bec_widgets.widgets import BECDock, BECDockArea, BECFigure, SpiralProgressBar from bec_widgets.widgets.figure.plots.image.image import BECImageShow from bec_widgets.widgets.figure.plots.image.image_item import BECImageItem from bec_widgets.widgets.figure.plots.motor_map.motor_map import BECMotorMap from bec_widgets.widgets.figure.plots.plot_base import BECPlotBase from bec_widgets.widgets.figure.plots.waveform.waveform import BECWaveform from bec_widgets.widgets.figure.plots.waveform.waveform_curve import BECCurve + from bec_widgets.widgets.spiral_progress_bar.ring import Ring current_path = os.path.dirname(__file__) client_path = os.path.join(current_path, "client.py") @@ -130,6 +131,8 @@ if __name__ == "__main__": # pragma: no cover BECMotorMap, BECDock, BECDockArea, + SpiralProgressBar, + Ring, ] generator = ClientGenerator() generator.generate_client(clss) diff --git a/bec_widgets/cli/rpc_wigdet_handler.py b/bec_widgets/cli/rpc_wigdet_handler.py index 92c46fbc..bc25c632 100644 --- a/bec_widgets/cli/rpc_wigdet_handler.py +++ b/bec_widgets/cli/rpc_wigdet_handler.py @@ -1,11 +1,12 @@ from bec_widgets.utils import BECConnector from bec_widgets.widgets.figure import BECFigure +from bec_widgets.widgets.spiral_progress_bar.spiral_progress_bar import SpiralProgressBar class RPCWidgetHandler: """Handler class for creating widgets from RPC messages.""" - widget_classes = {"BECFigure": BECFigure} + widget_classes = {"BECFigure": BECFigure, "SpiralProgressBar": SpiralProgressBar} @staticmethod def create_widget(widget_type, **kwargs) -> BECConnector: diff --git a/bec_widgets/examples/jupyter_console/jupyter_console_window.py b/bec_widgets/examples/jupyter_console/jupyter_console_window.py index d95db00a..ce27d21a 100644 --- a/bec_widgets/examples/jupyter_console/jupyter_console_window.py +++ b/bec_widgets/examples/jupyter_console/jupyter_console_window.py @@ -13,6 +13,7 @@ from bec_widgets.cli.rpc_register import RPCRegister from bec_widgets.utils import BECDispatcher from bec_widgets.widgets import BECFigure from bec_widgets.widgets.dock.dock_area import BECDockArea +from bec_widgets.widgets.spiral_progress_bar.spiral_progress_bar import SpiralProgressBar class JupyterConsoleWidget(RichJupyterWidget): # pragma: no cover: @@ -62,6 +63,7 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: "d1": self.d1, "d2": self.d2, "d3": self.d3, + "bar": self.bar, "b2a": self.button_2_a, "b2b": self.button_2_b, "b2c": self.button_2_c, @@ -114,14 +116,14 @@ class JupyterConsoleWindow(QWidget): # pragma: no cover: self.button_2_b = QtWidgets.QPushButton("button after without postions specified") self.button_2_c = QtWidgets.QPushButton("button super late") self.button_3 = QtWidgets.QPushButton("Button above Figure ") - self.label_1 = QtWidgets.QLabel("some scan info label with useful information") + self.bar = SpiralProgressBar() self.label_2 = QtWidgets.QLabel("label which is added separately") self.label_3 = QtWidgets.QLabel("Label above figure") self.d1 = self.dock.add_dock(widget=self.button_1, position="left") self.d1.addWidget(self.label_2) - self.d2 = self.dock.add_dock(widget=self.label_1, position="right") + self.d2 = self.dock.add_dock(widget=self.bar, position="right") self.d3 = self.dock.add_dock(name="figure") self.fig_dock3 = BECFigure() self.fig_dock3.plot(x_name="samx", y_name="bpm4d") diff --git a/bec_widgets/widgets/__init__.py b/bec_widgets/widgets/__init__.py index 268d685a..5e041775 100644 --- a/bec_widgets/widgets/__init__.py +++ b/bec_widgets/widgets/__init__.py @@ -1,3 +1,4 @@ from .dock import BECDock, BECDockArea from .figure import BECFigure, FigureConfig from .scan_control import ScanControl +from .spiral_progress_bar import SpiralProgressBar diff --git a/bec_widgets/widgets/spiral_progress_bar/__init__.py b/bec_widgets/widgets/spiral_progress_bar/__init__.py new file mode 100644 index 00000000..81ed0f36 --- /dev/null +++ b/bec_widgets/widgets/spiral_progress_bar/__init__.py @@ -0,0 +1 @@ +from .spiral_progress_bar import SpiralProgressBar diff --git a/bec_widgets/widgets/spiral_progress_bar/ring.py b/bec_widgets/widgets/spiral_progress_bar/ring.py new file mode 100644 index 00000000..b1f6b538 --- /dev/null +++ b/bec_widgets/widgets/spiral_progress_bar/ring.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from bec_lib.endpoints import EndpointInfo +from pydantic import BaseModel, Field, field_validator +from pydantic_core import PydanticCustomError +from qtpy import QtGui + +from bec_widgets.utils import BECConnector, ConnectionConfig + + +class RingConnections(BaseModel): + slot: Literal["on_scan_progress", "on_device_readback"] = None + endpoint: EndpointInfo | str = None + + @field_validator("endpoint") + def validate_endpoint(cls, v, values): + slot = values.data["slot"] + endpoint = v.endpoint if isinstance(v, EndpointInfo) else v + if slot == "on_scan_progress": + if endpoint != "scans/scan_progress": + raise PydanticCustomError( + "unsupported endpoint", + "For slot 'on_scan_progress', endpoint must be MessageEndpoint.scan_progress or 'scans/scan_progress'.", + {"wrong_value": v}, + ) + elif slot == "on_device_readback": + if not endpoint.startswith("internal/devices/readback/"): + raise PydanticCustomError( + "unsupported endpoint", + "For slot 'on_device_readback', endpoint must be MessageEndpoint.device_readback(device) or 'internal/devices/readback/{device}'.", + {"wrong_value": v}, + ) + return v + + +class RingConfig(ConnectionConfig): + direction: int | None = Field( + -1, description="Direction of the progress bars. -1 for clockwise, 1 for counter-clockwise." + ) + color: str | tuple | None = Field( + (0, 159, 227, 255), + description="Color for the progress bars. Can be tuple (R, G, B, A) or string HEX Code.", + ) + background_color: str | tuple | None = Field( + (200, 200, 200, 50), + description="Background color for the progress bars. Can be tuple (R, G, B, A) or string HEX Code.", + ) + index: int | None = Field(0, description="Index of the progress bar. 0 is outer ring.") + line_width: int | None = Field(5, description="Line widths for the progress bars.") + start_position: int | None = Field( + 90, + description="Start position for the progress bars in degrees. Default is 90 degrees - corespons to " + "the top of the ring.", + ) + min_value: int | None = Field(0, description="Minimum value for the progress bars.") + max_value: int | None = Field(100, description="Maximum value for the progress bars.") + precision: int | None = Field(3, description="Precision for the progress bars.") + update_behaviour: Literal["manual", "auto"] | None = Field( + "auto", description="Update behaviour for the progress bars." + ) + connections: RingConnections | None = Field( + default_factory=RingConnections, description="Connections for the progress bars." + ) + + +class Ring(BECConnector): + USER_ACCESS = [ + "get_all_rpc", + "rpc_id", + "config_dict", + "set_value", + "set_color", + "set_background", + "set_line_width", + "set_min_max_values", + "set_start_angle", + "set_connections", + "reset_connection", + ] + + def __init__( + self, + parent=None, + parent_progress_widget=None, + config: RingConfig | dict | None = None, + client=None, + gui_id: Optional[str] = None, + ): + if config is None: + config = RingConfig(widget_class=self.__class__.__name__) + self.config = config + else: + if isinstance(config, dict): + config = RingConfig(**config) + self.config = config + super().__init__(client=client, config=config, gui_id=gui_id) + + self.parent_progress_widget = parent_progress_widget + self.color = None + self.background_color = None + self.start_position = None + self.config = config + self.value = 0 + self.RID = None + self._init_config_params() + + def _init_config_params(self): + self.color = self.convert_color(self.config.color) + self.background_color = self.convert_color(self.config.background_color) + self.set_start_angle(self.config.start_position) + if self.config.connections: + self.set_connections(self.config.connections.slot, self.config.connections.endpoint) + + def set_value(self, value: int | float): + self.value = round( + max(self.config.min_value, min(self.config.max_value, value)), self.config.precision + ) + + def set_color(self, color: str | tuple): + self.config.color = color + self.color = self.convert_color(color) + + def set_background(self, color: str | tuple): + self.config.background_color = color + self.color = self.convert_color(color) + + def set_line_width(self, width: int): + self.config.line_width = width + + def set_min_max_values(self, min_value: int, max_value: int): + self.config.min_value = min_value + self.config.max_value = max_value + + def set_start_angle(self, start_angle: int): + self.config.start_position = start_angle + self.start_position = start_angle * 16 + + @staticmethod + def convert_color(color): + converted_color = None + if isinstance(color, str): + converted_color = QtGui.QColor(color) + elif isinstance(color, tuple): + converted_color = QtGui.QColor(*color) + return converted_color + + def set_connections(self, slot: str, endpoint: str | EndpointInfo): + if self.config.connections.endpoint == endpoint and self.config.connections.slot == slot: + return + else: + self.bec_dispatcher.disconnect_slot( + self.config.connections.slot, self.config.connections.endpoint + ) + self.config.connections = RingConnections(slot=slot, endpoint=endpoint) + self.bec_dispatcher.connect_slot(getattr(self, slot), endpoint) + + def reset_connection(self): + self.bec_dispatcher.disconnect_slot( + self.config.connections.slot, self.config.connections.endpoint + ) + self.config.connections = RingConnections() + + def on_scan_progress(self, msg, meta): + current_RID = meta.get("RID", None) + if current_RID != self.RID: + self.set_min_max_values(0, msg.get("max_value", 100)) + self.set_value(msg.get("value", 0)) + self.parent_progress_widget.update() + + def on_device_readback(self, msg, meta): + if isinstance(self.config.connections.endpoint, EndpointInfo): + endpoint = self.config.connections.endpoint.endpoint + else: + endpoint = self.config.connections.endpoint + device = endpoint.split("/")[-1] + value = msg.get("signals").get(device).get("value") + self.set_value(value) + self.parent_progress_widget.update() + + def cleanup(self): + self.reset_connection() + super().cleanup() diff --git a/bec_widgets/widgets/spiral_progress_bar/spiral_progress_bar.py b/bec_widgets/widgets/spiral_progress_bar/spiral_progress_bar.py new file mode 100644 index 00000000..ec9f14ba --- /dev/null +++ b/bec_widgets/widgets/spiral_progress_bar/spiral_progress_bar.py @@ -0,0 +1,594 @@ +from __future__ import annotations + +from typing import Literal, Optional + +import pyqtgraph as pg +from bec_lib.endpoints import MessageEndpoints +from pydantic import Field, field_validator +from pydantic_core import PydanticCustomError +from qtpy import QtCore, QtGui +from qtpy.QtCore import QSize, Slot +from qtpy.QtWidgets import QSizePolicy, QWidget + +from bec_widgets.utils import BECConnector, Colors, ConnectionConfig, EntryValidator +from bec_widgets.widgets.spiral_progress_bar.ring import Ring, RingConfig + + +class SpiralProgressBarConfig(ConnectionConfig): + color_map: str | None = Field("magma", description="Color scheme for the progress bars.") + min_number_of_bars: int | None = Field( + 1, description="Minimum number of progress bars to display." + ) + max_number_of_bars: int | None = Field( + 10, description="Maximum number of progress bars to display." + ) + num_bars: int | None = Field(1, description="Number of progress bars to display.") + gap: int | None = Field(10, description="Gap between progress bars.") + auto_updates: bool | None = Field( + True, description="Enable or disable updates based on scan queue status." + ) + rings: list[RingConfig] | None = Field([], description="List of ring configurations.") + + @field_validator("num_bars") + def validate_num_bars(cls, v, values): + min_number_of_bars = values.data.get("min_number_of_bars", None) + max_number_of_bars = values.data.get("max_number_of_bars", None) + if min_number_of_bars is not None and max_number_of_bars is not None: + print( + f"Number of bars adjusted to be between defined min:{min_number_of_bars} and max:{max_number_of_bars} number of bars." + ) + v = max(min_number_of_bars, min(v, max_number_of_bars)) + return v + + @field_validator("rings") + def validate_rings(cls, v, values): + if v is not None and v is not []: + num_bars = values.data.get("num_bars", None) + if len(v) != num_bars: + raise PydanticCustomError( + "different number of configs", + f"Length of rings configuration ({len(v)}) does not match the number of bars ({num_bars}).", + {"wrong_value": len(v)}, + ) + indices = [ring.index for ring in v] + if sorted(indices) != list(range(len(indices))): + raise PydanticCustomError( + "wrong indices", + f"Indices of ring configurations must be unique and in order from 0 to num_bars {num_bars}.", + {"wrong_value": indices}, + ) + return v + + @field_validator("color_map") + def validate_color_map(cls, v, values): + if v is not None and v != "": + if v not in pg.colormap.listMaps(): + raise PydanticCustomError( + "unsupported colormap", + f"Colormap '{v}' not found in the current installation of pyqtgraph", + {"wrong_value": v}, + ) + return v + + +class SpiralProgressBar(BECConnector, QWidget): + USER_ACCESS = [ + "get_all_rpc", + "rpc_id", + "config_dict", + "rings", + "update_config", + "add_ring", + "remove_ring", + "set_precision", + "set_min_max_values", + "set_number_of_bars", + "set_value", + "set_colors_from_map", + "set_colors_directly", + "set_line_widths", + "set_gap", + "set_diameter", + "reset_diameter", + "enable_auto_updates", + ] + + def __init__( + self, + parent=None, + config: SpiralProgressBarConfig | dict | None = None, + client=None, + gui_id: str | None = None, + num_bars: int | None = None, + ): + if config is None: + config = SpiralProgressBarConfig(widget_class=self.__class__.__name__) + self.config = config + else: + if isinstance(config, dict): + config = SpiralProgressBarConfig(**config, widget_class=self.__class__.__name__) + self.config = config + super().__init__(client=client, config=config, gui_id=gui_id) + QWidget.__init__(self, parent=None) + + self.get_bec_shortcuts() + self.entry_validator = EntryValidator(self.dev) + + self.RID = None + self.values = None + + # For updating bar behaviour + self._auto_updates = True + self._rings = [] + + if num_bars is not None: + self.config.num_bars = max( + self.config.min_number_of_bars, min(num_bars, self.config.max_number_of_bars) + ) + self.initialize_bars() + + self.enable_auto_updates(self.config.auto_updates) + + @property + def rings(self): + return self._rings + + @rings.setter + def rings(self, value): + self._rings = value + + def update_config(self, config: SpiralProgressBarConfig | dict): + """ + Update the configuration of the widget. + + Args: + config(SpiralProgressBarConfig|dict): Configuration to update. + """ + if isinstance(config, dict): + config = SpiralProgressBarConfig(**config, widget_class=self.__class__.__name__) + self.config = config + self.clear_all() + + def initialize_bars(self): + """ + Initialize the progress bars. + """ + start_positions = [90 * 16] * self.config.num_bars + directions = [-1] * self.config.num_bars + + self.config.rings = [ + RingConfig( + widget_class="Ring", + index=i, + start_positions=start_positions[i], + directions=directions[i], + ) + for i in range(self.config.num_bars) + ] + self._rings = [ + Ring(parent_progress_widget=self, config=config) for config in self.config.rings + ] + + if self.config.color_map: + self.set_colors_from_map(self.config.color_map) + + min_size = self._calculate_minimum_size() + self.setMinimumSize(min_size) + self.update() + + def add_ring(self, **kwargs) -> Ring: + """ + Add a new progress bar. + + Args: + **kwargs: Keyword arguments for the new progress bar. + + Returns: + Ring: Ring object. + """ + if self.config.num_bars < self.config.max_number_of_bars: + ring = Ring(parent_progress_widget=self, **kwargs) + ring.config.index = self.config.num_bars + self.config.num_bars += 1 + self._rings.append(ring) + self.config.rings.append(ring.config) + if self.config.color_map: + self.set_colors_from_map(self.config.color_map) + self.update() + return ring + + def remove_ring(self, index: int): + """ + Remove a progress bar by index. + + Args: + index(int): Index of the progress bar to remove. + """ + ring = self._find_ring_by_index(index) + ring.cleanup() + self._rings.remove(ring) + self.config.rings.remove(ring.config) + self.config.num_bars -= 1 + self._reindex_rings() + if self.config.color_map: + self.set_colors_from_map(self.config.color_map) + self.update() + + def _reindex_rings(self): + """ + Reindex the progress bars. + """ + for i, ring in enumerate(self._rings): + ring.config.index = i + + def set_precision(self, precision: int, bar_index: int = None): + """ + Set the precision for the progress bars. If bar_index is not provide, the precision will be set for all progress bars. + + Args: + precision(int): Precision for the progress bars. + bar_index(int): Index of the progress bar to set the precision for. If provided, only a single precision can be set. + """ + if bar_index is not None: + bar_index = self._bar_index_check(bar_index) + ring = self._find_ring_by_index(bar_index) + ring.config.precision = precision + else: + for ring in self._rings: + ring.config.precision = precision + self.update() + + def set_min_max_values( + self, + min_values: int | float | list[int | float], + max_values: int | float | list[int | float], + ): + """ + Set the minimum and maximum values for the progress bars. + + Args: + min_values(int|float | list[float]): Minimum value(s) for the progress bars. If multiple progress bars are displayed, provide a list of minimum values for each progress bar. + max_values(int|float | list[float]): Maximum value(s) for the progress bars. If multiple progress bars are displayed, provide a list of maximum values for each progress bar. + """ + if isinstance(min_values, int) or isinstance(min_values, float): + min_values = [min_values] + if isinstance(max_values, int) or isinstance(max_values, float): + max_values = [max_values] + min_values = self._adjust_list_to_bars(min_values) + max_values = self._adjust_list_to_bars(max_values) + for ring, min_value, max_value in zip(self._rings, min_values, max_values): + ring.set_min_max_values(min_value, max_value) + self.update() + + def set_number_of_bars(self, num_bars: int): + """ + Set the number of progress bars to display. + + Args: + num_bars(int): Number of progress bars to display. + """ + num_bars = max( + self.config.min_number_of_bars, min(num_bars, self.config.max_number_of_bars) + ) + if num_bars != self.config.num_bars: + self.config.num_bars = num_bars + self.initialize_bars() + + def set_value(self, values: int | list, ring_index: int = None): + """ + Set the values for the progress bars. + + Args: + values(int | tuple): Value(s) for the progress bars. If multiple progress bars are displayed, provide a tuple of values for each progress bar. + ring_index(int): Index of the progress bar to set the value for. If provided, only a single value can be set. + + Examples: + >>> SpiralProgressBar.set_value(50) + >>> SpiralProgressBar.set_value([30, 40, 50]) # (outer, middle, inner) + >>> SpiralProgressBar.set_value(60, bar_index=1) # Set the value for the middle progress bar. + """ + if ring_index is not None: + ring = self._find_ring_by_index(ring_index) + if isinstance(values, list): + values = values[0] + print( + f"Warning: Only a single value can be set for a single progress bar. Using the first value in the list {values}" + ) + ring.set_value(values) + else: + if isinstance(values, int): + values = [values] + values = self._adjust_list_to_bars(values) + for ring, value in zip(self._rings, values): + ring.set_value(value) + self.update() + + def set_colors_from_map(self, colormap, color_format: Literal["RGB", "HEX"] = "RGB"): + """ + Set the colors for the progress bars from a colormap. + + Args: + colormap(str): Name of the colormap. + color_format(Literal["RGB","HEX"]): Format of the returned colors ('RGB', 'HEX'). + """ + if colormap not in pg.colormap.listMaps(): + raise ValueError( + f"Colormap '{colormap}' not found in the current installation of pyqtgraph" + ) + colors = Colors.golden_angle_color(colormap, self.config.num_bars, color_format) + self.set_colors_directly(colors) + self.config.color_map = colormap + self.update() + + def set_colors_directly(self, colors: list[str | tuple] | str | tuple, bar_index: int = None): + """ + Set the colors for the progress bars directly. + + Args: + colors(list[str | tuple] | str | tuple): Color(s) for the progress bars. If multiple progress bars are displayed, provide a list of colors for each progress bar. + bar_index(int): Index of the progress bar to set the color for. If provided, only a single color can be set. + """ + if bar_index is not None and isinstance(colors, (str, tuple)): + bar_index = self._bar_index_check(bar_index) + ring = self._find_ring_by_index(bar_index) + ring.set_color(colors) + else: + if isinstance(colors, (str, tuple)): + colors = [colors] + colors = self._adjust_list_to_bars(colors) + for ring, color in zip(self._rings, colors): + ring.set_color(color) + self.config.color_map = None + self.update() + + def set_line_widths(self, widths: int | list[int], bar_index: int = None): + """ + Set the line widths for the progress bars. + + Args: + widths(int | list[int]): Line width(s) for the progress bars. If multiple progress bars are displayed, provide a list of line widths for each progress bar. + bar_index(int): Index of the progress bar to set the line width for. If provided, only a single line width can be set. + """ + if bar_index is not None: + bar_index = self._bar_index_check(bar_index) + ring = self._find_ring_by_index(bar_index) + if isinstance(widths, list): + widths = widths[0] + print( + f"Warning: Only a single line width can be set for a single progress bar. Using the first value in the list {widths}" + ) + ring.set_line_width(widths) + else: + if isinstance(widths, int): + widths = [widths] + widths = self._adjust_list_to_bars(widths) + self.config.gap = max(widths) * 2 + for ring, width in zip(self._rings, widths): + ring.set_line_width(width) + min_size = self._calculate_minimum_size() + self.setMinimumSize(min_size) + self.update() + + def set_gap(self, gap: int): + """ + Set the gap between the progress bars. + + Args: + gap(int): Gap between the progress bars. + """ + self.config.gap = gap + self.update() + + def set_diameter(self, diameter: int): + """ + Set the diameter of the widget. + + Args: + diameter(int): Diameter of the widget. + """ + size = QSize(diameter, diameter) + self.resize(size) + self.setFixedSize(size) + + def _find_ring_by_index(self, index: int) -> Ring: + """ + Find the ring by index. + + Args: + index(int): Index of the ring. + + Returns: + Ring: Ring object. + """ + found_ring = None + for ring in self._rings: + if ring.config.index == index: + found_ring = ring + break + if found_ring is None: + raise ValueError(f"Ring with index {index} not found.") + return found_ring + + def enable_auto_updates(self, enable: bool = True): + """ + Enable or disable updates based on scan status. Overrides manual updates. + The behaviour of the whole progress bar widget will be driven by the scan queue status. + + Args: + enable(bool): True or False. + + Returns: + bool: True if scan segment updates are enabled. + """ + + self._auto_updates = enable + if enable is True: + self.bec_dispatcher.connect_slot( + self.on_scan_queue_status, MessageEndpoints.scan_queue_status() + ) + else: + self.bec_dispatcher.disconnect_slot( + self.on_scan_queue_status, MessageEndpoints.scan_queue_status() + ) + return self._auto_updates + + @Slot(dict, dict) + def on_scan_queue_status(self, msg, meta): + primary_queue = msg.get("queue").get("primary") + info = primary_queue.get("info", None) + + if info: + active_request_block = info[0].get("active_request_block", None) + if active_request_block: + report_instructions = active_request_block.get("report_instructions", None) + if report_instructions: + instruction_type = list(report_instructions[0].keys())[0] + if instruction_type == "scan_progress": + if self.config.num_bars != 1: + self.set_number_of_bars(1) + self._hook_scan_progress(ring_index=0) + elif instruction_type == "readback": + devices = report_instructions[0].get("readback").get("devices") + start = report_instructions[0].get("readback").get("start") + end = report_instructions[0].get("readback").get("end") + if self.config.num_bars != len(devices): + self.set_number_of_bars(len(devices)) + for index, device in enumerate(devices): + self._hook_readback(index, device, start[index], end[index]) + else: + print(f"{instruction_type} not supported yet.") + + # elif instruction_type == "device_progress": + # print("hook device_progress") + + def _hook_scan_progress(self, ring_index: int = None): + if ring_index is not None: + ring = self._find_ring_by_index(ring_index) + else: + ring = self._rings[0] + + if ring.config.connections.slot == "on_scan_progress": + return + else: + ring.set_connections("on_scan_progress", MessageEndpoints.scan_progress()) + + def _hook_readback(self, bar_index: int, device: str, min: float | int, max: float | int): + ring = self._find_ring_by_index(bar_index) + ring.set_min_max_values(min, max) + endpoint = MessageEndpoints.device_readback(device) + ring.set_connections("on_device_readback", endpoint) + + def _adjust_list_to_bars(self, items: list) -> list: + """ + Utility method to adjust the list of parameters to match the number of progress bars. + + Args: + items(list): List of parameters for the progress bars. + + Returns: + list: List of parameters for the progress bars. + """ + if items is None: + raise ValueError( + "Items cannot be None. Please provide a list for parameters for the progress bars." + ) + if not isinstance(items, list): + items = [items] + if len(items) < self.config.num_bars: + last_item = items[-1] + items.extend([last_item] * (self.config.num_bars - len(items))) + elif len(items) > self.config.num_bars: + items = items[: self.config.num_bars] + return items + + def _bar_index_check(self, bar_index: int): + """ + Utility method to check if the bar index is within the range of the number of progress bars. + + Args: + bar_index(int): Index of the progress bar to set the value for. + """ + if not (0 <= bar_index < self.config.num_bars): + raise ValueError( + f"bar_index {bar_index} out of range of number of bars {self.config.num_bars}." + ) + return bar_index + + def paintEvent(self, event): + painter = QtGui.QPainter(self) + painter.setRenderHint(QtGui.QPainter.Antialiasing) + size = min(self.width(), self.height()) + rect = QtCore.QRect(0, 0, size, size) + rect.adjust( + max(ring.config.line_width for ring in self._rings), + max(ring.config.line_width for ring in self._rings), + -max(ring.config.line_width for ring in self._rings), + -max(ring.config.line_width for ring in self._rings), + ) + + for i, ring in enumerate(self._rings): + # Background arc + painter.setPen( + QtGui.QPen(ring.background_color, ring.config.line_width, QtCore.Qt.SolidLine) + ) + offset = self.config.gap * i + adjusted_rect = QtCore.QRect( + rect.left() + offset, + rect.top() + offset, + rect.width() - 2 * offset, + rect.height() - 2 * offset, + ) + painter.drawArc(adjusted_rect, ring.config.start_position, 360 * 16) + + # Foreground arc + pen = QtGui.QPen(ring.color, ring.config.line_width, QtCore.Qt.SolidLine) + pen.setCapStyle(QtCore.Qt.RoundCap) + painter.setPen(pen) + proportion = (ring.value - ring.config.min_value) / ( + (ring.config.max_value - ring.config.min_value) + 1e-3 + ) + angle = int(proportion * 360 * 16 * ring.config.direction) + painter.drawArc(adjusted_rect, ring.start_position, angle) + + def reset_diameter(self): + """ + Reset the fixed size of the widget. + """ + self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) + self.setMinimumSize(self._calculate_minimum_size()) + self.setMaximumSize(16777215, 16777215) + + def _calculate_minimum_size(self): + """ + Calculate the minimum size of the widget. + """ + if not self.config.rings: + print("no rings to get size from setting size to 10x10") + return QSize(10, 10) + ring_widths = [self.config.rings[i].line_width for i in range(self.config.num_bars)] + total_width = sum(ring_widths) + self.config.gap * (self.config.num_bars - 1) + diameter = total_width * 2 + if diameter < 50: + diameter = 50 + return QSize(diameter, diameter) + + def sizeHint(self): + min_size = self._calculate_minimum_size() + return min_size + + def clear_all(self): + for ring in self._rings: + ring.cleanup() + del ring + self._rings = [] + self.update() + self.initialize_bars() + + def cleanup(self): + self.bec_dispatcher.disconnect_slot( + self.on_scan_queue_status, MessageEndpoints.scan_queue_status() + ) + for ring in self._rings: + ring.cleanup() + del ring + super().cleanup() diff --git a/tests/end-2-end/test_bec_dock_rpc_e2e.py b/tests/end-2-end/test_bec_dock_rpc_e2e.py index dbb0d86d..38dc75b5 100644 --- a/tests/end-2-end/test_bec_dock_rpc_e2e.py +++ b/tests/end-2-end/test_bec_dock_rpc_e2e.py @@ -3,6 +3,7 @@ import pytest from bec_lib.endpoints import MessageEndpoints from bec_widgets.cli.client import BECDockArea, BECFigure, BECImageShow, BECMotorMap, BECWaveform +from bec_widgets.utils import Colors def test_rpc_add_dock_with_figure_e2e(rpc_server_dock, qtbot): @@ -143,3 +144,83 @@ def test_dock_manipulations_e2e(rpc_server_dock, qtbot): assert len(dock_server.docks) == 0 assert len(dock_server.tempAreas) == 0 + + +def test_spiral_bar(rpc_server_dock): + dock = BECDockArea(rpc_server_dock.gui_id) + dock_server = rpc_server_dock.gui + + d0 = dock.add_dock("dock_0") + + bar = d0.add_widget_bec("SpiralProgressBar") + assert bar.__class__.__name__ == "SpiralProgressBar" + + bar.set_number_of_bars(5) + bar.set_colors_from_map("viridis") + bar.set_value([10, 20, 30, 40, 50]) + + bar_server = dock_server.docks["dock_0"].widgets[0] + + expected_colors = Colors.golden_angle_color("viridis", 5, "RGB") + bar_colors = [ring.color.getRgb() for ring in bar_server.rings] + bar_values = [ring.value for ring in bar_server.rings] + assert bar_values == [10, 20, 30, 40, 50] + assert bar_colors == expected_colors + + +def test_spiral_bar_scan_update(rpc_server_dock, qtbot): + dock = BECDockArea(rpc_server_dock.gui_id) + dock_server = rpc_server_dock.gui + + d0 = dock.add_dock("dock_0") + + d0.add_widget_bec("SpiralProgressBar") + + client = rpc_server_dock.client + dev = client.device_manager.devices + scans = client.scans + + status = scans.line_scan(dev.samx, -5, 5, steps=10, exp_time=0.05, relative=False) + + while not status.status == "COMPLETED": + qtbot.wait(200) + + qtbot.wait(200) + bar_server = dock_server.docks["dock_0"].widgets[0] + assert bar_server.config.num_bars == 1 + np.testing.assert_allclose(bar_server.rings[0].value, 10, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.min_value, 0, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.max_value, 10, atol=0.1) + + status = scans.grid_scan(dev.samx, -5, 5, 4, dev.samy, -10, 10, 4, relative=True, exp_time=0.1) + + while not status.status == "COMPLETED": + qtbot.wait(200) + + qtbot.wait(200) + assert bar_server.config.num_bars == 1 + np.testing.assert_allclose(bar_server.rings[0].value, 16, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.min_value, 0, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.max_value, 16, atol=0.1) + + init_samx = dev.samx.read()["samx"]["value"] + init_samy = dev.samy.read()["samy"]["value"] + final_samx = init_samx + 5 + final_samy = init_samy + 10 + + dev.samx.velocity.put(5) + dev.samy.velocity.put(5) + + status = scans.umv(dev.samx, 5, dev.samy, 10, relative=True) + + while not status.status == "COMPLETED": + qtbot.wait(200) + + qtbot.wait(200) + assert bar_server.config.num_bars == 2 + np.testing.assert_allclose(bar_server.rings[0].value, final_samx, atol=0.1) + np.testing.assert_allclose(bar_server.rings[1].value, final_samy, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.min_value, init_samx, atol=0.1) + np.testing.assert_allclose(bar_server.rings[1].config.min_value, init_samy, atol=0.1) + np.testing.assert_allclose(bar_server.rings[0].config.max_value, final_samx, atol=0.1) + np.testing.assert_allclose(bar_server.rings[1].config.max_value, final_samy, atol=0.1) diff --git a/tests/unit_tests/test_spiral_progress_bar.py b/tests/unit_tests/test_spiral_progress_bar.py new file mode 100644 index 00000000..9bea8dab --- /dev/null +++ b/tests/unit_tests/test_spiral_progress_bar.py @@ -0,0 +1,338 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring, unused-import + +import pytest +from bec_lib.endpoints import MessageEndpoints +from pydantic import ValidationError + +from bec_widgets.utils import Colors +from bec_widgets.widgets import SpiralProgressBar +from bec_widgets.widgets.spiral_progress_bar.ring import RingConfig, RingConnections +from bec_widgets.widgets.spiral_progress_bar.spiral_progress_bar import SpiralProgressBarConfig + +from .client_mocks import mocked_client + + +@pytest.fixture +def spiral_progress_bar(qtbot, mocked_client): + widget = SpiralProgressBar(client=mocked_client) + qtbot.addWidget(widget) + qtbot.waitExposed(widget) + yield widget + widget.close() + + +def test_bar_init(spiral_progress_bar): + assert spiral_progress_bar is not None + assert spiral_progress_bar.client is not None + assert isinstance(spiral_progress_bar, SpiralProgressBar) + assert spiral_progress_bar.config.widget_class == "SpiralProgressBar" + assert spiral_progress_bar.config.gui_id is not None + assert spiral_progress_bar.gui_id == spiral_progress_bar.config.gui_id + + +def test_config_validation_num_of_bars(): + config = SpiralProgressBarConfig(num_bars=100, min_num_bars=1, max_num_bars=10) + + assert config.num_bars == 10 + + +def test_config_validation_num_of_ring_error(): + ring_config_0 = RingConfig(index=0) + ring_config_1 = RingConfig(index=1) + + with pytest.raises(ValidationError) as excinfo: + SpiralProgressBarConfig(rings=[ring_config_0, ring_config_1], num_bars=1) + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "different number of configs" + assert "Length of rings configuration (2) does not match the number of bars (1)." in str( + excinfo.value + ) + + +def test_config_validation_ring_indices_wrong_order(): + ring_config_0 = RingConfig(index=2) + ring_config_1 = RingConfig(index=5) + + with pytest.raises(ValidationError) as excinfo: + SpiralProgressBarConfig(rings=[ring_config_0, ring_config_1], num_bars=2) + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "wrong indices" + assert ( + "Indices of ring configurations must be unique and in order from 0 to num_bars 2." + in str(excinfo.value) + ) + + +def test_config_validation_ring_same_indices(): + ring_config_0 = RingConfig(index=0) + ring_config_1 = RingConfig(index=0) + + with pytest.raises(ValidationError) as excinfo: + SpiralProgressBarConfig(rings=[ring_config_0, ring_config_1], num_bars=2) + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "wrong indices" + assert ( + "Indices of ring configurations must be unique and in order from 0 to num_bars 2." + in str(excinfo.value) + ) + + +def test_config_validation_invalid_colormap(): + with pytest.raises(ValueError) as excinfo: + SpiralProgressBarConfig(color_map="crazy_colors") + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "unsupported colormap" + assert "Colormap 'crazy_colors' not found in the current installation of pyqtgraph" in str( + excinfo.value + ) + + +def test_ring_connection_endpoint_validation(): + with pytest.raises(ValueError) as excinfo: + RingConnections(slot="on_scan_progress", endpoint="non_existing") + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "unsupported endpoint" + assert ( + "For slot 'on_scan_progress', endpoint must be MessageEndpoint.scan_progress or 'scans/scan_progress'." + in str(excinfo.value) + ) + + with pytest.raises(ValueError) as excinfo: + RingConnections(slot="on_device_readback", endpoint="non_existing") + errors = excinfo.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "unsupported endpoint" + assert ( + "For slot 'on_device_readback', endpoint must be MessageEndpoint.device_readback(device) or 'internal/devices/readback/{device}'." + in str(excinfo.value) + ) + + +def test_bar_add_number_of_bars(spiral_progress_bar): + assert spiral_progress_bar.config.num_bars == 1 + + spiral_progress_bar.set_number_of_bars(5) + assert spiral_progress_bar.config.num_bars == 5 + + spiral_progress_bar.set_number_of_bars(2) + assert spiral_progress_bar.config.num_bars == 2 + + +def test_add_remove_bars_individually(spiral_progress_bar): + spiral_progress_bar.add_ring() + spiral_progress_bar.add_ring() + + assert spiral_progress_bar.config.num_bars == 3 + assert len(spiral_progress_bar.config.rings) == 3 + + spiral_progress_bar.remove_ring(1) + assert spiral_progress_bar.config.num_bars == 2 + assert len(spiral_progress_bar.config.rings) == 2 + assert spiral_progress_bar.rings[0].config.index == 0 + assert spiral_progress_bar.rings[1].config.index == 1 + + +def test_bar_set_value(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(5) + + assert spiral_progress_bar.config.num_bars == 5 + assert len(spiral_progress_bar.config.rings) == 5 + assert len(spiral_progress_bar.rings) == 5 + + spiral_progress_bar.set_value([10, 20, 30, 40, 50]) + ring_values = [ring.value for ring in spiral_progress_bar.rings] + assert ring_values == [10, 20, 30, 40, 50] + + # update just one bar + spiral_progress_bar.set_value(90, 1) + ring_values = [ring.value for ring in spiral_progress_bar.rings] + assert ring_values == [10, 90, 30, 40, 50] + + +def test_bar_set_precision(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(3) + + assert spiral_progress_bar.config.num_bars == 3 + assert len(spiral_progress_bar.config.rings) == 3 + assert len(spiral_progress_bar.rings) == 3 + + spiral_progress_bar.set_precision(2) + ring_precision = [ring.config.precision for ring in spiral_progress_bar.rings] + assert ring_precision == [2, 2, 2] + + spiral_progress_bar.set_value([10.1234, 20.1234, 30.1234]) + ring_values = [ring.value for ring in spiral_progress_bar.rings] + assert ring_values == [10.12, 20.12, 30.12] + + spiral_progress_bar.set_precision(4, 1) + ring_precision = [ring.config.precision for ring in spiral_progress_bar.rings] + assert ring_precision == [2, 4, 2] + + spiral_progress_bar.set_value([10.1234, 20.1234, 30.1234]) + ring_values = [ring.value for ring in spiral_progress_bar.rings] + assert ring_values == [10.12, 20.1234, 30.12] + + +def test_set_min_max_value(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(2) + + spiral_progress_bar.set_min_max_values(0, 10) + ring_min_values = [ring.config.min_value for ring in spiral_progress_bar.rings] + ring_max_values = [ring.config.max_value for ring in spiral_progress_bar.rings] + + assert ring_min_values == [0, 0] + assert ring_max_values == [10, 10] + + spiral_progress_bar.set_value([5, 15]) + ring_values = [ring.value for ring in spiral_progress_bar.rings] + assert ring_values == [5, 10] + + +def test_setup_colors_from_colormap(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(5) + spiral_progress_bar.set_colors_from_map("viridis", "RGB") + + expected_colors = Colors.golden_angle_color("viridis", 5, "RGB") + converted_colors = [ring.color.getRgb() for ring in spiral_progress_bar.rings] + ring_config_colors = [ring.config.color for ring in spiral_progress_bar.rings] + + assert expected_colors == converted_colors + assert ring_config_colors == expected_colors + + +def get_colors_from_rings(rings): + converted_colors = [ring.color.getRgb() for ring in rings] + ring_config_colors = [ring.config.color for ring in rings] + return converted_colors, ring_config_colors + + +def test_set_colors_from_colormap_and_change_num_of_bars(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(2) + spiral_progress_bar.set_colors_from_map("viridis", "RGB") + + expected_colors = Colors.golden_angle_color("viridis", 2, "RGB") + converted_colors, ring_config_colors = get_colors_from_rings(spiral_progress_bar.rings) + + assert expected_colors == converted_colors + assert ring_config_colors == expected_colors + + # increase the number of bars to 6 + spiral_progress_bar.set_number_of_bars(6) + expected_colors = Colors.golden_angle_color("viridis", 6, "RGB") + converted_colors, ring_config_colors = get_colors_from_rings(spiral_progress_bar.rings) + + assert expected_colors == converted_colors + assert ring_config_colors == expected_colors + + # decrease the number of bars to 3 + spiral_progress_bar.set_number_of_bars(3) + expected_colors = Colors.golden_angle_color("viridis", 3, "RGB") + converted_colors, ring_config_colors = get_colors_from_rings(spiral_progress_bar.rings) + + assert expected_colors == converted_colors + assert ring_config_colors == expected_colors + + +def test_set_colors_directly(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(3) + + # setting as a list of rgb tuples + colors = [(255, 0, 0, 255), (0, 255, 0, 255), (0, 0, 255, 255)] + spiral_progress_bar.set_colors_directly(colors) + converted_colors = get_colors_from_rings(spiral_progress_bar.rings)[0] + + assert colors == converted_colors + + spiral_progress_bar.set_colors_directly((255, 0, 0, 255), 1) + converted_colors = get_colors_from_rings(spiral_progress_bar.rings)[0] + + assert converted_colors == [(255, 0, 0, 255), (255, 0, 0, 255), (0, 0, 255, 255)] + + +def test_set_line_width(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(3) + + spiral_progress_bar.set_line_widths(5) + line_widths = [ring.config.line_width for ring in spiral_progress_bar.rings] + + assert line_widths == [5, 5, 5] + + spiral_progress_bar.set_line_widths([10, 20, 30]) + line_widths = [ring.config.line_width for ring in spiral_progress_bar.rings] + + assert line_widths == [10, 20, 30] + + spiral_progress_bar.set_line_widths(15, 1) + line_widths = [ring.config.line_width for ring in spiral_progress_bar.rings] + + assert line_widths == [10, 15, 30] + + +def test_set_gap(spiral_progress_bar): + spiral_progress_bar.set_number_of_bars(3) + spiral_progress_bar.set_gap(20) + + assert spiral_progress_bar.config.gap == 20 + + +def test_auto_update(spiral_progress_bar): + spiral_progress_bar.enable_auto_updates(True) + + scan_queue_status_scan_progress = { + "queue": { + "primary": { + "info": [{"active_request_block": {"report_instructions": [{"scan_progress": 10}]}}] + } + } + } + meta = {} + + spiral_progress_bar.on_scan_queue_status(scan_queue_status_scan_progress, meta) + + assert spiral_progress_bar._auto_updates is True + assert len(spiral_progress_bar._rings) == 1 + assert spiral_progress_bar._rings[0].config.connections == RingConnections( + slot="on_scan_progress", endpoint=MessageEndpoints.scan_progress() + ) + + scan_queue_status_device_readback = { + "queue": { + "primary": { + "info": [ + { + "active_request_block": { + "report_instructions": [ + { + "readback": { + "devices": ["samx", "samy"], + "start": [1, 2], + "end": [10, 20], + } + } + ] + } + } + ] + } + } + } + spiral_progress_bar.on_scan_queue_status(scan_queue_status_device_readback, meta) + + assert spiral_progress_bar._auto_updates is True + assert len(spiral_progress_bar._rings) == 2 + assert spiral_progress_bar._rings[0].config.connections == RingConnections( + slot="on_device_readback", endpoint=MessageEndpoints.device_readback("samx") + ) + assert spiral_progress_bar._rings[1].config.connections == RingConnections( + slot="on_device_readback", endpoint=MessageEndpoints.device_readback("samy") + ) + + assert spiral_progress_bar._rings[0].config.min_value == 1 + assert spiral_progress_bar._rings[0].config.max_value == 10 + assert spiral_progress_bar._rings[1].config.min_value == 2 + assert spiral_progress_bar._rings[1].config.max_value == 20