From 7607d7a3b64b3861f4833c9b8f5afc360f31b38d Mon Sep 17 00:00:00 2001 From: wyzula-jan <133381102+wyzula-jan@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:02:22 +0100 Subject: [PATCH] fix: bec_dispatcher.py can partially disconnect topics from slot --- bec_widgets/utils/bec_dispatcher.py | 138 ++++++++++++---------------- tests/test_bec_dispatcher.py | 50 +++++----- 2 files changed, 86 insertions(+), 102 deletions(-) diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index f76e138d..b5e9258c 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -1,3 +1,7 @@ +# TODO last backup + +# todo super last refactor + from __future__ import annotations import argparse @@ -10,7 +14,7 @@ from bec_lib import BECClient, messages, ServiceConfig from bec_lib.redis_connector import RedisConsumerThreaded from qtpy.QtCore import QObject, Signal as pyqtSignal -# Adding a new pyqt signal requres a class factory, as they must be part of the class definition +# Adding a new pyqt signal requires a class factory, as they must be part of the class definition # and cannot be dynamically added as class attributes after the class has been defined. _signal_class_factory = ( type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count() @@ -29,6 +33,8 @@ class _Connection: class _BECDispatcher(QObject): + """Utility class to keep track of slots connected to a particular redis consumer""" + def __init__(self, bec_config=None): super().__init__() self.client = BECClient() @@ -41,64 +47,6 @@ class _BECDispatcher(QObject): self.client.initialize(config=ServiceConfig(config_path=bec_config)) self._connections = {} - def _connect_slot_to_topic(self, slot: Callable, topic: str) -> None: - """A helper method to connect a slot to a specific topic - - Args: - slot (Callable): A slot method/function that accepts two inputs: content and metadata of - the corresponding pub/sub message - topic (str): A topic that can typically be acquired via bec_lib.MessageEndpoints - """ - # create new connection for topic if it doesn't exist - if topic not in self._connections: - - def cb(msg): - msg = messages.MessageReader.loads(msg.value) - # TODO: this can could be replaced with a simple - # self._connections[topic].signal.emit(msg.content, msg.metadata) - # once all dispatcher.connect_slot calls are made with a single topic only - if not isinstance(msg, list): - msg = [msg] - for msg_i in msg: - self._connections[topic].signal.emit(msg_i.content, msg_i.metadata) - - consumer = self.client.connector.consumer(topics=topic, cb=cb) - consumer.start() - - self._connections[topic] = _Connection(consumer) - - # connect slot if it's not connected - if slot not in self._connections[topic].slots: - self._connections[topic].signal.connect(slot) - self._connections[topic].slots.add(slot) - - def _connect_slot_to_multiple_topics(self, slot: Callable, topics: list) -> None: - """ - A helper method to connect a slot to multiple topics using a single callback. - Args: - slot (Callable): A slot method/function that accepts two inputs: content and metadata of - the corresponding pub/sub message - topics (list): A list of topics that can typically be acquired via bec_lib.MessageEndpoints. - """ - - # Creating a unique key for this combination of topics - topics_key = tuple(sorted(topics)) - - if topics_key not in self._connections: - - def cb(msg): - msg = messages.MessageReader.loads(msg.value) - self._connections[topics_key].signal.emit(msg.content, msg.metadata) - - consumer = self.client.connector.consumer(topics=topics, cb=cb) - consumer.start() - - self._connections[topics_key] = _Connection(consumer) - - if slot not in self._connections[topics_key].slots: - self._connections[topics_key].signal.connect(slot) - self._connections[topics_key].slots.add(slot) - def connect_slot( self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False ) -> None: @@ -114,11 +62,33 @@ class _BECDispatcher(QObject): if isinstance(topics, str): topics = [topics] - if single_callback_for_all_topics: - self._connect_slot_to_multiple_topics(slot, topics) - else: - for topic in topics: - self._connect_slot_to_topic(slot, topic) + # Ensure topics_key is a tuple, whether single_callback_for_all_topics is True or False. + topics_key = tuple(sorted(topics)) if single_callback_for_all_topics else tuple(topics) + + if topics_key not in self._connections: + self._connections[topics_key] = self._create_connection(topics) + connection = self._connections[topics_key] + if slot not in connection.slots: + connection.signal.connect(slot) + connection.slots.add(slot) + + def _create_connection(self, topics: list) -> _Connection: + """Creates a new connection for given topics.""" + + def cb(msg): + msg = messages.MessageReader.loads(msg.value) + if not isinstance(msg, list): + msg = [msg] + for msg_i in msg: + for connection_key, connection in self._connections.items(): + if set(topics).intersection( + connection_key if isinstance(connection_key, tuple) else [connection_key] + ): + connection.signal.emit(msg_i.content, msg_i.metadata) + + consumer = self.client.connector.consumer(topics=topics, cb=cb) + consumer.start() + return _Connection(consumer) def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None: """A helper method to disconnect a slot from a specific topic. @@ -128,19 +98,13 @@ class _BECDispatcher(QObject): topic (str): A corresponding topic that can typically be acquired via bec_lib.MessageEndpoints """ - if topic not in self._connections: - return - - if slot not in self._connections[topic].slots: - return - - self._connections[topic].signal.disconnect(slot) - self._connections[topic].slots.remove(slot) - - if not self._connections[topic].slots: - # shutdown consumer if there are no more connected slots - self._connections[topic].consumer.shutdown() - del self._connections[topic] + connection = self._connections.get(topic) + if connection and slot in connection.slots: + connection.signal.disconnect(slot) + connection.slots.remove(slot) + if not connection.slots: + connection.consumer.shutdown() + del self._connections[topic] def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None: """Disconnect widget's pyqt slot from pub/sub updates on a topic. @@ -153,16 +117,28 @@ class _BECDispatcher(QObject): if isinstance(topics, str): topics = [topics] - for topic in topics: - self._disconnect_slot_from_topic(slot, topic) + for key, connection in list(self._connections.items()): + if slot in connection.slots: + common_topics = set(topics).intersection(key) + if common_topics: + remaining_topics = set(key) - set(topics) + # Disconnect slot from common topics + connection.signal.disconnect(slot) + connection.slots.remove(slot) + if not connection.slots: + print(f"{connection.consumer} is shutting down") + connection.consumer.shutdown() + del self._connections[key] + # Reconnect slot to remaining topics if any + if remaining_topics: + self.connect_slot(slot, list(remaining_topics), True) def disconnect_all(self): """Disconnect all slots from all topics.""" for key, connection in list(self._connections.items()): for slot in list(connection.slots): - self._disconnect_slot_from_topic(slot, key) # Updated to pass key + self._disconnect_slot_from_topic(slot, key) - # Check if the topic still exists before trying to shutdown and delete if key in self._connections and not connection.slots: connection.consumer.shutdown() del self._connections[key] diff --git a/tests/test_bec_dispatcher.py b/tests/test_bec_dispatcher.py index 2af5b62d..9020a8d6 100644 --- a/tests/test_bec_dispatcher.py +++ b/tests/test_bec_dispatcher.py @@ -78,29 +78,40 @@ def test_disconnect_one_slot_one_topic(bec_dispatcher, consumer): slot1, slot2 = Mock(), Mock() bec_dispatcher.connect_slot(slot=slot1, topics="topic0") - # disconnect using a different slot + # disconnect using a different topic bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 1 - # disconnect using a different topics + # disconnect using a different slot bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0") consumer.call_args.kwargs["cb"](msg) assert slot1.call_count == 2 # disconnect using the right slot and topics bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") - with pytest.raises(KeyError): - consumer.call_args.kwargs["cb"](msg) + # reset count to 0 for slot + slot1.reset_mock() + consumer.call_args.kwargs["cb"](msg) + assert slot1.call_count == 0 def test_disconnect_identical(bec_dispatcher, consumer): slot1 = Mock() + # Try to connect slot twice bec_dispatcher.connect_slot(slot=slot1, topics="topic0") bec_dispatcher.connect_slot(slot=slot1, topics="topic0") + + # Test to call the slot once (slot should be not connected twice) + consumer.call_args.kwargs["cb"](msg) + assert slot1.call_count == 1 + + # Disconnect the slot bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") - with pytest.raises(KeyError): - consumer.call_args.kwargs["cb"](msg) + + # Test to call the slot once (slot should be not connected anymore), count remains 1 + consumer.call_args.kwargs["cb"](msg) + assert slot1.call_count == 1 def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer): @@ -148,16 +159,17 @@ def test_disconnect_one_slot_many_topics(bec_dispatcher, consumer): # disconnect using the right slot and topics bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") - with pytest.raises(KeyError): - consumer.call_args_list[0].kwargs["cb"](msg) + # Calling disconnected topic0 should not call slot1 + consumer.call_args_list[0].kwargs["cb"](msg) + assert slot1.call_count == 4 + # Calling topic1 should still call slot1 consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 5 + # disconnect remaining topic1 from slot1, calling any topic should not increase count bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") - with pytest.raises(KeyError): - consumer.call_args_list[0].kwargs["cb"](msg) - with pytest.raises(KeyError): - consumer.call_args_list[1].kwargs["cb"](msg) + consumer.call_args_list[0].kwargs["cb"](msg) + consumer.call_args_list[1].kwargs["cb"](msg) assert slot1.call_count == 5 @@ -174,12 +186,9 @@ def test_disconnect_all(bec_dispatcher, consumer): bec_dispatcher.disconnect_all() # Simulate messages and verify that none of the slots are called - with pytest.raises(KeyError): - consumer.call_args_list[0].kwargs["cb"](msg) - with pytest.raises(KeyError): - consumer.call_args_list[1].kwargs["cb"](msg) - with pytest.raises(KeyError): - consumer.call_args_list[2].kwargs["cb"](msg) + consumer.call_args_list[0].kwargs["cb"](msg) + consumer.call_args_list[1].kwargs["cb"](msg) + consumer.call_args_list[2].kwargs["cb"](msg) # Ensure that the slots have not been called assert slot1.call_count == 0 @@ -236,6 +245,5 @@ def test_disconnect_all_with_single_callback_for_multiple_topics(bec_dispatcher, assert slot1.call_count == 0 # Slot has not been called # Simulate messages and verify that the slot is not called - msg = MessageObject(topic="topic1", value=ScanMessage(point_id=0, scanID=0, data={}).dumps()) - with pytest.raises(KeyError): - consumer.call_args.kwargs["cb"](msg) + consumer.call_args.kwargs["cb"](msg) + assert slot1.call_count == 0 # Slot has not been called