0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 03:31:50 +02:00

fix(utils/bec_dispatcher): BECDispatcher can accept new EndpointInfo dataclass.

This commit is contained in:
2024-03-07 13:46:51 +01:00
parent 814768525f
commit c319dacb24

View File

@ -11,6 +11,8 @@ from bec_lib import BECClient, ServiceConfig
from qtpy.QtCore import QObject
from qtpy.QtCore import Signal as pyqtSignal
from bec_lib.endpoints import EndpointInfo
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition
# and cannot be dynamically added as class attributes after the class has been defined.
_signal_class_factory = (
@ -49,31 +51,50 @@ class _BECDispatcher(QObject):
self._connections = {}
def connect_slot(
self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False
self,
slot: Callable,
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
single_callback_for_all_topics=False,
) -> None:
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
Args:
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
the corresponding pub/sub message
topics (str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use
separate callbacks.
"""
if isinstance(topics, str):
# Normalise the topics input
if isinstance(topics, (str, EndpointInfo)):
topics = [topics]
endpoints = [
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
]
# consumer_types = [
# topic.consumer_type if isinstance(topic, EndpointInfo) else "SET_PUBLISH"
# for topic in topics
# ]
# Ensure topics_key is a tuple, whether single_callback_for_all_topics is True or False.
topics_key = tuple(sorted(topics)) if single_callback_for_all_topics else tuple(topics)
topics_key = (
tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints)
)
if topics_key not in self._connections:
self._connections[topics_key] = self._create_connection(topics)
self._connections[topics_key] = self._create_connection(
endpoints
) # , consumer_type = )# add here consumer type
connection = self._connections[topics_key]
if slot not in connection.slots:
connection.signal.connect(slot)
connection.slots.add(slot)
def _create_connection(self, topics: list) -> _Connection:
def _create_connection(
self, topics: list
) -> _Connection: # , consumer_type: str) -> _Connection:
"""Creates a new connection for given topics."""
def cb(msg):
@ -86,6 +107,10 @@ class _BECDispatcher(QObject):
try:
self.client.connector.register(topics=topics, cb=cb)
# if consumer_type == "SET_PUBLISH":
# self.client.connector.register(topics=topics, cb=cb)
# elif consumer_type == "STREAM":
# self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True)
except redis.exceptions.ConnectionError:
print("Could not connect to Redis, skipping registration of topics.")
@ -123,14 +148,21 @@ class _BECDispatcher(QObject):
topics (str | list): A corresponding topic or list of topics that can typically be acquired via
bec_lib.MessageEndpoints
"""
if isinstance(topics, str):
# Normalise the topics input
if isinstance(topics, (str, EndpointInfo)):
topics = [topics]
endpoints = [
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
]
# if isinstance(topics, str):
# topics = [topics]
for key, connection in list(self._connections.items()):
if slot in connection.slots:
common_topics = set(topics).intersection(key)
common_topics = set(endpoints).intersection(key)
if common_topics:
remaining_topics = set(key) - set(topics)
remaining_topics = set(key) - set(endpoints)
# Disconnect slot from common topics
self._do_disconnect_slot(key, slot)
# Reconnect slot to remaining topics if any