diff --git a/bec_widgets/utils/bec_plugin_helper.py b/bec_widgets/utils/bec_plugin_helper.py index 242adbdc..c1fa88b9 100644 --- a/bec_widgets/utils/bec_plugin_helper.py +++ b/bec_widgets/utils/bec_plugin_helper.py @@ -1,7 +1,9 @@ from __future__ import annotations +import ast import importlib.metadata import inspect +import os import pkgutil import traceback from importlib import util as importlib_util @@ -11,11 +13,61 @@ from typing import Generator from bec_lib.logger import bec_logger -from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo +from bec_widgets.utils.plugin_utils import ( + BECClassContainer, + BECClassInfo, + BECClassReference, + _ast_node_name, + _class_has_rpc_markers, +) logger = bec_logger.logger +def _plugin_class_is_candidate(node: ast.ClassDef) -> bool: + base_names = {_ast_node_name(base) for base in node.bases} + return bool({"BECWidget", "BECConnector"} & base_names) or _class_has_rpc_markers(node) + + +def get_all_plugin_widget_references() -> list[BECClassReference]: + references: list[BECClassReference] = [] + seen_names: set[str] = set() + for entry_point in importlib.metadata.entry_points(group="bec.widgets.user_widgets"): # type: ignore + spec = importlib_util.find_spec(entry_point.module) + if spec is None: + continue + + package_roots = list(spec.submodule_search_locations or ()) + if spec.origin and not package_roots: + package_roots = [os.path.dirname(spec.origin)] + + for package_root in package_roots: + for root, _, files in sorted(os.walk(package_root)): + for file_name in sorted(files): + if not file_name.endswith(".py") or file_name.startswith("__"): + continue + path = os.path.join(root, file_name) + with open(path, encoding="utf-8") as file_handle: + module = ast.parse(file_handle.read(), filename=path) + module_name = ".".join( + os.path.relpath(path, package_root).removesuffix(".py").split(os.sep) + ) + for node in module.body: + if not isinstance(node, ast.ClassDef) or not _plugin_class_is_candidate( + node + ): + continue + if node.name in seen_names: + continue + references.append( + BECClassReference( + name=node.name, module=f"{entry_point.module}.{module_name}" + ) + ) + seen_names.add(node.name) + return references + + def _submodule_specs(module: ModuleType) -> tuple[ModuleSpec | None, ...]: """Return specs for all submodules of the given module.""" return tuple( diff --git a/bec_widgets/utils/plugin_utils.py b/bec_widgets/utils/plugin_utils.py index c9c367ae..433c3516 100644 --- a/bec_widgets/utils/plugin_utils.py +++ b/bec_widgets/utils/plugin_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import importlib import inspect import os @@ -7,16 +8,16 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Iterable from bec_lib.plugin_helper import _get_available_plugins -from qtpy.QtWidgets import QWidget - -from bec_widgets.utils.bec_connector import BECConnector -from bec_widgets.utils.bec_widget import BECWidget if TYPE_CHECKING: # pragma: no cover + from bec_widgets.utils.bec_connector import BECConnector + from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.widgets.containers.auto_update.auto_updates import AutoUpdates +_DISCOVERY_BASE_NAMES = frozenset({"BECConnector", "BECWidget", "ViewBase"}) -def get_plugin_widgets() -> dict[str, BECConnector]: + +def get_plugin_widgets() -> dict[str, "BECConnector"]: """ Get all available widgets from the plugin directory. Widgets are classes that inherit from BECConnector. The plugins are provided through python plugins and specified in the respective pyproject.toml file using @@ -48,6 +49,8 @@ def get_plugin_widgets() -> dict[str, BECConnector]: def _filter_plugins(obj): + from bec_widgets.utils.bec_connector import BECConnector + return inspect.isclass(obj) and issubclass(obj, BECConnector) @@ -90,14 +93,20 @@ class BECClassInfo: name: str module: str file: str - obj: type[BECWidget] + obj: type["BECWidget"] is_connector: bool = False is_widget: bool = False is_plugin: bool = False +@dataclass(frozen=True) +class BECClassReference: + name: str + module: str + + class BECClassContainer: - def __init__(self, initial: Iterable[BECClassInfo] = []): + def __init__(self, initial: Iterable[BECClassInfo] = ()): self._collection: list[BECClassInfo] = list(initial) def __repr__(self): @@ -109,12 +118,13 @@ class BECClassContainer: def __add__(self, other: BECClassContainer): return BECClassContainer((*self, *(c for c in other if c.name not in self.names))) - def as_dict(self, ignores: list[str] = []) -> dict[str, type[BECWidget]]: + def as_dict(self, ignores: list[str] | None = None) -> dict[str, type["BECWidget"]]: """get a dict of {name: Type} for all the entries in the collection. Args: ignores(list[str]): a list of class names to exclude from the dictionary.""" - return {c.name: c.obj for c in self if c.name not in ignores} + ignore_set = set(ignores or ()) + return {c.name: c.obj for c in self if c.name not in ignore_set} def add_class(self, class_info: BECClassInfo): """ @@ -166,48 +176,126 @@ class BECClassContainer: return [info.obj for info in self.collection] -def _collect_classes_from_package(repo_name: str, package: str) -> BECClassContainer: - """Collect classes from a package subtree (for example ``widgets`` or ``applications``).""" - collection = BECClassContainer() +def _ast_node_name(node: ast.expr) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +def _class_has_rpc_markers(node: ast.ClassDef) -> bool: + for stmt in node.body: + if isinstance(stmt, ast.Assign): + target_names = {target.id for target in stmt.targets if isinstance(target, ast.Name)} + if ( + "PLUGIN" in target_names + and isinstance(stmt.value, ast.Constant) + and stmt.value.value + ): + return True + if {"RPC_CONTENT_CLASS", "USER_ACCESS"} & target_names: + return True + if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): + if ( + stmt.target.id == "PLUGIN" + and isinstance(stmt.value, ast.Constant) + and stmt.value.value + ): + return True + if stmt.target.id in {"RPC_CONTENT_CLASS", "USER_ACCESS"}: + return True + return False + + +def _class_is_candidate(node: ast.ClassDef) -> bool: + base_names = {_ast_node_name(base) for base in node.bases} + return bool(_DISCOVERY_BASE_NAMES & base_names) or _class_has_rpc_markers(node) + + +def _candidate_top_level_class_names(path: str) -> list[str]: + with open(path, encoding="utf-8") as file_handle: + module = ast.parse(file_handle.read(), filename=path) + return [ + node.name + for node in module.body + if isinstance(node, ast.ClassDef) and _class_is_candidate(node) + ] + + +def _iter_candidate_modules(repo_name: str, package: str) -> Iterable[tuple[str, str, list[str]]]: try: anchor_module = importlib.import_module(f"{repo_name}.{package}") except ModuleNotFoundError as exc: - # Some plugin repositories expose only one subtree. Skip gracefully if it does not exist. if exc.name == f"{repo_name}.{package}": - return collection + return () raise directory = os.path.dirname(anchor_module.__file__) - for root, _, files in sorted(os.walk(directory)): - for file in files: - if not file.endswith(".py") or file.startswith("__"): + return ( + (f"{repo_name}.{package}.{module_name}", path, class_names) + for root, _, files in sorted(os.walk(directory)) + for file_name in sorted(files) + if file_name.endswith(".py") + and not file_name.startswith("__") + and not file_name.startswith("register_") + and not file_name.endswith("_plugin.py") + for path in (os.path.join(root, file_name),) + for rel_dir in (os.path.dirname(os.path.relpath(path, directory)),) + for module_name in ( + [ + ( + file_name.removesuffix(".py") + if rel_dir in ("", ".") + else ".".join(rel_dir.split(os.sep) + [file_name.removesuffix(".py")]) + ) + ] + ) + for class_names in (_candidate_top_level_class_names(path),) + if class_names + ) + + +def _collect_classes_from_package(repo_name: str, package: str) -> BECClassContainer: + """Collect classes from a package subtree (for example ``widgets`` or ``applications``).""" + collection = BECClassContainer() + for module_name, path, _ in _iter_candidate_modules(repo_name, package): + from qtpy.QtWidgets import QWidget + + from bec_widgets.utils.bec_connector import BECConnector + from bec_widgets.utils.bec_widget import BECWidget + + module = importlib.import_module(module_name) + for name, obj in inspect.getmembers(module, inspect.isclass): + if obj.__module__ != module.__name__: continue - - path = os.path.join(root, file) - rel_dir = os.path.dirname(os.path.relpath(path, directory)) - if rel_dir in ("", "."): - module_name = file.split(".")[0] - else: - module_name = ".".join(rel_dir.split(os.sep) + [file.split(".")[0]]) - - module = importlib.import_module(f"{repo_name}.{package}.{module_name}") - - for name in dir(module): - obj = getattr(module, name) - if not hasattr(obj, "__module__") or obj.__module__ != module.__name__: - continue - if isinstance(obj, type): - class_info = BECClassInfo(name=name, module=module.__name__, file=path, obj=obj) - if issubclass(obj, BECConnector): - class_info.is_connector = True - if issubclass(obj, QWidget) or issubclass(obj, BECWidget): - class_info.is_widget = True - if hasattr(obj, "PLUGIN") and obj.PLUGIN: - class_info.is_plugin = True - collection.add_class(class_info) + class_info = BECClassInfo(name=name, module=module.__name__, file=path, obj=obj) + if issubclass(obj, BECConnector): + class_info.is_connector = True + if issubclass(obj, QWidget) or issubclass(obj, BECWidget): + class_info.is_widget = True + if hasattr(obj, "PLUGIN") and obj.PLUGIN: + class_info.is_plugin = True + collection.add_class(class_info) return collection +def get_custom_class_references( + repo_name: str, packages: tuple[str, ...] | None = None +) -> list[BECClassReference]: + selected_packages = packages or ("widgets",) + references: list[BECClassReference] = [] + seen_names: set[str] = set() + for package in selected_packages: + for module_name, _, class_names in _iter_candidate_modules(repo_name, package): + for class_name in class_names: + if class_name in seen_names: + continue + references.append(BECClassReference(name=class_name, module=module_name)) + seen_names.add(class_name) + return references + + def get_custom_classes( repo_name: str, packages: tuple[str, ...] | None = None ) -> BECClassContainer: