0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

refactor: tidy client generation and add options

This commit is contained in:
2025-02-28 17:23:43 +01:00
parent 43e1aa9505
commit b4925918f7
6 changed files with 2878 additions and 3076 deletions

View File

@ -1,5 +1,4 @@
from bec_widgets.qt_utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.utils.bec_widget import BECWidget
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
__all__ = ["BECWidget", "SafeSlot", "SafeProperty"] __all__ = ["BECWidget", "SafeSlot", "SafeProperty"]

View File

@ -2,17 +2,22 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import importlib
import inspect import inspect
import os import os
import sys import sys
from pathlib import Path
import black import black
import isort import isort
from bec_lib.logger import bec_logger
from qtpy.QtCore import Property as QtProperty from qtpy.QtCore import Property as QtProperty
from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator, plugin_filenames
from bec_widgets.utils.plugin_utils import BECClassContainer, get_custom_classes from bec_widgets.utils.plugin_utils import BECClassContainer, get_custom_classes
logger = bec_logger.logger
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
from typing import get_overloads from typing import get_overloads
else: else:
@ -193,41 +198,52 @@ def main():
""" """
parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets") parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets")
parser.add_argument("--core", action="store_true", help="Whether to generate the core client") parser.add_argument(
"--module-name",
action="store",
type=str,
default="bec_widgets",
help="Which module to generate plugin files for (default: bec_widgets, example: my_plugin_repo.bec_widgets)",
)
args = parser.parse_args() args = parser.parse_args()
if args.core: logger.info(f"BEC Widget code generation tool started with args: {args}")
current_path = os.path.dirname(__file__)
client_path = os.path.join(current_path, "client.py")
rpc_classes = get_custom_classes("bec_widgets") try:
module = importlib.import_module(args.module_name)
assert module.__file__ is not None
module_file = Path(module.__file__)
module_dir = module_file.parent if module_file.is_file() else module_file
except Exception as e:
logger.error(f"Failed to load module {args.module_name} for code generation: {e}")
return
generator = ClientGenerator() client_path = module_dir / "client.py"
generator.generate_client(rpc_classes)
generator.write(client_path)
for cls in rpc_classes.plugins: rpc_classes = get_custom_classes(args.module_name)
plugin = DesignerPluginGenerator(cls) logger.info(f"Obtained classes with RPC objects: {rpc_classes!r}")
if not hasattr(plugin, "info"):
continue
# if the class directory already has a register, plugin and pyproject file, skip generator = ClientGenerator(base=args.module_name == "bec_widgets")
if os.path.exists( logger.info(f"Generating client.py")
os.path.join(plugin.info.base_path, f"register_{plugin.info.plugin_name_snake}.py") generator.generate_client(rpc_classes)
): generator.write(str(client_path))
continue
if os.path.exists( for cls in rpc_classes.plugins:
os.path.join(plugin.info.base_path, f"{plugin.info.plugin_name_snake}_plugin.py") logger.info(f"Writing plugins for: {cls}")
): plugin = DesignerPluginGenerator(cls)
continue if not hasattr(plugin, "info"):
if os.path.exists( continue
os.path.join(plugin.info.base_path, f"{plugin.info.plugin_name_snake}.pyproject")
): def _exists(file: str):
continue return os.path.exists(os.path.join(plugin.info.base_path, file))
plugin.run()
if any(_exists(file) for file in plugin_filenames(plugin.info.plugin_name_snake)):
logger.debug(f"Skipping {plugin.info.plugin_name_snake} - a file already exists.")
continue
plugin.run()
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover
sys.argv = ["generate_cli.py", "--core"]
main() main()

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,23 @@
import inspect import inspect
import os import os
import re import re
from typing import NamedTuple
from qtpy.QtCore import QObject from qtpy.QtCore import QObject
EXCLUDED_PLUGINS = ["BECConnector", "BECDockArea", "BECDock", "BECFigure"] EXCLUDED_PLUGINS = ["BECConnector", "BECDockArea", "BECDock", "BECFigure"]
class PluginFilenames(NamedTuple):
register: str
plugin: str
pyproj: str
def plugin_filenames(name: str) -> PluginFilenames:
return PluginFilenames(f"register_{name}.py", f"{name}_plugin.py", f"{name}.pyproject")
class DesignerPluginInfo: class DesignerPluginInfo:
def __init__(self, plugin_class): def __init__(self, plugin_class):
self.plugin_class = plugin_class self.plugin_class = plugin_class
@ -53,11 +64,15 @@ class DesignerPluginGenerator:
self._excluded = True self._excluded = True
return return
self.templates = {} self.templates: dict[str, str] = {}
self.template_path = os.path.join( self.template_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "plugin_templates" os.path.dirname(os.path.abspath(__file__)), "plugin_templates"
) )
@property
def filenames(self):
return plugin_filenames(self.info.plugin_name_snake)
def run(self, validate=True): def run(self, validate=True):
if self._excluded: if self._excluded:
print(f"Plugin {self.widget.__name__} is excluded from generation.") print(f"Plugin {self.widget.__name__} is excluded from generation.")
@ -112,26 +127,18 @@ class DesignerPluginGenerator:
f"Widget class {self.widget.__name__} must call the super constructor with parent." f"Widget class {self.widget.__name__} must call the super constructor with parent."
) )
def _write_file(self, name: str, contents: str):
with open(os.path.join(self.info.base_path, name), "w", encoding="utf-8") as f:
f.write(contents)
def _format(self, name: str):
return self.templates[name].format(**self.info.__dict__)
def _write_templates(self): def _write_templates(self):
self._write_register() self._write_file(self.filenames.register, self._format("register"))
self._write_plugin() self._write_file(self.filenames.plugin, self._format("plugin"))
self._write_pyproject() pyproj = str({"files": [f"{self.info.plugin_class.__module__.split('.')[-1]}.py"]})
self._write_file(self.filenames.pyproj, pyproj)
def _write_register(self):
file_path = os.path.join(self.info.base_path, f"register_{self.info.plugin_name_snake}.py")
with open(file_path, "w", encoding="utf-8") as f:
f.write(self.templates["register"].format(**self.info.__dict__))
def _write_plugin(self):
file_path = os.path.join(self.info.base_path, f"{self.info.plugin_name_snake}_plugin.py")
with open(file_path, "w", encoding="utf-8") as f:
f.write(self.templates["plugin"].format(**self.info.__dict__))
def _write_pyproject(self):
file_path = os.path.join(self.info.base_path, f"{self.info.plugin_name_snake}.pyproject")
out = {"files": [f"{self.info.plugin_class.__module__.split('.')[-1]}.py"]}
with open(file_path, "w", encoding="utf-8") as f:
f.write(str(out))
def _load_templates(self): def _load_templates(self):
for file in os.listdir(self.template_path): for file in os.listdir(self.template_path):

View File

@ -58,7 +58,10 @@ class BECClassInfo:
class BECClassContainer: class BECClassContainer:
def __init__(self): def __init__(self):
self._collection = [] self._collection: list[BECClassInfo] = []
def __repr__(self):
return str(list(cl.name for cl in self.collection))
def add_class(self, class_info: BECClassInfo): def add_class(self, class_info: BECClassInfo):
""" """

View File

@ -0,0 +1,98 @@
from unittest import mock
import pytest
from bec_widgets.cli.generate_cli import BECClassContainer, ClientGenerator
def test_client_generator_init():
"""
Test the initialization of the ClientGenerator class.
"""
generator = ClientGenerator()
assert generator.header.startswith("# This file was automatically generated by generate_cli.py")
assert generator.content == ""
def test_generate_client():
"""
Test the generate_client method of the ClientGenerator class.
"""
generator = ClientGenerator()
class_container = mock.MagicMock(spec=BECClassContainer)
class_container.rpc_top_level_classes = [mock.MagicMock(RPC=True, __name__="TestClass1")]
class_container.connector_classes = [mock.MagicMock(RPC=True, __name__="TestClass2")]
generator.generate_client(class_container)
assert '"TestClass1": "TestClass1"' in generator.content
assert "class TestClass2(RPCBase):" in generator.content
@pytest.mark.parametrize("plugin", (True, False))
def test_write_client_enum(plugin):
"""
Test the write_client_enum method of the ClientGenerator class.
"""
generator = ClientGenerator(base=plugin)
published_classes = [
mock.MagicMock(__name__="TestClass1"),
mock.MagicMock(__name__="TestClass2"),
]
generator.write_client_enum(published_classes)
assert ("class _WidgetsEnumType(str, enum.Enum):" in generator.content) is plugin
assert '"TestClass1": "TestClass1",' in generator.content
assert '"TestClass2": "TestClass2",' in generator.content
def test_generate_content_for_class():
"""
Test the generate_content_for_class method of the ClientGenerator class.
"""
generator = ClientGenerator()
cls = mock.MagicMock(__name__="TestClass", USER_ACCESS=["method1"])
method = mock.MagicMock()
method.__name__ = "method1"
method.__doc__ = "Test method"
method_signature = "(self)"
cls.method1 = method
with mock.patch("inspect.signature", return_value=method_signature):
generator.generate_content_for_class(cls)
assert "class TestClass(RPCBase):" in generator.content
assert "def method1(self):" in generator.content
assert "Test method" in generator.content
def test_write_is_black_formatted(tmp_path):
"""
Test the write method of the ClientGenerator class.
"""
generator = ClientGenerator()
generator.content = """
def test_content():
pass
a=1
b=2
c=a+b
"""
corrected = """def test_content():
pass
a = 1
b = 2
c = a + b"""
file_name = tmp_path / "test_client.py"
generator.write(str(file_name))
with open(file_name, "r", encoding="utf-8") as file:
content = file.read()
assert corrected in content