mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-22 02:20:02 +02:00
108 lines
4.0 KiB
Python
108 lines
4.0 KiB
Python
import glob
|
|
import importlib
|
|
import importlib.util
|
|
import inspect
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from bec_lib import MessageEndpoints, bec_logger
|
|
from bec_lib.messages import AvailableResourceMessage
|
|
from bec_lib.signature_serializer import signature_to_dict
|
|
|
|
from . import scans as ScanServerScans
|
|
|
|
logger = bec_logger.logger
|
|
|
|
|
|
class ScanManager:
|
|
DEFAULT_PLUGIN_PATH = Path(os.path.dirname(os.path.abspath(__file__)) + "/../../").resolve()
|
|
|
|
def __init__(self, *, parent):
|
|
"""
|
|
Scan Manager loads and manages the available scans.
|
|
"""
|
|
self.parent = parent
|
|
self.available_scans = {}
|
|
self.scan_dict = {}
|
|
self._plugins = {}
|
|
self.load_plugins()
|
|
self.update_available_scans()
|
|
self.publish_available_scans()
|
|
|
|
def load_plugins(self):
|
|
"""load scan plugins"""
|
|
plugin_path = os.environ.get("BEC_PLUGIN_PATH")
|
|
if not plugin_path:
|
|
logger.info("BEC_PLUGIN_PATH not set. Using default plugin path.")
|
|
plugin_path = self.DEFAULT_PLUGIN_PATH
|
|
else:
|
|
logger.info(f"Using plugin path {plugin_path}")
|
|
plugin_path = os.path.join(plugin_path, "scan_server/scan_plugins")
|
|
files = glob.glob(os.path.join(plugin_path, "*.py"))
|
|
for file in files:
|
|
if file.endswith("__init__.py"):
|
|
continue
|
|
module_spec = importlib.util.spec_from_file_location("scan_plugins", file)
|
|
plugin_module = importlib.util.module_from_spec(module_spec)
|
|
module_spec.loader.exec_module(plugin_module)
|
|
module_members = inspect.getmembers(plugin_module)
|
|
for name, cls in module_members:
|
|
if not inspect.isclass(cls):
|
|
continue
|
|
# ignore imported classes
|
|
if cls.__module__ != "scan_plugins":
|
|
continue
|
|
if issubclass(cls, ScanServerScans.RequestBase):
|
|
self._plugins[name] = cls
|
|
logger.info(f"Loading scan plugin {name}")
|
|
|
|
def update_available_scans(self):
|
|
"""load all scans and plugin scans"""
|
|
members = inspect.getmembers(ScanServerScans)
|
|
for member_name, cls in self._plugins.items():
|
|
members.append((member_name, cls))
|
|
|
|
for name, scan_cls in members:
|
|
try:
|
|
is_scan = issubclass(scan_cls, ScanServerScans.RequestBase)
|
|
except TypeError:
|
|
is_scan = False
|
|
|
|
if not is_scan or not scan_cls.scan_name:
|
|
logger.debug(f"Ignoring {name}")
|
|
continue
|
|
if scan_cls.scan_name in self.available_scans:
|
|
logger.error(f"{scan_cls.scan_name} already exists. Skipping.")
|
|
continue
|
|
|
|
report_classes = [
|
|
ScanServerScans.RequestBase,
|
|
ScanServerScans.ScanBase,
|
|
ScanServerScans.AsyncFlyScanBase,
|
|
ScanServerScans.SyncFlyScanBase,
|
|
ScanServerScans.ScanStubs,
|
|
ScanServerScans.ScanComponent,
|
|
]
|
|
|
|
for report_cls in report_classes:
|
|
if issubclass(scan_cls, report_cls):
|
|
base_cls = report_cls.__name__
|
|
self.scan_dict[scan_cls.__name__] = scan_cls
|
|
self.available_scans[scan_cls.scan_name] = {
|
|
"class": scan_cls.__name__,
|
|
"base_class": base_cls,
|
|
"arg_input": scan_cls.arg_input,
|
|
"required_kwargs": scan_cls.required_kwargs,
|
|
"arg_bundle_size": scan_cls.arg_bundle_size,
|
|
"scan_report_hint": scan_cls.scan_report_hint,
|
|
"doc": scan_cls.__doc__ or scan_cls.__init__.__doc__,
|
|
"signature": signature_to_dict(scan_cls.__init__),
|
|
}
|
|
|
|
def publish_available_scans(self):
|
|
"""send all available scans to the broker"""
|
|
self.parent.connector.set(
|
|
MessageEndpoints.available_scans(),
|
|
AvailableResourceMessage(resource=self.available_scans),
|
|
)
|