diff --git a/bec_widgets/utils/plugin_utils.py b/bec_widgets/utils/plugin_utils.py index bf672e6c..711e46af 100644 --- a/bec_widgets/utils/plugin_utils.py +++ b/bec_widgets/utils/plugin_utils.py @@ -325,19 +325,117 @@ def _collect_classes_from_package(repo_name: str, package: str) -> BECClassConta return collection +def _build_ast_inheritance_map( + repo_name: str, packages: tuple[str, ...] +) -> dict[str, tuple[str, set[str]]]: + """ + Walk all candidate modules in the given packages and return a map of: + class_name -> (module_name, {direct_base_names}) + + This is used for the transitive-closure widget discovery so that subclasses + of discovered widget bases are themselves discoverable without needing to + repeat ``PLUGIN = True`` on every intermediate or leaf class. + """ + mapping: dict[str, tuple[str, set[str]]] = {} + for package in packages: + try: + package_roots = _find_package_roots(f"{repo_name}.{package}") + except ModuleNotFoundError: + continue + for directory in package_roots: + for root, _, files in sorted(os.walk(directory)): + for file_name in sorted(files): + if ( + not file_name.endswith(".py") + or file_name.startswith("__") + or file_name.startswith("register_") + or file_name.endswith("_plugin.py") + ): + continue + path = os.path.join(root, file_name) + rel_dir = os.path.dirname(os.path.relpath(path, directory)) + module_name = ( + file_name.removesuffix(".py") + if rel_dir in ("", ".") + else ".".join(rel_dir.split(os.sep) + [file_name.removesuffix(".py")]) + ) + full_module = f"{repo_name}.{package}.{module_name}" + with open(path, encoding="utf-8") as fh: + tree = ast.parse(fh.read(), filename=path) + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + base_names = {_ast_node_name(b) for b in node.bases} - {None} + mapping[node.name] = (full_module, base_names) # type: ignore[arg-type] + return mapping + + @lru_cache(maxsize=32) def _cached_custom_class_references( repo_name: str, packages: tuple[str, ...] ) -> tuple[BECClassReference, ...]: + """Discover widget/connector class references using a transitive-closure AST scan. + + The first pass identifies classes that directly inherit from + ``_DISCOVERY_BASE_NAMES`` or carry explicit RPC markers (``PLUGIN``, + ``USER_ACCESS``, ``RPC_CONTENT_CLASS``). Subsequent passes treat every + newly found class name as an additional base name, so subclasses of + subclasses are discovered automatically — without requiring each + intermediate class to repeat ``PLUGIN = True``. + """ + inheritance_map = _build_ast_inheritance_map(repo_name, packages) + + # Seed with _class_has_rpc_markers — we need the AST nodes for that check. + # Re-parse only to identify initial RPC-marker classes; inheritance_map + # already has everything else we need. + rpc_marker_names: set[str] = set() + for package in packages: + try: + package_roots = _find_package_roots(f"{repo_name}.{package}") + except ModuleNotFoundError: + continue + for directory in package_roots: + for root, _, files in sorted(os.walk(directory)): + for file_name in sorted(files): + if ( + not file_name.endswith(".py") + or file_name.startswith("__") + or file_name.startswith("register_") + or file_name.endswith("_plugin.py") + ): + continue + path = os.path.join(root, file_name) + with open(path, encoding="utf-8") as fh: + tree = ast.parse(fh.read(), filename=path) + for node in tree.body: + if isinstance(node, ast.ClassDef) and _class_has_rpc_markers(node): + rpc_marker_names.add(node.name) + + # Transitive closure: start with known base names + RPC-marker classes, + # then repeatedly add classes whose direct bases are already known. + known: set[str] = set(_DISCOVERY_BASE_NAMES) | rpc_marker_names + changed = True + while changed: + changed = False + for class_name, (_, bases) in inheritance_map.items(): + if class_name not in known and bases & known: + known.add(class_name) + changed = True + + # Build the final list of references, preserving first-seen order and + # keeping only classes that are in the inheritance map (i.e. have a module). references: list[BECClassReference] = [] seen_names: set[str] = set() - for package in packages: - for module_name, _, class_names in _iter_candidate_modules(repo_name, package): - for class_name in class_names: - if class_name in seen_names: - continue - references.append(BECClassReference(name=class_name, module=module_name)) - seen_names.add(class_name) + for class_name, (module_name, bases) in inheritance_map.items(): + if class_name not in known: + continue + if class_name in seen_names: + continue + # Only emit if the class is actually a candidate (not just a raw base) + if class_name in _DISCOVERY_BASE_NAMES: + continue + references.append(BECClassReference(name=class_name, module=module_name)) + seen_names.add(class_name) return tuple(references) diff --git a/tests/end-2-end/test_rpc_widgets_e2e.py b/tests/end-2-end/test_rpc_widgets_e2e.py index 19e8109f..fb83cafa 100644 --- a/tests/end-2-end/test_rpc_widgets_e2e.py +++ b/tests/end-2-end/test_rpc_widgets_e2e.py @@ -89,13 +89,6 @@ def test_available_widgets(qtbot, connected_client_gui_obj): # Skip private attributes if object_name.startswith("_"): continue - # Skip BECShell as ttyd is not installed - if object_name == "BECShell": - continue - - # Skip BecConsole as ttyd is not installed - if object_name == "BecConsole": - continue ############################# ######### Add widget ########