mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-13 19:21:50 +02:00
feat(cli): auto-discover rpc-enabled widgets
This commit is contained in:
@ -1 +1 @@
|
||||
from .client import BECDockArea, BECFigure
|
||||
from .client import *
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,17 @@
|
||||
# pylint: disable=missing-module-docstring
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
import black
|
||||
from qtpy.QtWidgets import QGraphicsWidget, QWidget
|
||||
|
||||
from bec_widgets.utils import BECConnector
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import get_overloads
|
||||
@ -22,22 +29,40 @@ else:
|
||||
class ClientGenerator:
|
||||
def __init__(self):
|
||||
self.header = """# This file was automatically generated by generate_cli.py\n
|
||||
from bec_widgets.cli.client_utils import rpc_call, RPCBase, BECGuiClientMixin
|
||||
from typing import Literal, Optional, overload"""
|
||||
import enum
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from bec_widgets.cli.client_utils import BECGuiClientMixin, RPCBase, rpc_call"""
|
||||
|
||||
self.content = ""
|
||||
|
||||
def generate_client(self, published_classes: list):
|
||||
def generate_client(
|
||||
self, published_classes: dict[Literal["connector_classes", "top_level_classes"], list[type]]
|
||||
):
|
||||
"""
|
||||
Generate the client for the published classes.
|
||||
|
||||
Args:
|
||||
published_classes(list): The list of published classes (e.g. [BECWaveform1D, BECFigure]).
|
||||
published_classes(dict): A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
|
||||
"""
|
||||
for cls in published_classes:
|
||||
self.write_client_enum(published_classes["top_level_classes"])
|
||||
for cls in published_classes["connector_classes"]:
|
||||
self.content += "\n\n"
|
||||
self.generate_content_for_class(cls)
|
||||
|
||||
def write_client_enum(self, published_classes: list[type]):
|
||||
"""
|
||||
Write the client enum to the content.
|
||||
"""
|
||||
self.content += """
|
||||
class Widgets(str, enum.Enum):
|
||||
\"\"\"
|
||||
Enum for the available widgets.
|
||||
\"\"\"
|
||||
"""
|
||||
for cls in published_classes:
|
||||
self.content += f'{cls.__name__} = "{cls.__name__}"\n '
|
||||
|
||||
def generate_content_for_class(self, cls):
|
||||
"""
|
||||
Generate the content for the class.
|
||||
@ -104,38 +129,74 @@ class {class_name}(RPCBase):"""
|
||||
with open(file_name, "w", encoding="utf-8") as file:
|
||||
file.write(formatted_content)
|
||||
|
||||
@staticmethod
|
||||
def get_rpc_classes(
|
||||
repo_name: str,
|
||||
) -> dict[Literal["connector_classes", "top_level_classes"], list[type]]:
|
||||
"""
|
||||
Get all RPC-enabled classes in the specified repository.
|
||||
|
||||
Args:
|
||||
repo_name(str): The name of the repository.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
|
||||
"""
|
||||
connector_classes = []
|
||||
top_level_classes = []
|
||||
anchor_module = importlib.import_module(f"{repo_name}.widgets")
|
||||
directory = os.path.dirname(anchor_module.__file__)
|
||||
for root, _, files in sorted(os.walk(directory)):
|
||||
for file in files:
|
||||
if not file.endswith(".py") or file.startswith("__"):
|
||||
continue
|
||||
|
||||
path = os.path.join(root, file)
|
||||
subs = os.path.dirname(os.path.relpath(path, directory)).split("/")
|
||||
if len(subs) == 1 and not subs[0]:
|
||||
module_name = file.split(".")[0]
|
||||
else:
|
||||
module_name = ".".join(subs + [file.split(".")[0]])
|
||||
|
||||
module = importlib.import_module(f"{repo_name}.widgets.{module_name}")
|
||||
|
||||
for name in dir(module):
|
||||
obj = getattr(module, name)
|
||||
if not hasattr(obj, "__module__") or obj.__module__ != module.__name__:
|
||||
continue
|
||||
if isinstance(obj, type) and issubclass(obj, BECConnector):
|
||||
connector_classes.append(obj)
|
||||
if len(subs) == 1 and (
|
||||
issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget)
|
||||
):
|
||||
top_level_classes.append(obj)
|
||||
|
||||
return {"connector_classes": connector_classes, "top_level_classes": top_level_classes}
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the script, controlled by command line arguments.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets")
|
||||
parser.add_argument("--core", action="store_true", help="Whether to generate the core client")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.core:
|
||||
current_path = os.path.dirname(__file__)
|
||||
client_path = os.path.join(current_path, "client.py")
|
||||
|
||||
rpc_classes = ClientGenerator.get_rpc_classes("bec_widgets")
|
||||
rpc_classes["connector_classes"].append(BECConnector) # Not sure if this is necessary
|
||||
rpc_classes["connector_classes"].sort(key=lambda x: x.__name__)
|
||||
|
||||
generator = ClientGenerator()
|
||||
generator.generate_client(rpc_classes)
|
||||
generator.write(client_path)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
import os
|
||||
|
||||
from bec_widgets.utils import BECConnector
|
||||
from bec_widgets.widgets import BECDock, BECDockArea, BECFigure, SpiralProgressBar
|
||||
from bec_widgets.widgets.figure.plots.image.image import BECImageShow
|
||||
from bec_widgets.widgets.figure.plots.image.image_item import BECImageItem
|
||||
from bec_widgets.widgets.figure.plots.motor_map.motor_map import BECMotorMap
|
||||
from bec_widgets.widgets.figure.plots.plot_base import BECPlotBase
|
||||
from bec_widgets.widgets.figure.plots.waveform.waveform import BECWaveform
|
||||
from bec_widgets.widgets.figure.plots.waveform.waveform_curve import BECCurve
|
||||
from bec_widgets.widgets.spiral_progress_bar.ring import Ring
|
||||
from bec_widgets.widgets.website.website import WebsiteWidget
|
||||
|
||||
current_path = os.path.dirname(__file__)
|
||||
client_path = os.path.join(current_path, "client.py")
|
||||
clss = [
|
||||
BECPlotBase,
|
||||
BECWaveform,
|
||||
BECFigure,
|
||||
BECCurve,
|
||||
BECImageShow,
|
||||
BECConnector,
|
||||
BECImageItem,
|
||||
BECMotorMap,
|
||||
BECDock,
|
||||
BECDockArea,
|
||||
SpiralProgressBar,
|
||||
Ring,
|
||||
WebsiteWidget,
|
||||
]
|
||||
generator = ClientGenerator()
|
||||
generator.generate_client(clss)
|
||||
generator.write(client_path)
|
||||
sys.argv = ["generate_cli.py", "--core"]
|
||||
main()
|
||||
|
@ -1,7 +1,6 @@
|
||||
from textwrap import dedent
|
||||
|
||||
import black
|
||||
import pytest
|
||||
|
||||
from bec_widgets.cli.generate_cli import ClientGenerator
|
||||
|
||||
@ -33,16 +32,31 @@ class MockBECFigure:
|
||||
|
||||
def test_client_generator_with_black_formatting():
|
||||
generator = ClientGenerator()
|
||||
generator.generate_client([MockBECWaveform1D, MockBECFigure])
|
||||
rpc_classes = {
|
||||
"connector_classes": [MockBECWaveform1D, MockBECFigure],
|
||||
"top_level_classes": [MockBECFigure],
|
||||
}
|
||||
generator.generate_client(rpc_classes)
|
||||
|
||||
# Format the expected output with black to ensure it matches the generator output
|
||||
expected_output = dedent(
|
||||
'''\
|
||||
# This file was automatically generated by generate_cli.py
|
||||
|
||||
from bec_widgets.cli.client_utils import rpc_call, RPCBase, BECGuiClientMixin
|
||||
import enum
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from bec_widgets.cli.client_utils import BECGuiClientMixin, RPCBase, rpc_call
|
||||
|
||||
|
||||
class Widgets(str, enum.Enum):
|
||||
"""
|
||||
Enum for the available widgets.
|
||||
"""
|
||||
|
||||
MockBECFigure = "MockBECFigure"
|
||||
|
||||
|
||||
class MockBECWaveform1D(RPCBase):
|
||||
@rpc_call
|
||||
def set_frequency(self, frequency: float) -> list:
|
||||
@ -79,3 +93,17 @@ def test_client_generator_with_black_formatting():
|
||||
)
|
||||
|
||||
assert expected_output_formatted == generated_output_formatted
|
||||
|
||||
|
||||
def test_client_generator_classes():
|
||||
generator = ClientGenerator()
|
||||
out = generator.get_rpc_classes("bec_widgets")
|
||||
assert list(out.keys()) == ["connector_classes", "top_level_classes"]
|
||||
connector_cls_names = [cls.__name__ for cls in out["connector_classes"]]
|
||||
top_level_cls_names = [cls.__name__ for cls in out["top_level_classes"]]
|
||||
|
||||
assert "BECFigure" in connector_cls_names
|
||||
assert "BECWaveform" in connector_cls_names
|
||||
assert "BECDockArea" in top_level_cls_names
|
||||
assert "BECFigure" in top_level_cls_names
|
||||
assert "BECWaveform" not in top_level_cls_names
|
||||
|
Reference in New Issue
Block a user