0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 03:31:50 +02:00

feat(cli): auto-discover rpc-enabled widgets

This commit is contained in:
2024-06-08 11:14:17 +02:00
parent 954c576131
commit df1be10057
4 changed files with 1412 additions and 1384 deletions

View File

@ -1 +1 @@
from .client import BECDockArea, BECFigure from .client import *

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,17 @@
# pylint: disable=missing-module-docstring # pylint: disable=missing-module-docstring
from __future__ import annotations from __future__ import annotations
import argparse
import importlib
import inspect import inspect
import os
import sys import sys
from typing import Literal
import black import black
from qtpy.QtWidgets import QGraphicsWidget, QWidget
from bec_widgets.utils import BECConnector
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
from typing import get_overloads from typing import get_overloads
@ -22,22 +29,40 @@ else:
class ClientGenerator: class ClientGenerator:
def __init__(self): def __init__(self):
self.header = """# This file was automatically generated by generate_cli.py\n self.header = """# This file was automatically generated by generate_cli.py\n
from bec_widgets.cli.client_utils import rpc_call, RPCBase, BECGuiClientMixin import enum
from typing import Literal, Optional, overload""" from typing import Literal, Optional, overload
from bec_widgets.cli.client_utils import BECGuiClientMixin, RPCBase, rpc_call"""
self.content = "" self.content = ""
def generate_client(self, published_classes: list): def generate_client(
self, published_classes: dict[Literal["connector_classes", "top_level_classes"], list[type]]
):
""" """
Generate the client for the published classes. Generate the client for the published classes.
Args: Args:
published_classes(list): The list of published classes (e.g. [BECWaveform1D, BECFigure]). published_classes(dict): A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
""" """
for cls in published_classes: self.write_client_enum(published_classes["top_level_classes"])
for cls in published_classes["connector_classes"]:
self.content += "\n\n" self.content += "\n\n"
self.generate_content_for_class(cls) self.generate_content_for_class(cls)
def write_client_enum(self, published_classes: list[type]):
"""
Write the client enum to the content.
"""
self.content += """
class Widgets(str, enum.Enum):
\"\"\"
Enum for the available widgets.
\"\"\"
"""
for cls in published_classes:
self.content += f'{cls.__name__} = "{cls.__name__}"\n '
def generate_content_for_class(self, cls): def generate_content_for_class(self, cls):
""" """
Generate the content for the class. Generate the content for the class.
@ -104,38 +129,74 @@ class {class_name}(RPCBase):"""
with open(file_name, "w", encoding="utf-8") as file: with open(file_name, "w", encoding="utf-8") as file:
file.write(formatted_content) file.write(formatted_content)
@staticmethod
def get_rpc_classes(
repo_name: str,
) -> dict[Literal["connector_classes", "top_level_classes"], list[type]]:
"""
Get all RPC-enabled classes in the specified repository.
Args:
repo_name(str): The name of the repository.
Returns:
dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
"""
connector_classes = []
top_level_classes = []
anchor_module = importlib.import_module(f"{repo_name}.widgets")
directory = os.path.dirname(anchor_module.__file__)
for root, _, files in sorted(os.walk(directory)):
for file in files:
if not file.endswith(".py") or file.startswith("__"):
continue
path = os.path.join(root, file)
subs = os.path.dirname(os.path.relpath(path, directory)).split("/")
if len(subs) == 1 and not subs[0]:
module_name = file.split(".")[0]
else:
module_name = ".".join(subs + [file.split(".")[0]])
module = importlib.import_module(f"{repo_name}.widgets.{module_name}")
for name in dir(module):
obj = getattr(module, name)
if not hasattr(obj, "__module__") or obj.__module__ != module.__name__:
continue
if isinstance(obj, type) and issubclass(obj, BECConnector):
connector_classes.append(obj)
if len(subs) == 1 and (
issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget)
):
top_level_classes.append(obj)
return {"connector_classes": connector_classes, "top_level_classes": top_level_classes}
def main():
"""
Main entry point for the script, controlled by command line arguments.
"""
parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets")
parser.add_argument("--core", action="store_true", help="Whether to generate the core client")
args = parser.parse_args()
if args.core:
current_path = os.path.dirname(__file__)
client_path = os.path.join(current_path, "client.py")
rpc_classes = ClientGenerator.get_rpc_classes("bec_widgets")
rpc_classes["connector_classes"].append(BECConnector) # Not sure if this is necessary
rpc_classes["connector_classes"].sort(key=lambda x: x.__name__)
generator = ClientGenerator()
generator.generate_client(rpc_classes)
generator.write(client_path)
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover
import os sys.argv = ["generate_cli.py", "--core"]
main()
from bec_widgets.utils import BECConnector
from bec_widgets.widgets import BECDock, BECDockArea, BECFigure, SpiralProgressBar
from bec_widgets.widgets.figure.plots.image.image import BECImageShow
from bec_widgets.widgets.figure.plots.image.image_item import BECImageItem
from bec_widgets.widgets.figure.plots.motor_map.motor_map import BECMotorMap
from bec_widgets.widgets.figure.plots.plot_base import BECPlotBase
from bec_widgets.widgets.figure.plots.waveform.waveform import BECWaveform
from bec_widgets.widgets.figure.plots.waveform.waveform_curve import BECCurve
from bec_widgets.widgets.spiral_progress_bar.ring import Ring
from bec_widgets.widgets.website.website import WebsiteWidget
current_path = os.path.dirname(__file__)
client_path = os.path.join(current_path, "client.py")
clss = [
BECPlotBase,
BECWaveform,
BECFigure,
BECCurve,
BECImageShow,
BECConnector,
BECImageItem,
BECMotorMap,
BECDock,
BECDockArea,
SpiralProgressBar,
Ring,
WebsiteWidget,
]
generator = ClientGenerator()
generator.generate_client(clss)
generator.write(client_path)

View File

@ -1,7 +1,6 @@
from textwrap import dedent from textwrap import dedent
import black import black
import pytest
from bec_widgets.cli.generate_cli import ClientGenerator from bec_widgets.cli.generate_cli import ClientGenerator
@ -33,16 +32,31 @@ class MockBECFigure:
def test_client_generator_with_black_formatting(): def test_client_generator_with_black_formatting():
generator = ClientGenerator() generator = ClientGenerator()
generator.generate_client([MockBECWaveform1D, MockBECFigure]) rpc_classes = {
"connector_classes": [MockBECWaveform1D, MockBECFigure],
"top_level_classes": [MockBECFigure],
}
generator.generate_client(rpc_classes)
# Format the expected output with black to ensure it matches the generator output # Format the expected output with black to ensure it matches the generator output
expected_output = dedent( expected_output = dedent(
'''\ '''\
# This file was automatically generated by generate_cli.py # This file was automatically generated by generate_cli.py
from bec_widgets.cli.client_utils import rpc_call, RPCBase, BECGuiClientMixin import enum
from typing import Literal, Optional, overload from typing import Literal, Optional, overload
from bec_widgets.cli.client_utils import BECGuiClientMixin, RPCBase, rpc_call
class Widgets(str, enum.Enum):
"""
Enum for the available widgets.
"""
MockBECFigure = "MockBECFigure"
class MockBECWaveform1D(RPCBase): class MockBECWaveform1D(RPCBase):
@rpc_call @rpc_call
def set_frequency(self, frequency: float) -> list: def set_frequency(self, frequency: float) -> list:
@ -79,3 +93,17 @@ def test_client_generator_with_black_formatting():
) )
assert expected_output_formatted == generated_output_formatted assert expected_output_formatted == generated_output_formatted
def test_client_generator_classes():
generator = ClientGenerator()
out = generator.get_rpc_classes("bec_widgets")
assert list(out.keys()) == ["connector_classes", "top_level_classes"]
connector_cls_names = [cls.__name__ for cls in out["connector_classes"]]
top_level_cls_names = [cls.__name__ for cls in out["top_level_classes"]]
assert "BECFigure" in connector_cls_names
assert "BECWaveform" in connector_cls_names
assert "BECDockArea" in top_level_cls_names
assert "BECFigure" in top_level_cls_names
assert "BECWaveform" not in top_level_cls_names