diff --git a/bec_widgets/cli/auto_updates.py b/bec_widgets/cli/auto_updates.py index dfc7aa42..999ddcce 100644 --- a/bec_widgets/cli/auto_updates.py +++ b/bec_widgets/cli/auto_updates.py @@ -1,5 +1,7 @@ from __future__ import annotations +import threading +from queue import Queue from typing import TYPE_CHECKING from pydantic import BaseModel @@ -25,6 +27,17 @@ class AutoUpdates: def __init__(self, gui: BECDockArea): self.gui = gui + self.msg_queue = Queue() + self.auto_update_thread = None + self._shutdown_sentinel = object() + self.start() + + def start(self): + """ + Start the auto update thread. + """ + self.auto_update_thread = threading.Thread(target=self.process_queue) + self.auto_update_thread.start() def start_default_dock(self): """ @@ -79,6 +92,16 @@ class AutoUpdates: info = self.get_scan_info(msg) self.handler(info) + def process_queue(self): + """ + Process the message queue. + """ + while True: + msg = self.msg_queue.get() + if msg is self._shutdown_sentinel: + break + self.run(msg) + @staticmethod def get_selected_device(monitored_devices, selected_device): """ @@ -151,3 +174,11 @@ class AutoUpdates: fig.clear_all() plt = fig.plot(x_name=dev_x, y_name=dev_y, label=f"Scan {info.scan_number} - {dev_y}") plt.set(title=f"Scan {info.scan_number}", x_label=dev_x, y_label=dev_y) + + def shutdown(self): + """ + Shutdown the auto update thread. + """ + self.msg_queue.put(self._shutdown_sentinel) + if self.auto_update_thread: + self.auto_update_thread.join() diff --git a/bec_widgets/cli/client_utils.py b/bec_widgets/cli/client_utils.py index bcf11300..7be97cae 100644 --- a/bec_widgets/cli/client_utils.py +++ b/bec_widgets/cli/client_utils.py @@ -6,7 +6,6 @@ import json import os import select import subprocess -import sys import threading import time import uuid @@ -16,7 +15,6 @@ from typing import TYPE_CHECKING from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import isinstance_based_on_class_name, lazy_import, lazy_import_from -from qtpy.QtCore import QEventLoop, QSocketNotifier, QTimer import bec_widgets.cli.client as client from bec_widgets.cli.auto_updates import AutoUpdates @@ -24,10 +22,6 @@ from bec_widgets.cli.auto_updates import AutoUpdates if TYPE_CHECKING: from bec_lib.device import DeviceBase - from bec_widgets.cli.client import BECDockArea, BECFigure - -from bec_lib.serialization import MsgpackSerialization - messages = lazy_import("bec_lib.messages") # from bec_lib.connector import MessageObject MessageObject = lazy_import_from("bec_lib.connector", ("MessageObject",)) @@ -184,7 +178,7 @@ class BECGuiClientMixin: if isinstance(msg, messages.ScanStatusMessage): if not self.gui_is_alive(): return - self.auto_updates.run(msg) + self.auto_updates.msg_queue.put(msg) def show(self) -> None: """ @@ -213,6 +207,8 @@ class BECGuiClientMixin: self._process_output_processing_thread.join() self._process.wait() self._process = None + if self.auto_updates is not None: + self.auto_updates.shutdown() class RPCResponseTimeoutError(Exception): @@ -224,54 +220,14 @@ class RPCResponseTimeoutError(Exception): ) -class QtRedisMessageWaiter: - def __init__(self, redis_connector, message_to_wait): - self.ev_loop = QEventLoop() - self.response = None - self.connector = redis_connector - self.message_to_wait = message_to_wait - self.pubsub = redis_connector._redis_conn.pubsub() - self.pubsub.subscribe(self.message_to_wait.endpoint) - fd = self.pubsub.connection._sock.fileno() - self.notifier = QSocketNotifier(fd, QSocketNotifier.Read) - self.notifier.activated.connect(self._pubsub_readable) - - def _msg_received(self, msg_obj): - self.response = msg_obj.value - self.ev_loop.quit() - - def wait(self, timeout=1): - timer = QTimer() - timer.singleShot(timeout * 1000, self.ev_loop.quit) - self.ev_loop.exec_() - timer.stop() - self.notifier.setEnabled(False) - self.pubsub.close() - return self.response - - def _pubsub_readable(self, fd): - while True: - msg = self.pubsub.get_message() - if msg: - if msg["type"] == "subscribe": - # get_message buffers, so we may already have the answer - # let's check... - continue - else: - break - else: - return - channel = msg["channel"].decode() - msg = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"])) - self.connector._execute_callback(self._msg_received, msg, {}) - - class RPCBase: def __init__(self, gui_id: str = None, config: dict = None, parent=None) -> None: self._client = BECDispatcher().client self._config = config if config is not None else {} self._gui_id = gui_id if gui_id is not None else str(uuid.uuid4()) self._parent = parent + self._msg_wait_event = threading.Event() + self._rpc_response = None super().__init__() # print(f"RPCBase: {self._gui_id}") @@ -315,24 +271,39 @@ class RPCBase: # pylint: disable=protected-access receiver = self._root._gui_id if wait_for_rpc_response: - redis_msg = QtRedisMessageWaiter( - self._client.connector, MessageEndpoints.gui_instruction_response(request_id) + self._rpc_response = None + self._msg_wait_event.clear() + self._client.connector.register( + MessageEndpoints.gui_instruction_response(request_id), + cb=self._on_rpc_response, + parent=self, ) self._client.connector.set_and_publish(MessageEndpoints.gui_instructions(receiver), rpc_msg) if wait_for_rpc_response: - response = redis_msg.wait(timeout) - - if response is None: - raise RPCResponseTimeoutError(request_id, timeout) - + try: + finished = self._msg_wait_event.wait(10) + if not finished: + raise RPCResponseTimeoutError(request_id, timeout) + finally: + self._msg_wait_event.clear() + self._client.connector.unregister( + MessageEndpoints.gui_instruction_response(request_id), cb=self._on_rpc_response + ) # get class name - if not response.accepted: - raise ValueError(response.message["error"]) - msg_result = response.message.get("result") + if not self._rpc_response.accepted: + raise ValueError(self._rpc_response.message["error"]) + msg_result = self._rpc_response.message.get("result") + self._rpc_response = None return self._create_widget_from_msg_result(msg_result) + @staticmethod + def _on_rpc_response(msg: MessageObject, parent: RPCBase) -> None: + msg = msg.value + parent._msg_wait_event.set() + parent._rpc_response = msg + def _create_widget_from_msg_result(self, msg_result): if msg_result is None: return None diff --git a/bec_widgets/qt_utils/redis_message_waiter.py b/bec_widgets/qt_utils/redis_message_waiter.py new file mode 100644 index 00000000..4df9dfbc --- /dev/null +++ b/bec_widgets/qt_utils/redis_message_waiter.py @@ -0,0 +1,47 @@ +from bec_lib.serialization import MsgpackSerialization +from bec_lib.utils import lazy_import_from +from qtpy.QtCore import QEventLoop, QSocketNotifier, QTimer + +MessageObject = lazy_import_from("bec_lib.connector", ("MessageObject",)) + + +class QtRedisMessageWaiter: + def __init__(self, redis_connector, message_to_wait): + self.ev_loop = QEventLoop() + self.response = None + self.connector = redis_connector + self.message_to_wait = message_to_wait + self.pubsub = redis_connector._redis_conn.pubsub() + self.pubsub.subscribe(self.message_to_wait.endpoint) + fd = self.pubsub.connection._sock.fileno() + self.notifier = QSocketNotifier(fd, QSocketNotifier.Read) + self.notifier.activated.connect(self._pubsub_readable) + + def _msg_received(self, msg_obj): + self.response = msg_obj.value + self.ev_loop.quit() + + def wait(self, timeout=1): + timer = QTimer() + timer.singleShot(timeout * 1000, self.ev_loop.quit) + self.ev_loop.exec_() + timer.stop() + self.notifier.setEnabled(False) + self.pubsub.close() + return self.response + + def _pubsub_readable(self, fd): + while True: + msg = self.pubsub.get_message() + if msg: + if msg["type"] == "subscribe": + # get_message buffers, so we may already have the answer + # let's check... + continue + else: + break + else: + return + channel = msg["channel"].decode() + msg = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"])) + self.connector._execute_callback(self._msg_received, msg, {}) diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index 4402ed98..86c9b3ea 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -8,7 +8,7 @@ import redis from bec_lib.client import BECClient from bec_lib.redis_connector import MessageObject, RedisConnector from bec_lib.service_config import ServiceConfig -from qtpy.QtCore import PYQT6, PYSIDE6, QCoreApplication, QObject +from qtpy.QtCore import QObject from qtpy.QtCore import Signal as pyqtSignal if TYPE_CHECKING: @@ -75,7 +75,6 @@ class BECDispatcher: _instance = None _initialized = False - qapp = None def __new__(cls, client=None, config: str = None, *args, **kwargs): if cls._instance is None: @@ -87,9 +86,6 @@ class BECDispatcher: if self._initialized: return - if not QCoreApplication.instance(): - BECDispatcher.qapp = QCoreApplication([]) - self._slots = collections.defaultdict(set) self.client = client @@ -123,16 +119,6 @@ class BECDispatcher: cls._instance = None cls._initialized = False - if not cls.qapp: - return - - # shutdown QCoreApp if it exists - if PYQT6: - cls.qapp.exit() - elif PYSIDE6: - cls.qapp.shutdown() - cls.qapp = None - def connect_slot( self, slot: Callable, diff --git a/tests/end-2-end/test_bec_dock_rpc_e2e.py b/tests/end-2-end/test_bec_dock_rpc_e2e.py index b01431c9..1f5cbb45 100644 --- a/tests/end-2-end/test_bec_dock_rpc_e2e.py +++ b/tests/end-2-end/test_bec_dock_rpc_e2e.py @@ -295,3 +295,4 @@ def test_auto_update(bec_client_lib, rpc_server_dock, qtbot): plt_data[f"Scan {status.scan.scan_number} - {dock.selected_device}"]["y"] == last_scan_data["samy"]["samy"].val ) + dock.auto_updates.shutdown()