diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index 4e500ad2..f76e138d 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -72,19 +72,53 @@ class _BECDispatcher(QObject): self._connections[topic].signal.connect(slot) self._connections[topic].slots.add(slot) - def connect_slot(self, slot: Callable, topics: Union[str, list]) -> None: + def _connect_slot_to_multiple_topics(self, slot: Callable, topics: list) -> None: + """ + A helper method to connect a slot to multiple topics using a single callback. + Args: + slot (Callable): A slot method/function that accepts two inputs: content and metadata of + the corresponding pub/sub message + topics (list): A list of topics that can typically be acquired via bec_lib.MessageEndpoints. + """ + + # Creating a unique key for this combination of topics + topics_key = tuple(sorted(topics)) + + if topics_key not in self._connections: + + def cb(msg): + msg = messages.MessageReader.loads(msg.value) + self._connections[topics_key].signal.emit(msg.content, msg.metadata) + + consumer = self.client.connector.consumer(topics=topics, cb=cb) + consumer.start() + + self._connections[topics_key] = _Connection(consumer) + + if slot not in self._connections[topics_key].slots: + self._connections[topics_key].signal.connect(slot) + self._connections[topics_key].slots.add(slot) + + def connect_slot( + self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False + ) -> None: """Connect widget's pyqt slot, so that it is called on new pub/sub topic message. Args: slot (Callable): A slot method/function that accepts two inputs: content and metadata of the corresponding pub/sub message topics (str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints + single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use + separate callbacks. """ if isinstance(topics, str): topics = [topics] - for topic in topics: - self._connect_slot_to_topic(slot, topic) + if single_callback_for_all_topics: + self._connect_slot_to_multiple_topics(slot, topics) + else: + for topic in topics: + self._connect_slot_to_topic(slot, topic) def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None: """A helper method to disconnect a slot from a specific topic. @@ -124,14 +158,14 @@ class _BECDispatcher(QObject): def disconnect_all(self): """Disconnect all slots from all topics.""" - for topic, connection in list(self._connections.items()): + for key, connection in list(self._connections.items()): for slot in list(connection.slots): - self.disconnect_slot(slot, topic) + self._disconnect_slot_from_topic(slot, key) # Updated to pass key # Check if the topic still exists before trying to shutdown and delete - if topic in self._connections and not connection.slots: + if key in self._connections and not connection.slots: connection.consumer.shutdown() - del self._connections[topic] + del self._connections[key] parser = argparse.ArgumentParser() diff --git a/tests/test_bec_dispatcher.py b/tests/test_bec_dispatcher.py index 1f24b919..2af5b62d 100644 --- a/tests/test_bec_dispatcher.py +++ b/tests/test_bec_dispatcher.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-function-docstring from unittest.mock import Mock import pytest @@ -189,3 +190,52 @@ def test_disconnect_all(bec_dispatcher, consumer): assert "topic0" not in bec_dispatcher._connections assert "topic1" not in bec_dispatcher._connections assert "topic2" not in bec_dispatcher._connections + + +def test_connect_one_slot_multiple_topics_single_callback(bec_dispatcher, consumer): + slot1 = Mock() + + # Connect the slot to multiple topics using a single callback + topics = ["topic1", "topic2"] + bec_dispatcher.connect_slot(slot=slot1, topics=topics, single_callback_for_all_topics=True) + + # Verify the initial state + assert len(bec_dispatcher._connections) == 1 # One connection for all topics + assert len(bec_dispatcher._connections[tuple(sorted(topics))].slots) == 1 # One slot connected + + # Simulate messages being published on each topic + for topic in topics: + msg_with_topic = MessageObject( + topic=topic, value=ScanMessage(point_id=0, scanID=0, data={}).dumps() + ) + consumer.call_args.kwargs["cb"](msg_with_topic) + + # Verify that the slot is called once for each topic + assert slot1.call_count == len(topics) + + # Verify that a single consumer is created for all topics + consumer.assert_called_once() + + +def test_disconnect_all_with_single_callback_for_multiple_topics(bec_dispatcher, consumer): + slot1 = Mock() + + # Connect the slot to multiple topics using a single callback + topics = ["topic1", "topic2"] + bec_dispatcher.connect_slot(slot=slot1, topics=topics, single_callback_for_all_topics=True) + + # Verify the initial state + assert len(bec_dispatcher._connections) == 1 # One connection for all topics + assert len(bec_dispatcher._connections[tuple(sorted(topics))].slots) == 1 # One slot connected + + # Call disconnect_all method + bec_dispatcher.disconnect_all() + + # Verify that the slot is disconnected + assert len(bec_dispatcher._connections) == 0 # All connections are removed + assert slot1.call_count == 0 # Slot has not been called + + # Simulate messages and verify that the slot is not called + msg = MessageObject(topic="topic1", value=ScanMessage(point_id=0, scanID=0, data={}).dumps()) + with pytest.raises(KeyError): + consumer.call_args.kwargs["cb"](msg)