diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index 3c88f676..4e500ad2 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import argparse import itertools import os from collections.abc import Callable +from typing import Union from bec_lib import BECClient, messages, ServiceConfig from bec_lib.redis_connector import RedisConsumerThreaded @@ -38,8 +41,8 @@ class _BECDispatcher(QObject): self.client.initialize(config=ServiceConfig(config_path=bec_config)) self._connections = {} - def connect_slot(self, slot: Callable, topic: str) -> None: - """Connect widget's pyqt slot, so that it is called on new pub/sub topic message + def _connect_slot_to_topic(self, slot: Callable, topic: str) -> None: + """A helper method to connect a slot to a specific topic Args: slot (Callable): A slot method/function that accepts two inputs: content and metadata of @@ -69,8 +72,22 @@ class _BECDispatcher(QObject): self._connections[topic].signal.connect(slot) self._connections[topic].slots.add(slot) - def disconnect_slot(self, slot: Callable, topic: str) -> None: - """Disconnect widget's pyqt slot from pub/sub updates on a topic. + def connect_slot(self, slot: Callable, topics: Union[str, list]) -> None: + """Connect widget's pyqt slot, so that it is called on new pub/sub topic message. + + Args: + slot (Callable): A slot method/function that accepts two inputs: content and metadata of + the corresponding pub/sub message + topics (str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints + """ + if isinstance(topics, str): + topics = [topics] + + for topic in topics: + self._connect_slot_to_topic(slot, topic) + + def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None: + """A helper method to disconnect a slot from a specific topic. Args: slot (Callable): A slot to be disconnected @@ -91,6 +108,31 @@ class _BECDispatcher(QObject): self._connections[topic].consumer.shutdown() del self._connections[topic] + def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None: + """Disconnect widget's pyqt slot from pub/sub updates on a topic. + + Args: + slot (Callable): A slot to be disconnected + topics (str | list): A corresponding topic or list of topics that can typically be acquired via + bec_lib.MessageEndpoints + """ + if isinstance(topics, str): + topics = [topics] + + for topic in topics: + self._disconnect_slot_from_topic(slot, topic) + + def disconnect_all(self): + """Disconnect all slots from all topics.""" + for topic, connection in list(self._connections.items()): + for slot in list(connection.slots): + self.disconnect_slot(slot, topic) + + # Check if the topic still exists before trying to shutdown and delete + if topic in self._connections and not connection.slots: + connection.consumer.shutdown() + del self._connections[topic] + parser = argparse.ArgumentParser() parser.add_argument("--bec-config", default=None) diff --git a/tests/test_bec_dispatcher.py b/tests/test_bec_dispatcher.py index dd2d904b..1f24b919 100644 --- a/tests/test_bec_dispatcher.py +++ b/tests/test_bec_dispatcher.py @@ -26,7 +26,7 @@ def _consumer(bec_dispatcher): @pytest.mark.filterwarnings("ignore:Failed to connect to redis.") def test_connect_one_slot(bec_dispatcher, consumer): slot1 = Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") consumer.assert_called_once() # trigger consumer callback as if a message was published consumer.call_args.kwargs["cb"](msg) @@ -37,8 +37,8 @@ def test_connect_one_slot(bec_dispatcher, consumer): def test_connect_identical(bec_dispatcher, consumer): slot1 = Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") consumer.assert_called_once() consumer.call_args.kwargs["cb"](msg) @@ -47,9 +47,9 @@ def test_connect_identical(bec_dispatcher, consumer): def test_connect_many_slots_one_topic(bec_dispatcher, consumer): slot1, slot2 = Mock(), Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") consumer.assert_called_once() - bec_dispatcher.connect_slot(slot=slot2, topic="topic0") + bec_dispatcher.connect_slot(slot=slot2, topics="topic0") consumer.assert_called_once() # trigger consumer callback as if a message was published consumer.call_args.kwargs["cb"](msg) @@ -62,9 +62,9 @@ def test_connect_many_slots_one_topic(bec_dispatcher, consumer): def test_connect_one_slot_many_topics(bec_dispatcher, consumer): slot1 = Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") assert consumer.call_count == 1 - bec_dispatcher.connect_slot(slot=slot1, topic="topic1") + bec_dispatcher.connect_slot(slot=slot1, topics="topic1") assert consumer.call_count == 2 # trigger consumer callback as if a message was published consumer.call_args_list[0].kwargs["cb"](msg) @@ -75,52 +75,52 @@ def test_connect_one_slot_many_topics(bec_dispatcher, consumer): def test_disconnect_one_slot_one_topic(bec_dispatcher, consumer): slot1, slot2 = Mock(), Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") # disconnect using a different slot - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic1") + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 1 - # disconnect using a different topic - bec_dispatcher.disconnect_slot(slot=slot2, topic="topic0") + # disconnect using a different topics + bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 2 - # disconnect using the right slot and topic - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic0") + # disconnect using the right slot and topics + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") with pytest.raises(KeyError): consumer.call_args.kwargs["cb"](msg) def test_disconnect_identical(bec_dispatcher, consumer): slot1 = Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") with pytest.raises(KeyError): consumer.call_args.kwargs["cb"](msg) def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer): slot1, slot2, slot3 = Mock(), Mock(), Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") - bec_dispatcher.connect_slot(slot=slot2, topic="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.connect_slot(slot=slot2, topics="topic0") # disconnect using a different slot - bec_dispatcher.disconnect_slot(slot3, topic="topic0") + bec_dispatcher.disconnect_slot(slot3, topics="topic0") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 1 assert slot2.call_count == 1 - # disconnect using a different topic - bec_dispatcher.disconnect_slot(slot1, topic="topic1") + # disconnect using a different topics + bec_dispatcher.disconnect_slot(slot1, topics="topic1") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 2 assert slot2.call_count == 2 - # disconnect using the right slot and topic - bec_dispatcher.disconnect_slot(slot1, topic="topic0") + # disconnect using the right slot and topics + bec_dispatcher.disconnect_slot(slot1, topics="topic0") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 2 assert slot2.call_count == 3 @@ -128,33 +128,64 @@ def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer): def test_disconnect_one_slot_many_topics(bec_dispatcher, consumer): slot1, slot2 = Mock(), Mock() - bec_dispatcher.connect_slot(slot=slot1, topic="topic0") - bec_dispatcher.connect_slot(slot=slot1, topic="topic1") + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.connect_slot(slot=slot1, topics="topic1") # disconnect using a different slot - bec_dispatcher.disconnect_slot(slot=slot2, topic="topic0") + bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0") consumer.call_args_list[0].kwargs["cb"](msg) assert slot1.call_count == 1 consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 2 - # disconnect using a different topic - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic3") + # disconnect using a different topics + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic3") consumer.call_args_list[0].kwargs["cb"](msg) assert slot1.call_count == 3 consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 4 - # disconnect using the right slot and topic - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic0") + # disconnect using the right slot and topics + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") with pytest.raises(KeyError): consumer.call_args_list[0].kwargs["cb"](msg) consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 5 - bec_dispatcher.disconnect_slot(slot=slot1, topic="topic1") + bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") with pytest.raises(KeyError): consumer.call_args_list[0].kwargs["cb"](msg) with pytest.raises(KeyError): consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 5 + + +def test_disconnect_all(bec_dispatcher, consumer): + # Mock slots to connect + slot1, slot2, slot3 = Mock(), Mock(), Mock() + + # Connect slots to different topics + bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + bec_dispatcher.connect_slot(slot=slot2, topics="topic1") + bec_dispatcher.connect_slot(slot=slot3, topics="topic2") + + # Call disconnect_all method + bec_dispatcher.disconnect_all() + + # Simulate messages and verify that none of the slots are called + with pytest.raises(KeyError): + consumer.call_args_list[0].kwargs["cb"](msg) + with pytest.raises(KeyError): + consumer.call_args_list[1].kwargs["cb"](msg) + with pytest.raises(KeyError): + consumer.call_args_list[2].kwargs["cb"](msg) + + # Ensure that the slots have not been called + assert slot1.call_count == 0 + assert slot2.call_count == 0 + assert slot3.call_count == 0 + + # Also, check that the consumer for each topic is shutdown + assert "topic0" not in bec_dispatcher._connections + assert "topic1" not in bec_dispatcher._connections + assert "topic2" not in bec_dispatcher._connections