from __future__ import annotations import ast import importlib import inspect import os from dataclasses import dataclass from functools import lru_cache from importlib import util as importlib_util from typing import TYPE_CHECKING, Callable, Iterable if TYPE_CHECKING: # pragma: no cover from bec_widgets.utils.bec_connector import BECConnector from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.widgets.containers.auto_update.auto_updates import AutoUpdates _DISCOVERY_BASE_NAMES = frozenset({"BECConnector", "BECWidget", "ViewBase"}) def get_plugin_widgets() -> dict[str, "BECConnector"]: """ Get all available widgets from the plugin directory. Widgets are classes that inherit from BECConnector. The plugins are provided through python plugins and specified in the respective pyproject.toml file using the following key: [project.entry-points."bec.widgets.user_widgets"] plugin_widgets = "path.to.plugin.module" e.g. [project.entry-points."bec.widgets.user_widgets"] plugin_widgets = "pxiii_bec.bec_widgets.widgets" assuming that the widgets module for the package pxiii_bec is located at pxiii_bec/bec_widgets/widgets and contains the widgets to be loaded within the pxiii_bec/bec_widgets/widgets/__init__.py file. Returns: dict[str, BECConnector]: A dictionary of widget names and their respective classes. """ from bec_lib.plugin_helper import _get_available_plugins modules = _get_available_plugins("bec.widgets.user_widgets") loaded_plugins = {} for module in modules: mods = inspect.getmembers(module, predicate=_filter_plugins) for name, mod_cls in mods: if name in loaded_plugins: print(f"Duplicated widgets plugin {name}.") loaded_plugins[name] = mod_cls return loaded_plugins def _filter_plugins(obj): from bec_widgets.utils.bec_connector import BECConnector return inspect.isclass(obj) and issubclass(obj, BECConnector) def get_plugin_auto_updates() -> dict[str, type[AutoUpdates]]: """ Get all available auto update classes from the plugin directory. AutoUpdates must inherit from AutoUpdate and be placed in the plugin repository's bec_widgets/auto_updates directory. The entry point for the auto updates is specified in the respective pyproject.toml file using the following key: [project.entry-points."bec.widgets.auto_updates"] plugin_widgets_update = ".bec_widgets.auto_updates" e.g. [project.entry-points."bec.widgets.auto_updates"] plugin_widgets_update = "pxiii_bec.bec_widgets.auto_updates" Returns: dict[str, AutoUpdates]: A dictionary of widget names and their respective classes. """ from bec_lib.plugin_helper import _get_available_plugins modules = _get_available_plugins("bec.widgets.auto_updates") loaded_plugins = {} for module in modules: mods = inspect.getmembers(module, predicate=_filter_auto_updates) for name, mod_cls in mods: if name in loaded_plugins: print(f"Duplicated auto update {name}.") loaded_plugins[name] = mod_cls return loaded_plugins def _filter_auto_updates(obj): from bec_widgets.widgets.containers.auto_update.auto_updates import AutoUpdates return ( inspect.isclass(obj) and issubclass(obj, AutoUpdates) and not obj.__name__ == "AutoUpdates" ) @dataclass class BECClassInfo: name: str module: str file: str obj: type["BECWidget"] is_connector: bool = False is_widget: bool = False is_plugin: bool = False @dataclass(frozen=True) class BECClassReference: name: str module: str class BECClassContainer: def __init__(self, initial: Iterable[BECClassInfo] = ()): self._collection: list[BECClassInfo] = list(initial) def __repr__(self): return str(list(cl.name for cl in self.collection)) def __iter__(self): return self._collection.__iter__() def __add__(self, other: BECClassContainer): return BECClassContainer((*self, *(c for c in other if c.name not in self.names))) def as_dict(self, ignores: list[str] | None = None) -> dict[str, type["BECWidget"]]: """get a dict of {name: Type} for all the entries in the collection. Args: ignores(list[str]): a list of class names to exclude from the dictionary.""" ignore_set = set(ignores or ()) return {c.name: c.obj for c in self if c.name not in ignore_set} def add_class(self, class_info: BECClassInfo): """ Add a class to the collection. Args: class_info(BECClassInfo): The class information """ self.collection.append(class_info) @property def names(self): """Return a list of class names""" return [c.name for c in self] @property def collection(self): """Get the collection of classes.""" return self._collection @property def connector_classes(self): """Get all connector classes.""" return [info.obj for info in self.collection if info.is_connector] @property def top_level_classes(self): """Get all top-level classes.""" return [info.obj for info in self.collection if info.is_plugin] @property def plugins(self): """Get all plugins. These are all classes that are on the top level and are widgets.""" return [info.obj for info in self.collection if info.is_widget and info.is_plugin] @property def widgets(self): """Get all widgets. These are all classes inheriting from BECWidget.""" return [info.obj for info in self.collection if info.is_widget] @property def rpc_top_level_classes(self): """Get all top-level classes that are RPC-enabled. These are all classes that users can choose from.""" return [info.obj for info in self.collection if info.is_plugin and info.is_connector] @property def classes(self): """Get all classes.""" return [info.obj for info in self.collection] def _ast_node_name(node: ast.expr) -> str | None: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): return node.attr return None def _class_has_rpc_markers(node: ast.ClassDef) -> bool: for stmt in node.body: if isinstance(stmt, ast.Assign): target_names = {target.id for target in stmt.targets if isinstance(target, ast.Name)} if ( "PLUGIN" in target_names and isinstance(stmt.value, ast.Constant) and stmt.value.value ): return True if {"RPC_CONTENT_CLASS", "USER_ACCESS"} & target_names: return True if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): if ( stmt.target.id == "PLUGIN" and isinstance(stmt.value, ast.Constant) and stmt.value.value ): return True if stmt.target.id in {"RPC_CONTENT_CLASS", "USER_ACCESS"}: return True return False def _class_is_candidate(node: ast.ClassDef) -> bool: base_names = {_ast_node_name(base) for base in node.bases} return bool(_DISCOVERY_BASE_NAMES & base_names) or _class_has_rpc_markers(node) def _candidate_top_level_class_names(path: str) -> list[str]: with open(path, encoding="utf-8") as file_handle: module = ast.parse(file_handle.read(), filename=path) return [ node.name for node in module.body if isinstance(node, ast.ClassDef) and _class_is_candidate(node) ] @lru_cache(maxsize=64) def _find_package_roots(module_name: str) -> tuple[str, ...]: spec = importlib_util.find_spec(module_name) if spec is None: raise ModuleNotFoundError(module_name) package_roots = tuple(spec.submodule_search_locations or ()) if package_roots: return package_roots if spec.origin: return (os.path.dirname(spec.origin),) raise ModuleNotFoundError(module_name) def _discover_class_references_from_roots( module_prefix: str, package_roots: Iterable[str], *, file_name_filter: Callable[[str], bool], candidate_filter: Callable[[ast.ClassDef], bool], ) -> tuple[BECClassReference, ...]: references: list[BECClassReference] = [] seen_names: set[str] = set() for package_root in package_roots: for root, _, files in sorted(os.walk(package_root)): for file_name in sorted(files): if not file_name_filter(file_name): continue path = os.path.join(root, file_name) with open(path, encoding="utf-8") as file_handle: module = ast.parse(file_handle.read(), filename=path) rel_path = os.path.relpath(path, package_root).removesuffix(".py") module_name = ".".join([module_prefix, *rel_path.split(os.sep)]) for node in module.body: if not isinstance(node, ast.ClassDef) or not candidate_filter(node): continue if node.name in seen_names: continue references.append(BECClassReference(name=node.name, module=module_name)) seen_names.add(node.name) return tuple(references) def _iter_candidate_modules(repo_name: str, package: str) -> Iterable[tuple[str, str, list[str]]]: try: package_roots = _find_package_roots(f"{repo_name}.{package}") except ModuleNotFoundError: return () modules: list[tuple[str, str, list[str]]] = [] for directory in package_roots: for root, _, files in sorted(os.walk(directory)): for file_name in sorted(files): if ( not file_name.endswith(".py") or file_name.startswith("__") or file_name.startswith("register_") or file_name.endswith("_plugin.py") ): continue path = os.path.join(root, file_name) rel_dir = os.path.dirname(os.path.relpath(path, directory)) module_name = ( file_name.removesuffix(".py") if rel_dir in ("", ".") else ".".join(rel_dir.split(os.sep) + [file_name.removesuffix(".py")]) ) class_names = _candidate_top_level_class_names(path) if class_names: modules.append((f"{repo_name}.{package}.{module_name}", path, class_names)) return tuple(modules) def _collect_classes_from_package(repo_name: str, package: str) -> BECClassContainer: """Collect classes from a package subtree (for example ``widgets`` or ``applications``).""" collection = BECClassContainer() for module_name, path, _ in _iter_candidate_modules(repo_name, package): from qtpy.QtWidgets import QWidget from bec_widgets.utils.bec_connector import BECConnector from bec_widgets.utils.bec_widget import BECWidget module = importlib.import_module(module_name) for name, obj in inspect.getmembers(module, inspect.isclass): if obj.__module__ != module.__name__: continue class_info = BECClassInfo(name=name, module=module.__name__, file=path, obj=obj) if issubclass(obj, BECConnector): class_info.is_connector = True if issubclass(obj, QWidget) or issubclass(obj, BECWidget): class_info.is_widget = True if hasattr(obj, "PLUGIN") and obj.PLUGIN: class_info.is_plugin = True collection.add_class(class_info) return collection @lru_cache(maxsize=32) def _cached_custom_class_references( repo_name: str, packages: tuple[str, ...] ) -> tuple[BECClassReference, ...]: references: list[BECClassReference] = [] seen_names: set[str] = set() for package in packages: for module_name, _, class_names in _iter_candidate_modules(repo_name, package): for class_name in class_names: if class_name in seen_names: continue references.append(BECClassReference(name=class_name, module=module_name)) seen_names.add(class_name) return tuple(references) def get_custom_class_references( repo_name: str, packages: tuple[str, ...] | None = None, *, use_cache: bool = True ) -> list[BECClassReference]: selected_packages = packages or ("widgets",) if use_cache: return list(_cached_custom_class_references(repo_name, tuple(selected_packages))) _cached_custom_class_references.cache_clear() return list(_cached_custom_class_references(repo_name, tuple(selected_packages))) def get_custom_classes( repo_name: str, packages: tuple[str, ...] | None = None ) -> BECClassContainer: """ Get all relevant classes for RPC/CLI in the specified repository. By default, discovery is limited to ``.widgets`` for backward compatibility. Additional package subtrees (for example ``applications``) can be included explicitly. Args: repo_name(str): The name of the repository. packages(tuple[str, ...] | None): Optional tuple of package names to scan. Defaults to ("widgets",) for backward compatibility. Returns: BECClassContainer: Container with collected class information. """ selected_packages = packages or ("widgets",) collection = BECClassContainer() for package in selected_packages: collection += _collect_classes_from_package(repo_name, package) return collection