mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-12-31 11:11:17 +01:00
212 lines
6.8 KiB
Python
212 lines
6.8 KiB
Python
# pylint: disable=missing-module-docstring
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
import sys
|
|
from typing import Literal
|
|
|
|
import black
|
|
import isort
|
|
from qtpy.QtWidgets import QGraphicsWidget, QWidget
|
|
|
|
from bec_widgets.utils import BECConnector
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import get_overloads
|
|
else:
|
|
print(
|
|
"Python version is less than 3.11, using dummy function for get_overloads. "
|
|
"If you want to use the real function 'typing.get_overloads()', please use Python 3.11 or later."
|
|
)
|
|
|
|
def get_overloads(_obj):
|
|
"""
|
|
Dummy function for Python versions before 3.11.
|
|
"""
|
|
return []
|
|
|
|
|
|
class ClientGenerator:
|
|
def __init__(self):
|
|
self.header = """# This file was automatically generated by generate_cli.py\n
|
|
import enum
|
|
from typing import Literal, Optional, overload
|
|
|
|
from bec_widgets.cli.client_utils import RPCBase, rpc_call, BECGuiClientMixin
|
|
|
|
# pylint: skip-file"""
|
|
|
|
self.content = ""
|
|
|
|
def generate_client(
|
|
self, published_classes: dict[Literal["connector_classes", "top_level_classes"], list[type]]
|
|
):
|
|
"""
|
|
Generate the client for the published classes.
|
|
|
|
Args:
|
|
published_classes(dict): A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
|
|
"""
|
|
self.write_client_enum(published_classes["top_level_classes"])
|
|
for cls in published_classes["connector_classes"]:
|
|
self.content += "\n\n"
|
|
self.generate_content_for_class(cls)
|
|
|
|
def write_client_enum(self, published_classes: list[type]):
|
|
"""
|
|
Write the client enum to the content.
|
|
"""
|
|
self.content += """
|
|
class Widgets(str, enum.Enum):
|
|
\"\"\"
|
|
Enum for the available widgets.
|
|
\"\"\"
|
|
"""
|
|
for cls in published_classes:
|
|
self.content += f'{cls.__name__} = "{cls.__name__}"\n '
|
|
|
|
def generate_content_for_class(self, cls):
|
|
"""
|
|
Generate the content for the class.
|
|
|
|
Args:
|
|
cls: The class for which to generate the content.
|
|
"""
|
|
|
|
class_name = cls.__name__
|
|
|
|
# Generate the content
|
|
if cls.__name__ == "BECDockArea":
|
|
self.content += f"""
|
|
class {class_name}(RPCBase, BECGuiClientMixin):"""
|
|
else:
|
|
self.content += f"""
|
|
class {class_name}(RPCBase):"""
|
|
for method in cls.USER_ACCESS:
|
|
obj = getattr(cls, method)
|
|
if isinstance(obj, property):
|
|
self.content += """
|
|
@property
|
|
@rpc_call"""
|
|
sig = str(inspect.signature(obj.fget))
|
|
doc = inspect.getdoc(obj.fget)
|
|
else:
|
|
sig = str(inspect.signature(obj))
|
|
doc = inspect.getdoc(obj)
|
|
overloads = get_overloads(obj)
|
|
for overload in overloads:
|
|
sig_overload = str(inspect.signature(overload))
|
|
self.content += f"""
|
|
@overload
|
|
def {method}{str(sig_overload)}: ...
|
|
"""
|
|
|
|
self.content += """
|
|
@rpc_call"""
|
|
self.content += f"""
|
|
def {method}{str(sig)}:
|
|
\"\"\"
|
|
{doc}
|
|
\"\"\""""
|
|
|
|
def write(self, file_name: str):
|
|
"""
|
|
Write the content to a file, automatically formatted with black.
|
|
|
|
Args:
|
|
file_name(str): The name of the file to write to.
|
|
"""
|
|
# Combine header and content, then format with black
|
|
full_content = self.header + "\n" + self.content
|
|
try:
|
|
formatted_content = black.format_str(full_content, mode=black.FileMode(line_length=100))
|
|
except black.NothingChanged:
|
|
formatted_content = full_content
|
|
|
|
isort.Config(
|
|
profile="black",
|
|
line_length=100,
|
|
multi_line_output=3,
|
|
include_trailing_comma=True,
|
|
known_first_party=["bec_widgets"],
|
|
)
|
|
formatted_content = isort.code(formatted_content)
|
|
|
|
with open(file_name, "w", encoding="utf-8") as file:
|
|
file.write(formatted_content)
|
|
|
|
@staticmethod
|
|
def get_rpc_classes(
|
|
repo_name: str,
|
|
) -> dict[Literal["connector_classes", "top_level_classes"], list[type]]:
|
|
"""
|
|
Get all RPC-enabled classes in the specified repository.
|
|
|
|
Args:
|
|
repo_name(str): The name of the repository.
|
|
|
|
Returns:
|
|
dict: A dictionary with keys "connector_classes" and "top_level_classes" and values as lists of classes.
|
|
"""
|
|
connector_classes = []
|
|
top_level_classes = []
|
|
anchor_module = importlib.import_module(f"{repo_name}.widgets")
|
|
directory = os.path.dirname(anchor_module.__file__)
|
|
for root, _, files in sorted(os.walk(directory)):
|
|
for file in files:
|
|
if not file.endswith(".py") or file.startswith("__"):
|
|
continue
|
|
|
|
path = os.path.join(root, file)
|
|
subs = os.path.dirname(os.path.relpath(path, directory)).split("/")
|
|
if len(subs) == 1 and not subs[0]:
|
|
module_name = file.split(".")[0]
|
|
else:
|
|
module_name = ".".join(subs + [file.split(".")[0]])
|
|
|
|
module = importlib.import_module(f"{repo_name}.widgets.{module_name}")
|
|
|
|
for name in dir(module):
|
|
obj = getattr(module, name)
|
|
if not hasattr(obj, "__module__") or obj.__module__ != module.__name__:
|
|
continue
|
|
if isinstance(obj, type) and issubclass(obj, BECConnector):
|
|
connector_classes.append(obj)
|
|
if len(subs) == 1 and (
|
|
issubclass(obj, QWidget) or issubclass(obj, QGraphicsWidget)
|
|
):
|
|
top_level_classes.append(obj)
|
|
|
|
return {"connector_classes": connector_classes, "top_level_classes": top_level_classes}
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main entry point for the script, controlled by command line arguments.
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description="Auto-generate the client for RPC widgets")
|
|
parser.add_argument("--core", action="store_true", help="Whether to generate the core client")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.core:
|
|
current_path = os.path.dirname(__file__)
|
|
client_path = os.path.join(current_path, "client.py")
|
|
|
|
rpc_classes = ClientGenerator.get_rpc_classes("bec_widgets")
|
|
rpc_classes["connector_classes"].append(BECConnector) # Not sure if this is necessary
|
|
rpc_classes["connector_classes"].sort(key=lambda x: x.__name__)
|
|
|
|
generator = ClientGenerator()
|
|
generator.generate_client(rpc_classes)
|
|
generator.write(client_path)
|
|
|
|
|
|
if __name__ == "__main__": # pragma: no cover
|
|
sys.argv = ["generate_cli.py", "--core"]
|
|
main()
|