diff --git a/bec_widgets/cli/client_utils.py b/bec_widgets/cli/client_utils.py index ae9c594b..d5eff536 100644 --- a/bec_widgets/cli/client_utils.py +++ b/bec_widgets/cli/client_utils.py @@ -154,6 +154,7 @@ class BECFigureClientMixin: self._run_rpc("close", (), wait_for_rpc_response=False) self._process.kill() self._process = None + self._client.shutdown() def _start_plot_process(self) -> None: """ diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index a41e366e..1239aaad 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -1,37 +1,71 @@ from __future__ import annotations import argparse -import itertools -import os +import collections from collections.abc import Callable -from typing import Union +from typing import TYPE_CHECKING, Union import redis -from bec_lib import BECClient, ServiceConfig -from bec_lib.endpoints import EndpointInfo +from bec_lib import BECClient +from bec_lib.redis_connector import MessageObject, RedisConnector from qtpy.QtCore import QObject from qtpy.QtCore import Signal as pyqtSignal -# Adding a new pyqt signal requires a class factory, as they must be part of the class definition -# and cannot be dynamically added as class attributes after the class has been defined. -_signal_class_factory = ( - type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count() -) +if TYPE_CHECKING: + from bec_lib.endpoints import EndpointInfo -class _Connection: - """Utility class to keep track of slots connected to a particular redis connector""" +class QtThreadSafeCallback(QObject): + cb_signal = pyqtSignal(dict, dict) - def __init__(self, callback) -> None: - self.callback = callback + def __init__(self, cb): + super().__init__() - self.slots = set() - # keep a reference to a new signal class, so it is not gc'ed - self._signal_container = next(_signal_class_factory)() - self.signal: pyqtSignal = self._signal_container.signal + self.cb = cb + self.cb_signal.connect(self.cb) + + def __hash__(self): + # make 2 differents QtThreadSafeCallback to look + # identical when used as dictionary keys, if the + # callback is the same + return id(self.cb) + + def __call__(self, msg_content, metadata): + self.cb_signal.emit(msg_content, metadata) -class BECDispatcher(QObject): +class QtRedisConnector(RedisConnector): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _execute_callback(self, cb, msg, kwargs): + if not isinstance(cb, QtThreadSafeCallback): + return super()._execute_callback(cb, msg, kwargs) + # if msg.msg_type == "bundle_message": + # # big warning: how to handle bundle messages? + # # message with messages inside ; which slot to call? + # # bundle_msg = msg + # # for msg in bundle_msg: + # # ... + # # for now, only consider the 1st message + # msg = msg[0] + # raise RuntimeError(f" + if isinstance(msg, MessageObject): + if isinstance(msg.value, list): + msg = msg.value[0] + else: + msg = msg.value + + # we can notice kwargs are lost when passed to Qt slot + metadata = msg.metadata + cb(msg.content, metadata) + else: + # from stream + msg = msg["data"] + cb(msg.content, msg.metadata) + + +class BECDispatcher: """Utility class to keep track of slots connected to a particular redis connector""" _instance = None @@ -47,14 +81,21 @@ class BECDispatcher(QObject): if self._initialized: return - super().__init__() - self.client = BECClient() if client is None else client + self._slots = collections.defaultdict(set) + + if client is None: + self.client = BECClient(connector_cls=QtRedisConnector, forced=True) + else: + if self.client.started: + # have to reinitialize client to use proper connector + self.client.shutdown() + self.client._BECClient__init_params["connector_cls"] = QtRedisConnector + try: self.client.start() except redis.exceptions.ConnectionError: print("Could not connect to Redis, skipping start of BECClient.") - self._connections = {} self._initialized = True @@ -67,7 +108,6 @@ class BECDispatcher(QObject): self, slot: Callable, topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]], - single_callback_for_all_topics=False, ) -> None: """Connect widget's pyqt slot, so that it is called on new pub/sub topic message. @@ -75,117 +115,27 @@ class BECDispatcher(QObject): slot (Callable): A slot method/function that accepts two inputs: content and metadata of the corresponding pub/sub message topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints - single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use - separate callbacks. """ - # Normalise the topics input - if isinstance(topics, (str, EndpointInfo)): - topics = [topics] + slot = QtThreadSafeCallback(slot) + self.client.connector.register(topics, cb=slot) + topics_str, _ = self.client.connector._convert_endpointinfo(topics) + self._slots[slot].update(set(topics_str)) - endpoint_to_consumer_type = { - (topic.endpoint if isinstance(topic, EndpointInfo) else topic): ( - topic.message_op.name if isinstance(topic, EndpointInfo) else "SEND" - ) - for topic in topics - } + def disconnect_slot(self, slot: Callable, topics: Union[str, list]): + self.client.connector.unregister(topics, cb=slot) + topics_str, _ = self.client.connector._convert_endpointinfo(topics) + self._slots[slot].difference_update(set(topics_str)) + if not self._slots[slot]: + del self._slots[slot] - # Group topics by consumer type - consumer_type_to_endpoints = {} - for endpoint, consumer_type in endpoint_to_consumer_type.items(): - if consumer_type not in consumer_type_to_endpoints: - consumer_type_to_endpoints[consumer_type] = [] - consumer_type_to_endpoints[consumer_type].append(endpoint) + def disconnect_topics(self, topics: Union[str, list]): + self.client.connector.unregister(topics) + topics_str, _ = self.client.connector._convert_endpointinfo(topics) + for slot in list(self._slots.keys()): + slot_topics = self._slots[slot] + slot_topics.difference_update(set(topics_str)) + if not slot_topics: + del self._slots[slot] - for consumer_type, endpoints in consumer_type_to_endpoints.items(): - topics_key = ( - tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints) - ) - - if topics_key not in self._connections: - self._connections[topics_key] = self._create_connection(endpoints, consumer_type) - connection = self._connections[topics_key] - - if slot not in connection.slots: - connection.signal.connect(slot) - connection.slots.add(slot) - - def _create_connection(self, topics: list, consumer_type: str) -> _Connection: - """Creates a new connection for given topics.""" - - def cb(msg): - if isinstance(msg, dict): - msg = msg["data"] - else: - msg = msg.value - for connection_key, connection in self._connections.items(): - if set(topics).intersection(connection_key): - if isinstance(msg, list): - msg = msg[0] - connection.signal.emit(msg.content, msg.metadata) - - try: - if consumer_type == "STREAM": - self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True) - else: - self.client.connector.register(topics=topics, cb=cb) - except redis.exceptions.ConnectionError: - print("Could not connect to Redis, skipping registration of topics.") - - return _Connection(cb) - - def _do_disconnect_slot(self, topic, slot): - print(f"Disconnecting {slot} from {topic}") - connection = self._connections[topic] - try: - connection.signal.disconnect(slot) - except TypeError: - print(f"Could not disconnect slot:'{slot}' from topic:'{topic}'") - print("Continue to remove slot:'{slot}' from 'connection.slots'.") - connection.slots.remove(slot) - if not connection.slots: - del self._connections[topic] - - def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None: - """A helper method to disconnect a slot from a specific topic. - - Args: - slot (Callable): A slot to be disconnected - topic (str): A corresponding topic that can typically be acquired via - bec_lib.MessageEndpoints - """ - connection = self._connections.get(topic) - if connection and slot in connection.slots: - self._do_disconnect_slot(topic, slot) - - def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None: - """Disconnect widget's pyqt slot from pub/sub updates on a topic. - - Args: - slot (Callable): A slot to be disconnected - topics (str | list): A corresponding topic or list of topics that can typically be acquired via - bec_lib.MessageEndpoints - """ - # Normalise the topics input - if isinstance(topics, (str, EndpointInfo)): - topics = [topics] - - endpoints = [ - topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics - ] - - for key, connection in list(self._connections.items()): - if slot in connection.slots: - common_topics = set(endpoints).intersection(key) - if common_topics: - remaining_topics = set(key) - set(endpoints) - # Disconnect slot from common topics - self._do_disconnect_slot(key, slot) - # Reconnect slot to remaining topics if any - if remaining_topics: - self.connect_slot(slot, list(remaining_topics), True) - - def disconnect_all(self): - """Disconnect all slots from all topics.""" - for key, connection in list(self._connections.items()): - for slot in list(connection.slots): - self._disconnect_slot_from_topic(slot, key) + def disconnect_all(self, *args, **kwargs): + self.disconnect_topics(self.client.connector._topics_cb) diff --git a/bec_widgets/widgets/motor_map/motor_map.py b/bec_widgets/widgets/motor_map/motor_map.py index 7229a8e9..b0a8f20f 100644 --- a/bec_widgets/widgets/motor_map/motor_map.py +++ b/bec_widgets/widgets/motor_map/motor_map.py @@ -217,7 +217,6 @@ class MotorMap(pg.GraphicsLayoutWidget): bec_dispatcher.connect_slot( self.on_device_readback, endpoints, - single_callback_for_all_topics=True, ) def _add_limits_to_plot_data(self):