mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 03:31:50 +02:00
fix(utils/bec_dispatcher): BECDispatcher can accept new EndpointInfo dataclass.
This commit is contained in:
@ -11,6 +11,8 @@ from bec_lib import BECClient, ServiceConfig
|
||||
from qtpy.QtCore import QObject
|
||||
from qtpy.QtCore import Signal as pyqtSignal
|
||||
|
||||
from bec_lib.endpoints import EndpointInfo
|
||||
|
||||
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition
|
||||
# and cannot be dynamically added as class attributes after the class has been defined.
|
||||
_signal_class_factory = (
|
||||
@ -49,31 +51,50 @@ class _BECDispatcher(QObject):
|
||||
self._connections = {}
|
||||
|
||||
def connect_slot(
|
||||
self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False
|
||||
self,
|
||||
slot: Callable,
|
||||
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
|
||||
single_callback_for_all_topics=False,
|
||||
) -> None:
|
||||
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
|
||||
|
||||
Args:
|
||||
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
|
||||
the corresponding pub/sub message
|
||||
topics (str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
||||
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
||||
single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use
|
||||
separate callbacks.
|
||||
"""
|
||||
if isinstance(topics, str):
|
||||
# Normalise the topics input
|
||||
if isinstance(topics, (str, EndpointInfo)):
|
||||
topics = [topics]
|
||||
|
||||
endpoints = [
|
||||
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
|
||||
]
|
||||
|
||||
# consumer_types = [
|
||||
# topic.consumer_type if isinstance(topic, EndpointInfo) else "SET_PUBLISH"
|
||||
# for topic in topics
|
||||
# ]
|
||||
|
||||
# Ensure topics_key is a tuple, whether single_callback_for_all_topics is True or False.
|
||||
topics_key = tuple(sorted(topics)) if single_callback_for_all_topics else tuple(topics)
|
||||
topics_key = (
|
||||
tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints)
|
||||
)
|
||||
|
||||
if topics_key not in self._connections:
|
||||
self._connections[topics_key] = self._create_connection(topics)
|
||||
self._connections[topics_key] = self._create_connection(
|
||||
endpoints
|
||||
) # , consumer_type = )# add here consumer type
|
||||
connection = self._connections[topics_key]
|
||||
if slot not in connection.slots:
|
||||
connection.signal.connect(slot)
|
||||
connection.slots.add(slot)
|
||||
|
||||
def _create_connection(self, topics: list) -> _Connection:
|
||||
def _create_connection(
|
||||
self, topics: list
|
||||
) -> _Connection: # , consumer_type: str) -> _Connection:
|
||||
"""Creates a new connection for given topics."""
|
||||
|
||||
def cb(msg):
|
||||
@ -86,6 +107,10 @@ class _BECDispatcher(QObject):
|
||||
|
||||
try:
|
||||
self.client.connector.register(topics=topics, cb=cb)
|
||||
# if consumer_type == "SET_PUBLISH":
|
||||
# self.client.connector.register(topics=topics, cb=cb)
|
||||
# elif consumer_type == "STREAM":
|
||||
# self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True)
|
||||
except redis.exceptions.ConnectionError:
|
||||
print("Could not connect to Redis, skipping registration of topics.")
|
||||
|
||||
@ -123,14 +148,21 @@ class _BECDispatcher(QObject):
|
||||
topics (str | list): A corresponding topic or list of topics that can typically be acquired via
|
||||
bec_lib.MessageEndpoints
|
||||
"""
|
||||
if isinstance(topics, str):
|
||||
# Normalise the topics input
|
||||
if isinstance(topics, (str, EndpointInfo)):
|
||||
topics = [topics]
|
||||
|
||||
endpoints = [
|
||||
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
|
||||
]
|
||||
# if isinstance(topics, str):
|
||||
# topics = [topics]
|
||||
|
||||
for key, connection in list(self._connections.items()):
|
||||
if slot in connection.slots:
|
||||
common_topics = set(topics).intersection(key)
|
||||
common_topics = set(endpoints).intersection(key)
|
||||
if common_topics:
|
||||
remaining_topics = set(key) - set(topics)
|
||||
remaining_topics = set(key) - set(endpoints)
|
||||
# Disconnect slot from common topics
|
||||
self._do_disconnect_slot(key, slot)
|
||||
# Reconnect slot to remaining topics if any
|
||||
|
Reference in New Issue
Block a user