mirror of
https://github.com/bec-project/bec_widgets.git
synced 2026-04-20 23:34:36 +02:00
375 lines
14 KiB
Python
375 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from importlib import util as importlib_util
|
|
from typing import TYPE_CHECKING, Callable, Iterable
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from bec_widgets.utils.bec_connector import BECConnector
|
|
from bec_widgets.utils.bec_widget import BECWidget
|
|
from bec_widgets.widgets.containers.auto_update.auto_updates import AutoUpdates
|
|
|
|
_DISCOVERY_BASE_NAMES = frozenset({"BECConnector", "BECWidget", "ViewBase"})
|
|
|
|
|
|
def get_plugin_widgets() -> dict[str, "BECConnector"]:
|
|
"""
|
|
Get all available widgets from the plugin directory. Widgets are classes that inherit from BECConnector.
|
|
The plugins are provided through python plugins and specified in the respective pyproject.toml file using
|
|
the following key:
|
|
|
|
[project.entry-points."bec.widgets.user_widgets"]
|
|
plugin_widgets = "path.to.plugin.module"
|
|
|
|
e.g.
|
|
[project.entry-points."bec.widgets.user_widgets"]
|
|
plugin_widgets = "pxiii_bec.bec_widgets.widgets"
|
|
|
|
assuming that the widgets module for the package pxiii_bec is located at pxiii_bec/bec_widgets/widgets and
|
|
contains the widgets to be loaded within the pxiii_bec/bec_widgets/widgets/__init__.py file.
|
|
|
|
Returns:
|
|
dict[str, BECConnector]: A dictionary of widget names and their respective classes.
|
|
"""
|
|
from bec_lib.plugin_helper import _get_available_plugins
|
|
|
|
modules = _get_available_plugins("bec.widgets.user_widgets")
|
|
loaded_plugins = {}
|
|
for module in modules:
|
|
mods = inspect.getmembers(module, predicate=_filter_plugins)
|
|
for name, mod_cls in mods:
|
|
if name in loaded_plugins:
|
|
print(f"Duplicated widgets plugin {name}.")
|
|
loaded_plugins[name] = mod_cls
|
|
return loaded_plugins
|
|
|
|
|
|
def _filter_plugins(obj):
|
|
from bec_widgets.utils.bec_connector import BECConnector
|
|
|
|
return inspect.isclass(obj) and issubclass(obj, BECConnector)
|
|
|
|
|
|
def get_plugin_auto_updates() -> dict[str, type[AutoUpdates]]:
|
|
"""
|
|
Get all available auto update classes from the plugin directory. AutoUpdates must inherit from AutoUpdate and be
|
|
placed in the plugin repository's bec_widgets/auto_updates directory. The entry point for the auto updates is
|
|
specified in the respective pyproject.toml file using the following key:
|
|
[project.entry-points."bec.widgets.auto_updates"]
|
|
plugin_widgets_update = "<beamline_name>.bec_widgets.auto_updates"
|
|
|
|
e.g.
|
|
[project.entry-points."bec.widgets.auto_updates"]
|
|
plugin_widgets_update = "pxiii_bec.bec_widgets.auto_updates"
|
|
|
|
Returns:
|
|
dict[str, AutoUpdates]: A dictionary of widget names and their respective classes.
|
|
"""
|
|
from bec_lib.plugin_helper import _get_available_plugins
|
|
|
|
modules = _get_available_plugins("bec.widgets.auto_updates")
|
|
loaded_plugins = {}
|
|
for module in modules:
|
|
mods = inspect.getmembers(module, predicate=_filter_auto_updates)
|
|
for name, mod_cls in mods:
|
|
if name in loaded_plugins:
|
|
print(f"Duplicated auto update {name}.")
|
|
loaded_plugins[name] = mod_cls
|
|
return loaded_plugins
|
|
|
|
|
|
def _filter_auto_updates(obj):
|
|
from bec_widgets.widgets.containers.auto_update.auto_updates import AutoUpdates
|
|
|
|
return (
|
|
inspect.isclass(obj) and issubclass(obj, AutoUpdates) and not obj.__name__ == "AutoUpdates"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BECClassInfo:
|
|
name: str
|
|
module: str
|
|
file: str
|
|
obj: type["BECWidget"]
|
|
is_connector: bool = False
|
|
is_widget: bool = False
|
|
is_plugin: bool = False
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BECClassReference:
|
|
name: str
|
|
module: str
|
|
|
|
|
|
class BECClassContainer:
|
|
def __init__(self, initial: Iterable[BECClassInfo] = ()):
|
|
self._collection: list[BECClassInfo] = list(initial)
|
|
|
|
def __repr__(self):
|
|
return str(list(cl.name for cl in self.collection))
|
|
|
|
def __iter__(self):
|
|
return self._collection.__iter__()
|
|
|
|
def __add__(self, other: BECClassContainer):
|
|
return BECClassContainer((*self, *(c for c in other if c.name not in self.names)))
|
|
|
|
def as_dict(self, ignores: list[str] | None = None) -> dict[str, type["BECWidget"]]:
|
|
"""get a dict of {name: Type} for all the entries in the collection.
|
|
|
|
Args:
|
|
ignores(list[str]): a list of class names to exclude from the dictionary."""
|
|
ignore_set = set(ignores or ())
|
|
return {c.name: c.obj for c in self if c.name not in ignore_set}
|
|
|
|
def add_class(self, class_info: BECClassInfo):
|
|
"""
|
|
Add a class to the collection.
|
|
|
|
Args:
|
|
class_info(BECClassInfo): The class information
|
|
"""
|
|
self.collection.append(class_info)
|
|
|
|
@property
|
|
def names(self):
|
|
"""Return a list of class names"""
|
|
return [c.name for c in self]
|
|
|
|
@property
|
|
def collection(self):
|
|
"""Get the collection of classes."""
|
|
return self._collection
|
|
|
|
@property
|
|
def connector_classes(self):
|
|
"""Get all connector classes."""
|
|
return [info.obj for info in self.collection if info.is_connector]
|
|
|
|
@property
|
|
def top_level_classes(self):
|
|
"""Get all top-level classes."""
|
|
return [info.obj for info in self.collection if info.is_plugin]
|
|
|
|
@property
|
|
def plugins(self):
|
|
"""Get all plugins. These are all classes that are on the top level and are widgets."""
|
|
return [info.obj for info in self.collection if info.is_widget and info.is_plugin]
|
|
|
|
@property
|
|
def widgets(self):
|
|
"""Get all widgets. These are all classes inheriting from BECWidget."""
|
|
return [info.obj for info in self.collection if info.is_widget]
|
|
|
|
@property
|
|
def rpc_top_level_classes(self):
|
|
"""Get all top-level classes that are RPC-enabled. These are all classes that users can choose from."""
|
|
return [info.obj for info in self.collection if info.is_plugin and info.is_connector]
|
|
|
|
@property
|
|
def classes(self):
|
|
"""Get all classes."""
|
|
return [info.obj for info in self.collection]
|
|
|
|
|
|
def _ast_node_name(node: ast.expr) -> str | None:
|
|
if isinstance(node, ast.Name):
|
|
return node.id
|
|
if isinstance(node, ast.Attribute):
|
|
return node.attr
|
|
return None
|
|
|
|
|
|
def _class_has_rpc_markers(node: ast.ClassDef) -> bool:
|
|
for stmt in node.body:
|
|
if isinstance(stmt, ast.Assign):
|
|
target_names = {target.id for target in stmt.targets if isinstance(target, ast.Name)}
|
|
if (
|
|
"PLUGIN" in target_names
|
|
and isinstance(stmt.value, ast.Constant)
|
|
and stmt.value.value
|
|
):
|
|
return True
|
|
if {"RPC_CONTENT_CLASS", "USER_ACCESS"} & target_names:
|
|
return True
|
|
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
|
|
if (
|
|
stmt.target.id == "PLUGIN"
|
|
and isinstance(stmt.value, ast.Constant)
|
|
and stmt.value.value
|
|
):
|
|
return True
|
|
if stmt.target.id in {"RPC_CONTENT_CLASS", "USER_ACCESS"}:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _class_is_candidate(node: ast.ClassDef) -> bool:
|
|
base_names = {_ast_node_name(base) for base in node.bases}
|
|
return bool(_DISCOVERY_BASE_NAMES & base_names) or _class_has_rpc_markers(node)
|
|
|
|
|
|
def _candidate_top_level_class_names(path: str) -> list[str]:
|
|
with open(path, encoding="utf-8") as file_handle:
|
|
module = ast.parse(file_handle.read(), filename=path)
|
|
return [
|
|
node.name
|
|
for node in module.body
|
|
if isinstance(node, ast.ClassDef) and _class_is_candidate(node)
|
|
]
|
|
|
|
|
|
@lru_cache(maxsize=64)
|
|
def _find_package_roots(module_name: str) -> tuple[str, ...]:
|
|
spec = importlib_util.find_spec(module_name)
|
|
if spec is None:
|
|
raise ModuleNotFoundError(module_name)
|
|
|
|
package_roots = tuple(spec.submodule_search_locations or ())
|
|
if package_roots:
|
|
return package_roots
|
|
if spec.origin:
|
|
return (os.path.dirname(spec.origin),)
|
|
raise ModuleNotFoundError(module_name)
|
|
|
|
|
|
def _discover_class_references_from_roots(
|
|
module_prefix: str,
|
|
package_roots: Iterable[str],
|
|
*,
|
|
file_name_filter: Callable[[str], bool],
|
|
candidate_filter: Callable[[ast.ClassDef], bool],
|
|
) -> tuple[BECClassReference, ...]:
|
|
references: list[BECClassReference] = []
|
|
seen_names: set[str] = set()
|
|
|
|
for package_root in package_roots:
|
|
for root, _, files in sorted(os.walk(package_root)):
|
|
for file_name in sorted(files):
|
|
if not file_name_filter(file_name):
|
|
continue
|
|
path = os.path.join(root, file_name)
|
|
with open(path, encoding="utf-8") as file_handle:
|
|
module = ast.parse(file_handle.read(), filename=path)
|
|
rel_path = os.path.relpath(path, package_root).removesuffix(".py")
|
|
module_name = ".".join([module_prefix, *rel_path.split(os.sep)])
|
|
for node in module.body:
|
|
if not isinstance(node, ast.ClassDef) or not candidate_filter(node):
|
|
continue
|
|
if node.name in seen_names:
|
|
continue
|
|
references.append(BECClassReference(name=node.name, module=module_name))
|
|
seen_names.add(node.name)
|
|
|
|
return tuple(references)
|
|
|
|
|
|
def _iter_candidate_modules(repo_name: str, package: str) -> Iterable[tuple[str, str, list[str]]]:
|
|
try:
|
|
package_roots = _find_package_roots(f"{repo_name}.{package}")
|
|
except ModuleNotFoundError:
|
|
return ()
|
|
|
|
modules: list[tuple[str, str, list[str]]] = []
|
|
for directory in package_roots:
|
|
for root, _, files in sorted(os.walk(directory)):
|
|
for file_name in sorted(files):
|
|
if (
|
|
not file_name.endswith(".py")
|
|
or file_name.startswith("__")
|
|
or file_name.startswith("register_")
|
|
or file_name.endswith("_plugin.py")
|
|
):
|
|
continue
|
|
path = os.path.join(root, file_name)
|
|
rel_dir = os.path.dirname(os.path.relpath(path, directory))
|
|
module_name = (
|
|
file_name.removesuffix(".py")
|
|
if rel_dir in ("", ".")
|
|
else ".".join(rel_dir.split(os.sep) + [file_name.removesuffix(".py")])
|
|
)
|
|
class_names = _candidate_top_level_class_names(path)
|
|
if class_names:
|
|
modules.append((f"{repo_name}.{package}.{module_name}", path, class_names))
|
|
return tuple(modules)
|
|
|
|
|
|
def _collect_classes_from_package(repo_name: str, package: str) -> BECClassContainer:
|
|
"""Collect classes from a package subtree (for example ``widgets`` or ``applications``)."""
|
|
collection = BECClassContainer()
|
|
for module_name, path, _ in _iter_candidate_modules(repo_name, package):
|
|
from qtpy.QtWidgets import QWidget
|
|
|
|
from bec_widgets.utils.bec_connector import BECConnector
|
|
from bec_widgets.utils.bec_widget import BECWidget
|
|
|
|
module = importlib.import_module(module_name)
|
|
for name, obj in inspect.getmembers(module, inspect.isclass):
|
|
if obj.__module__ != module.__name__:
|
|
continue
|
|
class_info = BECClassInfo(name=name, module=module.__name__, file=path, obj=obj)
|
|
if issubclass(obj, BECConnector):
|
|
class_info.is_connector = True
|
|
if issubclass(obj, QWidget) or issubclass(obj, BECWidget):
|
|
class_info.is_widget = True
|
|
if hasattr(obj, "PLUGIN") and obj.PLUGIN:
|
|
class_info.is_plugin = True
|
|
collection.add_class(class_info)
|
|
return collection
|
|
|
|
|
|
@lru_cache(maxsize=32)
|
|
def _cached_custom_class_references(
|
|
repo_name: str, packages: tuple[str, ...]
|
|
) -> tuple[BECClassReference, ...]:
|
|
references: list[BECClassReference] = []
|
|
seen_names: set[str] = set()
|
|
for package in packages:
|
|
for module_name, _, class_names in _iter_candidate_modules(repo_name, package):
|
|
for class_name in class_names:
|
|
if class_name in seen_names:
|
|
continue
|
|
references.append(BECClassReference(name=class_name, module=module_name))
|
|
seen_names.add(class_name)
|
|
return tuple(references)
|
|
|
|
|
|
def get_custom_class_references(
|
|
repo_name: str, packages: tuple[str, ...] | None = None, *, use_cache: bool = True
|
|
) -> list[BECClassReference]:
|
|
selected_packages = packages or ("widgets",)
|
|
if use_cache:
|
|
return list(_cached_custom_class_references(repo_name, tuple(selected_packages)))
|
|
_cached_custom_class_references.cache_clear()
|
|
return list(_cached_custom_class_references(repo_name, tuple(selected_packages)))
|
|
|
|
|
|
def get_custom_classes(
|
|
repo_name: str, packages: tuple[str, ...] | None = None
|
|
) -> BECClassContainer:
|
|
"""
|
|
Get all relevant classes for RPC/CLI in the specified repository.
|
|
|
|
By default, discovery is limited to ``<repo>.widgets`` for backward compatibility.
|
|
Additional package subtrees (for example ``applications``) can be included explicitly.
|
|
|
|
Args:
|
|
repo_name(str): The name of the repository.
|
|
packages(tuple[str, ...] | None): Optional tuple of package names to scan. Defaults to ("widgets",) for backward compatibility.
|
|
|
|
Returns:
|
|
BECClassContainer: Container with collected class information.
|
|
"""
|
|
selected_packages = packages or ("widgets",)
|
|
collection = BECClassContainer()
|
|
for package in selected_packages:
|
|
collection += _collect_classes_from_package(repo_name, package)
|
|
return collection
|