diff --git a/bec_widgets/utils/bec_designer.py b/bec_widgets/utils/bec_designer.py index a1353a3a..e96e1f59 100644 --- a/bec_widgets/utils/bec_designer.py +++ b/bec_widgets/utils/bec_designer.py @@ -1,4 +1,8 @@ +import ast import importlib.metadata +import importlib.util +import inspect +import itertools import json import os import site @@ -6,11 +10,14 @@ import sys import sysconfig from pathlib import Path +from bec_lib.logger import bec_logger from bec_qthemes import material_icon from qtpy import PYSIDE6 from qtpy.QtGui import QIcon +from zmq import PLAIN from bec_widgets.utils.bec_plugin_helper import user_widget_plugin +from bec_widgets.utils.bec_widget import BECWidget if PYSIDE6: from PySide6.scripts.pyside_tool import ( @@ -24,6 +31,8 @@ if PYSIDE6: import bec_widgets +logger = bec_logger.logger + def designer_material_icon(icon_name: str) -> QIcon: """ @@ -122,14 +131,43 @@ def patch_designer(): # pragma: no cover qt_tool_wrapper(ui_tool_binary("designer"), sys.argv[1:]) -def find_plugin_paths(base_path: Path): +def _plugin_classes_for_python_file(file: Path): + logger.debug(f"getting plugin classes for {file}") + if not str(file).endswith(".py"): + raise ValueError("Please pass a python file") + spec = importlib.util.spec_from_file_location("_temp", file) + mod = importlib.util.module_from_spec(spec) + sys.modules["_temp"] = mod + spec.loader.exec_module(mod) + + plugin_widgets = list( + mem[0] + for mem in inspect.getmembers(mod, inspect.isclass) + if issubclass(mem[1], BECWidget) and hasattr(mem[1], "PLUGIN") and mem[1].PLUGIN is True + ) + logger.debug(f"Found: {plugin_widgets}") + return plugin_widgets + + +def _plugin_classes_for_pyproject(path: Path): + if not str(path).endswith(".pyproject"): + raise ValueError("Please pass the path of the designer pyproject file") + with open(path) as pyproject: + plugin_filenames = ast.literal_eval(pyproject.read())["files"] + plugin_files = (path.parent / file for file in plugin_filenames) + return itertools.chain(*(_plugin_classes_for_python_file(f) for f in plugin_files)) + + +def find_plugin_paths(base_path: Path) -> dict[str, list[str]]: """ - Recursively find all directories containing a .pyproject file. + Recursively find all directories containing a .pyproject file. Returns a dictionary with keys of + such paths, and values of the names of the classes contained in them if those classes are + desginer plugins. """ - plugin_paths = [] - for path in base_path.rglob("*.pyproject"): - plugin_paths.append(str(path.parent)) - return plugin_paths + return { + str(path.parent): list(_plugin_classes_for_pyproject(path)) + for path in base_path.rglob("*.pyproject") + } def set_plugin_environment_variable(plugin_paths): @@ -146,6 +184,19 @@ def set_plugin_environment_variable(plugin_paths): os.environ["PYSIDE_DESIGNER_PLUGINS"] = os.pathsep.join(current_paths) +def _extend_plugin_paths(plugin_paths: dict[str, list[str]], plugin_repo_dir: Path): + plugin_plugin_paths = find_plugin_paths(plugin_repo_dir) + builtin_plugin_names = list(itertools.chain(*plugin_paths.values())) + for plugin_file, plugin_classes in plugin_plugin_paths.items(): + logger.info(f"{plugin_classes} {builtin_plugin_names}") + if any(name in builtin_plugin_names for name in plugin_classes): + logger.warning( + f"Ignoring plugin {plugin_file} because it contains widgets {plugin_classes} which include duplicates of built-in widgets!" + ) + else: + plugin_paths[plugin_file] = plugin_classes + + # Patch the designer function def main(): # pragma: no cover if not PYSIDE6: @@ -154,11 +205,12 @@ def main(): # pragma: no cover base_dir = Path(os.path.dirname(bec_widgets.__file__)).resolve() plugin_paths = find_plugin_paths(base_dir) + if (plugin_repo := user_widget_plugin()) and isinstance(plugin_repo.__file__, str): plugin_repo_dir = Path(os.path.dirname(plugin_repo.__file__)).resolve() - plugin_paths.extend(find_plugin_paths(plugin_repo_dir)) + _extend_plugin_paths(plugin_paths, plugin_repo_dir) - set_plugin_environment_variable(plugin_paths) + set_plugin_environment_variable(plugin_paths.keys()) patch_designer()