# pylint: disable=missing-module-docstring from __future__ import annotations import argparse import importlib import inspect import os import sys from pathlib import Path import black import isort from bec_lib.logger import bec_logger from qtpy.QtCore import Property as QtProperty from bec_widgets.utils.generate_designer_plugin import DesignerPluginGenerator, plugin_filenames from bec_widgets.utils.plugin_utils import BECClassContainer, get_custom_classes logger = bec_logger.logger if sys.version_info >= (3, 11): from typing import get_overloads else: print( "Python version is less than 3.11, using dummy function for get_overloads. " "If you want to use the real function 'typing.get_overloads()', please use Python 3.11 or later." ) def get_overloads(_obj): """ Dummy function for Python versions before 3.11. """ return [] class ClientGenerator: def __init__(self, base=False): self._base = base base_imports = ( """import enum import inspect import traceback from typing import Literal, Optional """ if self._base else "\n" ) self.header = f"""# This file was automatically generated by generate_cli.py # type: ignore \n from __future__ import annotations {base_imports} from bec_lib.logger import bec_logger from bec_widgets.cli.rpc.rpc_base import RPCBase, rpc_call {"from bec_widgets.utils.bec_plugin_helper import get_all_plugin_widgets, get_plugin_client_module" if self._base else ""} logger = bec_logger.logger # pylint: skip-file""" self.content = "" def generate_client(self, class_container: BECClassContainer): """ Generate the client for the published classes, skipping any classes that have `RPC = False`. Args: class_container: The class container with the classes to generate the client for. """ # Filter out classes that explicitly have RPC=False rpc_top_level_classes = [ cls for cls in class_container.rpc_top_level_classes if getattr(cls, "RPC", True) ] rpc_top_level_classes.sort(key=lambda x: x.__name__) connector_classes = [ cls for cls in class_container.connector_classes if getattr(cls, "RPC", True) ] connector_classes.sort(key=lambda x: x.__name__) self.write_client_enum(rpc_top_level_classes) for cls in connector_classes: logger.debug(f"generating RPC client class for {cls.__name__}") self.content += "\n\n" self.generate_content_for_class(cls) def write_client_enum(self, published_classes: list[type]): """ Write the client enum to the content. """ if self._base: self.content += """ class _WidgetsEnumType(str, enum.Enum): \"\"\" Enum for the available widgets, to be generated programatically \"\"\" ... """ self.content += """ _Widgets = { """ for cls in published_classes: self.content += f'"{cls.__name__}": "{cls.__name__}",\n ' self.content += """} """ if self._base: self.content += """ try: _plugin_widgets = get_all_plugin_widgets() plugin_client = get_plugin_client_module() Widgets = _WidgetsEnumType("Widgets", {name: name for name in _plugin_widgets} | _Widgets) if (_overlap := _Widgets.keys() & _plugin_widgets.keys()) != set(): for _widget in _overlap: logger.warning(f"Detected duplicate widget {_widget} in plugin repo file: {inspect.getfile(_plugin_widgets[_widget])} !") for plugin_name, plugin_class in inspect.getmembers(plugin_client, inspect.isclass): if issubclass(plugin_class, RPCBase) and plugin_class is not RPCBase: if plugin_name in globals(): conflicting_file = ( inspect.getfile(_plugin_widgets[plugin_name]) if plugin_name in _plugin_widgets else f"{plugin_client}" ) logger.warning( f"Plugin widget {plugin_name} from {conflicting_file} conflicts with a built-in class!" ) continue if plugin_name not in _overlap: globals()[plugin_name] = plugin_class except ImportError as e: logger.error(f"Failed loading plugins: \\n{reduce(add, traceback.format_exception(e))}") """ def generate_content_for_class(self, cls): """ Generate the content for the class. Args: cls: The class for which to generate the content. """ class_name = cls.__name__ if class_name == "BECDockArea": self.content += f""" class {class_name}(RPCBase):""" else: self.content += f""" class {class_name}(RPCBase):""" if cls.__doc__: # We only want the first line of the docstring # But skip the first line if it's a blank line first_line = cls.__doc__.split("\n")[0] if first_line: class_docs = first_line else: class_docs = cls.__doc__.split("\n")[1] self.content += f""" \"\"\"{class_docs}\"\"\" """ if not cls.USER_ACCESS: self.content += """... """ for method in cls.USER_ACCESS: is_property_setter = False obj = getattr(cls, method, None) if obj is None: obj = getattr(cls, method.split(".setter")[0], None) is_property_setter = True method = method.split(".setter")[0] if obj is None: raise AttributeError( f"Method {method} not found in class {cls.__name__}. " f"Please check the USER_ACCESS list." ) if isinstance(obj, (property, QtProperty)): # for the cli, we can map qt properties to regular properties if is_property_setter: self.content += f""" @{method}.setter @rpc_call""" else: self.content += """ @property @rpc_call""" sig = str(inspect.signature(obj.fget)) doc = inspect.getdoc(obj.fget) else: sig = str(inspect.signature(obj)) doc = inspect.getdoc(obj) overloads = get_overloads(obj) for overload in overloads: sig_overload = str(inspect.signature(overload)) self.content += f""" @overload def {method}{str(sig_overload)}: ... """ self.content += """ @rpc_call""" self.content += f""" def {method}{str(sig)}: \"\"\" {doc} \"\"\"""" def write(self, file_name: str): """ Write the content to a file, automatically formatted with black. Args: file_name(str): The name of the file to write to. """ # Combine header and content, then format with black full_content = self.header + "\n" + self.content try: formatted_content = black.format_str(full_content, mode=black.FileMode(line_length=100)) except black.NothingChanged: formatted_content = full_content isort.Config( profile="black", line_length=100, multi_line_output=3, include_trailing_comma=True, known_first_party=["bec_widgets"], ) formatted_content = isort.code(formatted_content) with open(file_name, "w", encoding="utf-8") as file: file.write(formatted_content) def main(): """ Main entry point for the script, controlled by command line arguments. """ parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets") parser.add_argument( "--target", action="store", type=str, help="Which package to generate plugin files for. Should be installed in the local environment (example: my_plugin_repo)", ) args = parser.parse_args() if args.target is None: logger.error( "You must provide a target - for safety, the default of running this on bec_widgets core has been removed. To generate the client for bec_widgets, run `bw-generate-cli --target bec_widgets`" ) return logger.info(f"BEC Widget code generation tool started with args: {args}") client_subdir = "cli" if args.target == "bec_widgets" else "widgets" module_name = "bec_widgets" if args.target == "bec_widgets" else f"{args.target}.bec_widgets" try: module = importlib.import_module(module_name) assert module.__file__ is not None module_file = Path(module.__file__) module_dir = module_file.parent if module_file.is_file() else module_file except Exception as e: logger.error(f"Failed to load module {module_name} for code generation: {e}") return client_path = module_dir / client_subdir / "client.py" rpc_classes = get_custom_classes(module_name) logger.info(f"Obtained classes with RPC objects: {rpc_classes!r}") generator = ClientGenerator(base=module_name == "bec_widgets") logger.info(f"Generating client file at {client_path}") generator.generate_client(rpc_classes) generator.write(str(client_path)) if module_name != "bec_widgets": non_overwrite_classes = list(clsinfo.name for clsinfo in get_custom_classes("bec_widgets")) logger.info( f"Not writing plugins which would conflict with builtin classes: {non_overwrite_classes}" ) else: non_overwrite_classes = [] for cls in rpc_classes.plugins: logger.info(f"Writing bec-designer plugin files for {cls.__name__}...") if cls.__name__ in non_overwrite_classes: logger.error( f"Not writing plugin files for {cls.__name__} because a built-in widget with that name exists" ) plugin = DesignerPluginGenerator(cls) if not hasattr(plugin, "info"): continue def _exists(file: str): return os.path.exists(os.path.join(plugin.info.base_path, file)) if any(_exists(file) for file in plugin_filenames(plugin.info.plugin_name_snake)): logger.debug( f"Skipping generation of extra plugin files for {plugin.info.plugin_name_snake} - at least one file out of 'plugin.py', 'pyproject', and 'register_{plugin.info.plugin_name_snake}.py' already exists." ) continue plugin.run() if __name__ == "__main__": # pragma: no cover import sys sys.argv = ["bw-generate-cli", "--target", "csaxs_bec"] main()