diff --git a/bec_widgets/cli/client.py b/bec_widgets/cli/client.py index 3d7ac7f2..9b02d0a2 100644 --- a/bec_widgets/cli/client.py +++ b/bec_widgets/cli/client.py @@ -13,12 +13,12 @@ class Widgets(str, enum.Enum): Enum for the available widgets. """ - BECQueue = "BECQueue" - BECStatusBox = "BECStatusBox" BECDock = "BECDock" BECDockArea = "BECDockArea" BECFigure = "BECFigure" BECMotorMapWidget = "BECMotorMapWidget" + BECQueue = "BECQueue" + BECStatusBox = "BECStatusBox" RingProgressBar = "RingProgressBar" ScanControl = "ScanControl" TextBox = "TextBox" diff --git a/bec_widgets/cli/generate_cli.py b/bec_widgets/cli/generate_cli.py index 25ef0e70..a70c7add 100644 --- a/bec_widgets/cli/generate_cli.py +++ b/bec_widgets/cli/generate_cli.py @@ -5,13 +5,12 @@ import argparse import inspect import os import sys -from typing import Literal import black import isort from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator -from bec_widgets.utils.plugin_utils import get_rpc_classes +from bec_widgets.utils.plugin_utils import BECClassContainer, get_rpc_classes if sys.version_info >= (3, 11): from typing import get_overloads @@ -40,17 +39,20 @@ from bec_widgets.cli.client_utils import RPCBase, rpc_call, BECGuiClientMixin self.content = "" - def generate_client( - self, published_classes: dict[Literal["connector_classes", "top_level_classes"], list[type]] - ): + def generate_client(self, class_container: BECClassContainer): """ Generate the client for the published classes. Args: - published_classes(dict): A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes. + class_container: The class container with the classes to generate the client for. """ - self.write_client_enum(published_classes["top_level_classes"]) - for cls in published_classes["connector_classes"]: + rpc_top_level_classes = class_container.rpc_top_level_classes + rpc_top_level_classes.sort(key=lambda x: x.__name__) + connector_classes = class_container.connector_classes + connector_classes.sort(key=lambda x: x.__name__) + + self.write_client_enum(rpc_top_level_classes) + for cls in connector_classes: self.content += "\n\n" self.generate_content_for_class(cls) @@ -156,13 +158,12 @@ def main(): client_path = os.path.join(current_path, "client.py") rpc_classes = get_rpc_classes("bec_widgets") - rpc_classes["connector_classes"].sort(key=lambda x: x.__name__) generator = ClientGenerator() generator.generate_client(rpc_classes) generator.write(client_path) - for cls in rpc_classes["top_level_classes"]: + for cls in rpc_classes.plugins: plugin = DesignerPluginGenerator(cls) if not hasattr(plugin, "info"): continue diff --git a/bec_widgets/cli/rpc_wigdet_handler.py b/bec_widgets/cli/rpc_wigdet_handler.py index 44c1d61f..f302b804 100644 --- a/bec_widgets/cli/rpc_wigdet_handler.py +++ b/bec_widgets/cli/rpc_wigdet_handler.py @@ -29,7 +29,7 @@ class RPCWidgetHandler: from bec_widgets.utils.plugin_utils import get_rpc_classes clss = get_rpc_classes("bec_widgets") - self._widget_classes = {cls.__name__: cls for cls in clss["top_level_classes"]} + self._widget_classes = {cls.__name__: cls for cls in clss.top_level_classes} def create_widget(self, widget_type, **kwargs) -> BECConnector: """ diff --git a/bec_widgets/utils/plugin_utils.py b/bec_widgets/utils/plugin_utils.py index 4f213ce9..fa392c18 100644 --- a/bec_widgets/utils/plugin_utils.py +++ b/bec_widgets/utils/plugin_utils.py @@ -1,7 +1,7 @@ import importlib import inspect import os -from typing import Literal +from dataclasses import dataclass from bec_lib.plugin_helper import _get_available_plugins from qtpy.QtWidgets import QGraphicsWidget, QWidget @@ -45,9 +45,74 @@ def _filter_plugins(obj): return inspect.isclass(obj) and issubclass(obj, BECConnector) -def get_rpc_classes( - repo_name: str, -) -> dict[Literal["connector_classes", "top_level_classes"], list[type]]: +@dataclass +class BECClassInfo: + name: str + module: str + file: str + obj: type + is_connector: bool = False + is_widget: bool = False + is_top_level: bool = False + + +class BECClassContainer: + def __init__(self): + self._collection = [] + + def add_class(self, class_info: BECClassInfo): + """ + Add a class to the collection. + + Args: + class_info(BECClassInfo): The class information + """ + self.collection.append(class_info) + + @property + def collection(self): + """ + Get the collection of classes. + """ + return self._collection + + @property + def connector_classes(self): + """ + Get all connector classes. + """ + return [info.obj for info in self.collection if info.is_connector] + + @property + def top_level_classes(self): + """ + Get all top-level classes. + """ + return [info.obj for info in self.collection if info.is_top_level] + + @property + def plugins(self): + """ + Get all plugins. These are all classes that are on the top level and are widgets. + """ + return [info.obj for info in self.collection if info.is_widget and info.is_top_level] + + @property + def widgets(self): + """ + Get all widgets. These are all classes inheriting from BECWidget. + """ + return [info.obj for info in self.collection if info.is_widget] + + @property + def rpc_top_level_classes(self): + """ + Get all top-level classes that are RPC-enabled. These are all classes that users can choose from. + """ + return [info.obj for info in self.collection if info.is_top_level and info.is_connector] + + +def get_rpc_classes(repo_name: str) -> BECClassContainer: """ Get all RPC-enabled classes in the specified repository. @@ -57,8 +122,7 @@ def get_rpc_classes( Returns: dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes. """ - connector_classes = [] - top_level_classes = [] + collection = BECClassContainer() anchor_module = importlib.import_module(f"{repo_name}.widgets") directory = os.path.dirname(anchor_module.__file__) for root, _, files in sorted(os.walk(directory)): @@ -79,11 +143,16 @@ def get_rpc_classes( obj = getattr(module, name) if not hasattr(obj, "__module__") or obj.__module__ != module.__name__: continue - if isinstance(obj, type) and issubclass(obj, BECWidget): - connector_classes.append(obj) + if isinstance(obj, type): + class_info = BECClassInfo(name=name, module=module_name, file=path, obj=obj) + if issubclass(obj, BECConnector): + class_info.is_connector = True + if issubclass(obj, BECWidget): + class_info.is_widget = True if len(subs) == 1 and ( issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget) ): - top_level_classes.append(obj) + class_info.is_top_level = True + collection.add_class(class_info) - return {"connector_classes": connector_classes, "top_level_classes": top_level_classes} + return collection diff --git a/bec_widgets/utils/ui_loader.py b/bec_widgets/utils/ui_loader.py index 40336ead..9b7f03be 100644 --- a/bec_widgets/utils/ui_loader.py +++ b/bec_widgets/utils/ui_loader.py @@ -30,7 +30,7 @@ class UILoader: def __init__(self, parent=None): self.parent = parent - widgets = get_rpc_classes("bec_widgets").get("top_level_classes", []) + widgets = get_rpc_classes("bec_widgets").top_level_classes self.custom_widgets = {widget.__name__: widget for widget in widgets} diff --git a/tests/unit_tests/test_generate_cli_client.py b/tests/unit_tests/test_generate_cli_client.py index 69d80606..88a4d5dc 100644 --- a/tests/unit_tests/test_generate_cli_client.py +++ b/tests/unit_tests/test_generate_cli_client.py @@ -4,6 +4,7 @@ import black import isort from bec_widgets.cli.generate_cli import ClientGenerator +from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo # pylint: disable=missing-function-docstring @@ -33,11 +34,31 @@ class MockBECFigure: def test_client_generator_with_black_formatting(): generator = ClientGenerator() - rpc_classes = { - "connector_classes": [MockBECWaveform1D, MockBECFigure], - "top_level_classes": [MockBECFigure], - } - generator.generate_client(rpc_classes) + container = BECClassContainer() + container.add_class( + BECClassInfo( + name="MockBECWaveform1D", + module="test_module", + file="test_file", + obj=MockBECWaveform1D, + is_connector=True, + is_widget=True, + is_top_level=False, + ) + ) + container.add_class( + BECClassInfo( + name="MockBECFigure", + module="test_module", + file="test_file", + obj=MockBECFigure, + is_connector=True, + is_widget=True, + is_top_level=True, + ) + ) + + generator.generate_client(container) # Format the expected output with black to ensure it matches the generator output expected_output = dedent( @@ -51,6 +72,7 @@ def test_client_generator_with_black_formatting(): # pylint: skip-file + class Widgets(str, enum.Enum): """ Enum for the available widgets. @@ -59,18 +81,6 @@ def test_client_generator_with_black_formatting(): MockBECFigure = "MockBECFigure" - class MockBECWaveform1D(RPCBase): - @rpc_call - def set_frequency(self, frequency: float) -> list: - """ - Set the frequency of the waveform. - """ - @rpc_call - def set_amplitude(self, amplitude: float) -> tuple[float, float]: - """ - Set the amplitude of the waveform. - """ - class MockBECFigure(RPCBase): @rpc_call def add_plot(self, plot_id: str): @@ -83,6 +93,20 @@ def test_client_generator_with_black_formatting(): """ Remove a plot from the figure. """ + + + class MockBECWaveform1D(RPCBase): + @rpc_call + def set_frequency(self, frequency: float) -> list: + """ + Set the frequency of the waveform. + """ + + @rpc_call + def set_amplitude(self, amplitude: float) -> tuple[float, float]: + """ + Set the amplitude of the waveform. + """ ''' ) diff --git a/tests/unit_tests/test_plugin_utils.py b/tests/unit_tests/test_plugin_utils.py index 82276f3d..2a0616b9 100644 --- a/tests/unit_tests/test_plugin_utils.py +++ b/tests/unit_tests/test_plugin_utils.py @@ -3,9 +3,8 @@ from bec_widgets.utils.plugin_utils import get_rpc_classes def test_client_generator_classes(): out = get_rpc_classes("bec_widgets") - assert list(out.keys()) == ["connector_classes", "top_level_classes"] - connector_cls_names = [cls.__name__ for cls in out["connector_classes"]] - top_level_cls_names = [cls.__name__ for cls in out["top_level_classes"]] + connector_cls_names = [cls.__name__ for cls in out.connector_classes] + top_level_cls_names = [cls.__name__ for cls in out.top_level_classes] assert "BECFigure" in connector_cls_names assert "BECWaveform" in connector_cls_names