diff --git a/bec_widgets/examples/device_manager_view/device_manager_view.py b/bec_widgets/examples/device_manager_view/device_manager_view.py index 5cdea710..2f35b3d4 100644 --- a/bec_widgets/examples/device_manager_view/device_manager_view.py +++ b/bec_widgets/examples/device_manager_view/device_manager_view.py @@ -2,7 +2,7 @@ from __future__ import annotations import os from functools import partial -from typing import TYPE_CHECKING, List +from typing import List import PySide6QtAds as QtAds import yaml diff --git a/bec_widgets/widgets/control/device_manager/components/available_device_resources/device_resource_backend.py b/bec_widgets/widgets/control/device_manager/components/available_device_resources/device_resource_backend.py index 1c385131..7da504e3 100644 --- a/bec_widgets/widgets/control/device_manager/components/available_device_resources/device_resource_backend.py +++ b/bec_widgets/widgets/control/device_manager/components/available_device_resources/device_resource_backend.py @@ -16,6 +16,8 @@ from bec_lib.plugin_helper import plugin_package_name, plugin_repo_path, plugins logger = bec_logger.logger +# use the last n recovery files +_N_RECOVERY_FILES = 3 _BASE_REPO_PATH = Path(os.path.dirname(bec_lib.__file__)) / "../.." @@ -75,10 +77,6 @@ def _devices_from_file(file: str, include_source: bool = True): ) -# use the last n recovery files -_N_RECOVERY_FILES = 3 - - class _ConfigFileBackend(DeviceResourceBackend): def __init__(self) -> None: self._raw_device_set: set[HashableDevice] = self._get_config_from_backup_files() @@ -94,14 +92,13 @@ class _ConfigFileBackend(DeviceResourceBackend): dir = _BASE_REPO_PATH / "logs/device_configs/recovery_configs" files = sorted(glob("*.yaml", root_dir=dir)) last_n_files = files[-_N_RECOVERY_FILES:] - if len(last_n_files) == 0: - return set() return reduce( operator.or_, map( partial(_devices_from_file, include_source=False), (str(dir / f) for f in last_n_files), ), + set(), ) def _get_configs_from_plugin_files(self, dir: Path): @@ -131,7 +128,7 @@ class _ConfigFileBackend(DeviceResourceBackend): return {n for n, info in HashableDevice.model_fields.items() if info.annotation is str} def tags(self) -> set[str]: - return reduce(operator.or_, (dev.deviceTags for dev in self._raw_device_set), {}) + return reduce(operator.or_, (dev.deviceTags for dev in self._raw_device_set), set()) def tag_group(self, tag: str) -> set[HashableDevice]: return self.tag_groups[tag] diff --git a/bec_widgets/widgets/control/device_manager/components/dm_ophyd_test.py b/bec_widgets/widgets/control/device_manager/components/dm_ophyd_test.py index 2093eb6c..1a460459 100644 --- a/bec_widgets/widgets/control/device_manager/components/dm_ophyd_test.py +++ b/bec_widgets/widgets/control/device_manager/components/dm_ophyd_test.py @@ -4,20 +4,20 @@ from __future__ import annotations import enum import re -import traceback +from collections import deque +from concurrent.futures import CancelledError, Future, ThreadPoolExecutor from html import escape -from typing import TYPE_CHECKING, Any +from threading import Event, RLock +from typing import Any, Iterable -import bec_lib from bec_lib.logger import bec_logger from bec_qthemes import material_icon -from ophyd import status -from qtpy import QtCore, QtGui, QtWidgets +from PySide6.QtCore import QThreadPool +from qtpy import QtCore, QtWidgets from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.utils.colors import get_accent_colors -from bec_widgets.utils.error_popups import SafeProperty, SafeSlot -from bec_widgets.widgets.editors.web_console.web_console import WebConsole +from bec_widgets.utils.error_popups import SafeSlot from bec_widgets.widgets.utility.spinner.spinner import SpinnerWidget READY_TO_TEST = False @@ -34,11 +34,10 @@ except ImportError: ophyd_devices = None bec_server = None -if TYPE_CHECKING: # pragma no cover - try: - from ophyd_devices.utils.static_device_test import StaticDeviceTest - except ImportError: - StaticDeviceTest = None +try: + from ophyd_devices.utils.static_device_test import StaticDeviceTest +except ImportError: + StaticDeviceTest = None class ValidationStatus(int, enum.Enum): @@ -56,49 +55,77 @@ class DeviceValidationResult(QtCore.QObject): device_validated = QtCore.Signal(str, bool, str) -class DeviceValidationRunnable(QtCore.QRunnable): - """Runnable for validating a device configuration.""" - - def __init__( - self, - device_name: str, - config: dict, - static_device_test: StaticDeviceTest | None, - connect: bool = False, - ): - """ - Initialize the device validation runnable. - - Args: - device_name (str): The name of the device to validate. - config (dict): The configuration dictionary for the device. - static_device_test (StaticDeviceTest): The static device test instance. - connect (bool, optional): Whether to connect to the device. Defaults to False. - """ +class DeviceTester(QtCore.QRunnable): + def __init__(self, config: dict) -> None: super().__init__() - self.device_name = device_name - self.config = config - self._connect = connect - self._static_device_test = static_device_test self.signals = DeviceValidationResult() + self.shutdown_event = Event() + + self._config = config + + self._max_threads = 4 + self._pending_event = Event() + self._lock = RLock() + self._test_executor = ThreadPoolExecutor(self._max_threads, "device_manager_tester") + + self._pending_queue: deque[tuple[str, dict]] = deque([]) + self._active: set[str] = set() + + QtWidgets.QApplication.instance().aboutToQuit.connect(lambda: self.shutdown_event.set()) def run(self): - """Run method for device validation.""" - if self._static_device_test is None: - logger.error( - f"Ophyd devices or bec_server not available, cannot run validation for device {self.device_name}." - ) + if StaticDeviceTest is None: + logger.error("Ophyd devices or bec_server not available, cannot run validation.") return + while not self.shutdown_event.is_set(): + self._pending_event.wait(timeout=0.5) # check if shutting down every 0.5s + if len(self._active) >= self._max_threads: + self._pending_event.clear() # it will be set again on removing something from active + continue + with self._lock: + if len(self._pending_queue) > 0: + item, cfg = self._pending_queue.pop() + self._active.add(item) + fut = self._test_executor.submit(self._run_test, item, {item: cfg}) + fut.__dict__["__device_name"] = item + fut.add_done_callback(self._done_cb) + self._safe_check_and_clear() + + self._cleanup() + + def submit(self, devices: Iterable[tuple[str, dict]]): + with self._lock: + self._pending_queue.extend(devices) + self._pending_event.set() + + @staticmethod + def _run_test(name: str, config: dict) -> tuple[str, bool, str]: + tester = StaticDeviceTest(config_dict=config) # type: ignore # we exit early if it is None + results = tester.run_with_list_output(connect=False) + return name, results[0].success, results[0].message + + def _safe_check_and_clear(self): + with self._lock: + if len(self._pending_queue) == 0: + self._pending_event.clear() + + def _safe_remove_from_active(self, name: str): + with self._lock: + self._active.remove(name) + self._pending_event.set() # check again once a completed task is removed + + def _done_cb(self, future: Future): try: - self._static_device_test.config = {self.device_name: self.config} - results = self._static_device_test.run_with_list_output(connect=self._connect) - success = results[0].success - msg = results[0].message - self.signals.device_validated.emit(self.device_name, success, msg) - except Exception: - content = traceback.format_exc() - logger.error(f"Validation failed for device {self.device_name}. Exception: {content}") - self.signals.device_validated.emit(self.device_name, False, content) + name, success, message = future.result() + except CancelledError: + return + except Exception as e: + name, success, message = future.__dict__["__device_name"], False, str(e) + finally: + self._safe_remove_from_active(future.__dict__["__device_name"]) + self.signals.device_validated.emit(name, success, message) + + def _cleanup(self): ... class ValidationListItem(QtWidgets.QWidget): @@ -177,13 +204,13 @@ class DMOphydTest(BECWidget, QtWidgets.QWidget): super().__init__(parent=parent, client=client) if not READY_TO_TEST: self.setDisabled(True) - self.static_device_test = None + self.tester = None else: - from ophyd_devices.utils.static_device_test import StaticDeviceTest - - self.static_device_test = StaticDeviceTest(config_dict={}) + self.tester = DeviceTester({}) + self.tester.signals.device_validated.connect(self._on_device_validated) + QThreadPool.globalInstance().start(self.tester) self._device_list_items: dict[str, QtWidgets.QListWidgetItem] = {} - self._thread_pool = QtCore.QThreadPool.globalInstance() + self._thread_pool = QtCore.QThreadPool(maxThreadCount=1) self._main_layout = QtWidgets.QVBoxLayout(self) self._main_layout.setContentsMargins(0, 0, 0, 0) @@ -223,7 +250,9 @@ class DMOphydTest(BECWidget, QtWidgets.QWidget): if added: if name in self._device_list_items: continue - self._add_device(name, cfg) + if self.tester: + self._add_device(name, cfg) + self.tester.submit([(name, cfg)]) continue if name not in self._device_list_items: continue @@ -238,7 +267,6 @@ class DMOphydTest(BECWidget, QtWidgets.QWidget): self._list_widget.addItem(item) self._list_widget.setItemWidget(item, widget) self._device_list_items[name] = item - self._run_device_validation(widget) def _remove_list_item(self, device_name: str): """Remove a device from the list.""" @@ -254,34 +282,6 @@ class DMOphydTest(BECWidget, QtWidgets.QWidget): row = self._list_widget.row(item) self._list_widget.takeItem(row) - def _run_device_validation(self, widget: ValidationListItem): - """ - Run the device validation in a separate thread. - - Args: - widget (ValidationListItem): The widget to validate. - """ - if not READY_TO_TEST: - logger.error("Ophyd devices or bec_server not available, cannot run validation.") - return - if ( - widget.device_name in self.client.device_manager.devices - ): # TODO and config has to be exact the same.. - self._on_device_validated( - widget.device_name, - ValidationStatus.VALID, - f"Device {widget.device_name} is already in active config", - ) - return - runnable = DeviceValidationRunnable( - device_name=widget.device_name, - config=widget.device_config, - static_device_test=self.static_device_test, - connect=False, - ) - runnable.signals.device_validated.connect(self._on_device_validated) - self._thread_pool.start(runnable) - @SafeSlot(str, bool, str) def _on_device_validated(self, device_name: str, success: bool, message: str): """Handle the device validation result. @@ -391,6 +391,11 @@ class DMOphydTest(BECWidget, QtWidgets.QWidget): if item: self._list_widget.removeItemWidget(item) + def cleanup(self): + if self.tester: + self.tester.shutdown_event.set() + return super().cleanup() + if __name__ == "__main__": import sys