0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

feat(plugins): moved plugin dict to dataclass and container

This commit is contained in:
2024-07-06 09:50:50 +02:00
parent d6d0777113
commit 03819a3d90
7 changed files with 137 additions and 44 deletions

View File

@ -13,12 +13,12 @@ class Widgets(str, enum.Enum):
Enum for the available widgets. Enum for the available widgets.
""" """
BECQueue = "BECQueue"
BECStatusBox = "BECStatusBox"
BECDock = "BECDock" BECDock = "BECDock"
BECDockArea = "BECDockArea" BECDockArea = "BECDockArea"
BECFigure = "BECFigure" BECFigure = "BECFigure"
BECMotorMapWidget = "BECMotorMapWidget" BECMotorMapWidget = "BECMotorMapWidget"
BECQueue = "BECQueue"
BECStatusBox = "BECStatusBox"
RingProgressBar = "RingProgressBar" RingProgressBar = "RingProgressBar"
ScanControl = "ScanControl" ScanControl = "ScanControl"
TextBox = "TextBox" TextBox = "TextBox"

View File

@ -5,13 +5,12 @@ import argparse
import inspect import inspect
import os import os
import sys import sys
from typing import Literal
import black import black
import isort import isort
from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator
from bec_widgets.utils.plugin_utils import get_rpc_classes from bec_widgets.utils.plugin_utils import BECClassContainer, get_rpc_classes
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
from typing import get_overloads from typing import get_overloads
@ -40,17 +39,20 @@ from bec_widgets.cli.client_utils import RPCBase, rpc_call, BECGuiClientMixin
self.content = "" self.content = ""
def generate_client( def generate_client(self, class_container: BECClassContainer):
self, published_classes: dict[Literal["connector_classes", "top_level_classes"], list[type]]
):
""" """
Generate the client for the published classes. Generate the client for the published classes.
Args: Args:
published_classes(dict): A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes. class_container: The class container with the classes to generate the client for.
""" """
self.write_client_enum(published_classes["top_level_classes"]) rpc_top_level_classes = class_container.rpc_top_level_classes
for cls in published_classes["connector_classes"]: rpc_top_level_classes.sort(key=lambda x: x.__name__)
connector_classes = class_container.connector_classes
connector_classes.sort(key=lambda x: x.__name__)
self.write_client_enum(rpc_top_level_classes)
for cls in connector_classes:
self.content += "\n\n" self.content += "\n\n"
self.generate_content_for_class(cls) self.generate_content_for_class(cls)
@ -156,13 +158,12 @@ def main():
client_path = os.path.join(current_path, "client.py") client_path = os.path.join(current_path, "client.py")
rpc_classes = get_rpc_classes("bec_widgets") rpc_classes = get_rpc_classes("bec_widgets")
rpc_classes["connector_classes"].sort(key=lambda x: x.__name__)
generator = ClientGenerator() generator = ClientGenerator()
generator.generate_client(rpc_classes) generator.generate_client(rpc_classes)
generator.write(client_path) generator.write(client_path)
for cls in rpc_classes["top_level_classes"]: for cls in rpc_classes.plugins:
plugin = DesignerPluginGenerator(cls) plugin = DesignerPluginGenerator(cls)
if not hasattr(plugin, "info"): if not hasattr(plugin, "info"):
continue continue

View File

@ -29,7 +29,7 @@ class RPCWidgetHandler:
from bec_widgets.utils.plugin_utils import get_rpc_classes from bec_widgets.utils.plugin_utils import get_rpc_classes
clss = get_rpc_classes("bec_widgets") clss = get_rpc_classes("bec_widgets")
self._widget_classes = {cls.__name__: cls for cls in clss["top_level_classes"]} self._widget_classes = {cls.__name__: cls for cls in clss.top_level_classes}
def create_widget(self, widget_type, **kwargs) -> BECConnector: def create_widget(self, widget_type, **kwargs) -> BECConnector:
""" """

View File

@ -1,7 +1,7 @@
import importlib import importlib
import inspect import inspect
import os import os
from typing import Literal from dataclasses import dataclass
from bec_lib.plugin_helper import _get_available_plugins from bec_lib.plugin_helper import _get_available_plugins
from qtpy.QtWidgets import QGraphicsWidget, QWidget from qtpy.QtWidgets import QGraphicsWidget, QWidget
@ -45,9 +45,74 @@ def _filter_plugins(obj):
return inspect.isclass(obj) and issubclass(obj, BECConnector) return inspect.isclass(obj) and issubclass(obj, BECConnector)
def get_rpc_classes( @dataclass
repo_name: str, class BECClassInfo:
) -> dict[Literal["connector_classes", "top_level_classes"], list[type]]: name: str
module: str
file: str
obj: type
is_connector: bool = False
is_widget: bool = False
is_top_level: bool = False
class BECClassContainer:
def __init__(self):
self._collection = []
def add_class(self, class_info: BECClassInfo):
"""
Add a class to the collection.
Args:
class_info(BECClassInfo): The class information
"""
self.collection.append(class_info)
@property
def collection(self):
"""
Get the collection of classes.
"""
return self._collection
@property
def connector_classes(self):
"""
Get all connector classes.
"""
return [info.obj for info in self.collection if info.is_connector]
@property
def top_level_classes(self):
"""
Get all top-level classes.
"""
return [info.obj for info in self.collection if info.is_top_level]
@property
def plugins(self):
"""
Get all plugins. These are all classes that are on the top level and are widgets.
"""
return [info.obj for info in self.collection if info.is_widget and info.is_top_level]
@property
def widgets(self):
"""
Get all widgets. These are all classes inheriting from BECWidget.
"""
return [info.obj for info in self.collection if info.is_widget]
@property
def rpc_top_level_classes(self):
"""
Get all top-level classes that are RPC-enabled. These are all classes that users can choose from.
"""
return [info.obj for info in self.collection if info.is_top_level and info.is_connector]
def get_rpc_classes(repo_name: str) -> BECClassContainer:
""" """
Get all RPC-enabled classes in the specified repository. Get all RPC-enabled classes in the specified repository.
@ -57,8 +122,7 @@ def get_rpc_classes(
Returns: Returns:
dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes. dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
""" """
connector_classes = [] collection = BECClassContainer()
top_level_classes = []
anchor_module = importlib.import_module(f"{repo_name}.widgets") anchor_module = importlib.import_module(f"{repo_name}.widgets")
directory = os.path.dirname(anchor_module.__file__) directory = os.path.dirname(anchor_module.__file__)
for root, _, files in sorted(os.walk(directory)): for root, _, files in sorted(os.walk(directory)):
@ -79,11 +143,16 @@ def get_rpc_classes(
obj = getattr(module, name) obj = getattr(module, name)
if not hasattr(obj, "__module__") or obj.__module__ != module.__name__: if not hasattr(obj, "__module__") or obj.__module__ != module.__name__:
continue continue
if isinstance(obj, type) and issubclass(obj, BECWidget): if isinstance(obj, type):
connector_classes.append(obj) class_info = BECClassInfo(name=name, module=module_name, file=path, obj=obj)
if issubclass(obj, BECConnector):
class_info.is_connector = True
if issubclass(obj, BECWidget):
class_info.is_widget = True
if len(subs) == 1 and ( if len(subs) == 1 and (
issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget) issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget)
): ):
top_level_classes.append(obj) class_info.is_top_level = True
collection.add_class(class_info)
return {"connector_classes": connector_classes, "top_level_classes": top_level_classes} return collection

View File

@ -30,7 +30,7 @@ class UILoader:
def __init__(self, parent=None): def __init__(self, parent=None):
self.parent = parent self.parent = parent
widgets = get_rpc_classes("bec_widgets").get("top_level_classes", []) widgets = get_rpc_classes("bec_widgets").top_level_classes
self.custom_widgets = {widget.__name__: widget for widget in widgets} self.custom_widgets = {widget.__name__: widget for widget in widgets}

View File

@ -4,6 +4,7 @@ import black
import isort import isort
from bec_widgets.cli.generate_cli import ClientGenerator from bec_widgets.cli.generate_cli import ClientGenerator
from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
@ -33,11 +34,31 @@ class MockBECFigure:
def test_client_generator_with_black_formatting(): def test_client_generator_with_black_formatting():
generator = ClientGenerator() generator = ClientGenerator()
rpc_classes = { container = BECClassContainer()
"connector_classes": [MockBECWaveform1D, MockBECFigure], container.add_class(
"top_level_classes": [MockBECFigure], BECClassInfo(
} name="MockBECWaveform1D",
generator.generate_client(rpc_classes) module="test_module",
file="test_file",
obj=MockBECWaveform1D,
is_connector=True,
is_widget=True,
is_top_level=False,
)
)
container.add_class(
BECClassInfo(
name="MockBECFigure",
module="test_module",
file="test_file",
obj=MockBECFigure,
is_connector=True,
is_widget=True,
is_top_level=True,
)
)
generator.generate_client(container)
# Format the expected output with black to ensure it matches the generator output # Format the expected output with black to ensure it matches the generator output
expected_output = dedent( expected_output = dedent(
@ -51,6 +72,7 @@ def test_client_generator_with_black_formatting():
# pylint: skip-file # pylint: skip-file
class Widgets(str, enum.Enum): class Widgets(str, enum.Enum):
""" """
Enum for the available widgets. Enum for the available widgets.
@ -59,18 +81,6 @@ def test_client_generator_with_black_formatting():
MockBECFigure = "MockBECFigure" MockBECFigure = "MockBECFigure"
class MockBECWaveform1D(RPCBase):
@rpc_call
def set_frequency(self, frequency: float) -> list:
"""
Set the frequency of the waveform.
"""
@rpc_call
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
"""
Set the amplitude of the waveform.
"""
class MockBECFigure(RPCBase): class MockBECFigure(RPCBase):
@rpc_call @rpc_call
def add_plot(self, plot_id: str): def add_plot(self, plot_id: str):
@ -83,6 +93,20 @@ def test_client_generator_with_black_formatting():
""" """
Remove a plot from the figure. Remove a plot from the figure.
""" """
class MockBECWaveform1D(RPCBase):
@rpc_call
def set_frequency(self, frequency: float) -> list:
"""
Set the frequency of the waveform.
"""
@rpc_call
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
"""
Set the amplitude of the waveform.
"""
''' '''
) )

View File

@ -3,9 +3,8 @@ from bec_widgets.utils.plugin_utils import get_rpc_classes
def test_client_generator_classes(): def test_client_generator_classes():
out = get_rpc_classes("bec_widgets") out = get_rpc_classes("bec_widgets")
assert list(out.keys()) == ["connector_classes", "top_level_classes"] connector_cls_names = [cls.__name__ for cls in out.connector_classes]
connector_cls_names = [cls.__name__ for cls in out["connector_classes"]] top_level_cls_names = [cls.__name__ for cls in out.top_level_classes]
top_level_cls_names = [cls.__name__ for cls in out["top_level_classes"]]
assert "BECFigure" in connector_cls_names assert "BECFigure" in connector_cls_names
assert "BECWaveform" in connector_cls_names assert "BECWaveform" in connector_cls_names