diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index 4506646c..1ab6c15d 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -4,8 +4,9 @@ import collections import random import string from collections.abc import Callable -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, DefaultDict, Hashable, Union +import louie import redis from bec_lib.client import BECClient from bec_lib.logger import bec_logger @@ -41,15 +42,25 @@ class QtThreadSafeCallback(QObject): self.cb_info = cb_info self.cb = cb + self.cb_ref = louie.saferef.safe_ref(cb) self.cb_signal.connect(self.cb) + self.topics = set() def __hash__(self): # make 2 differents QtThreadSafeCallback to look # identical when used as dictionary keys, if the # callback is the same - return f"{id(self.cb)}{self.cb_info}".__hash__() + return f"{id(self.cb_ref)}{self.cb_info}".__hash__() + + def __eq__(self, other): + if not isinstance(other, QtThreadSafeCallback): + return False + return self.cb_ref == other.cb_ref and self.cb_info == other.cb_info def __call__(self, msg_content, metadata): + if self.cb_ref() is None: + # callback has been deleted + return self.cb_signal.emit(msg_content, metadata) @@ -96,7 +107,7 @@ class BECDispatcher: cls, client=None, config: str | ServiceConfig | None = None, - gui_id: str = None, + gui_id: str | None = None, *args, **kwargs, ): @@ -109,7 +120,9 @@ class BECDispatcher: if self._initialized: return - self._slots = collections.defaultdict(set) + self._registered_slots: DefaultDict[Hashable, QtThreadSafeCallback] = ( + collections.defaultdict() + ) self.client = client if self.client is None: @@ -162,10 +175,13 @@ class BECDispatcher: topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints cb_info (dict | None): A dictionary containing information about the callback. Defaults to None. """ - slot = QtThreadSafeCallback(cb=slot, cb_info=cb_info) - self.client.connector.register(topics, cb=slot, **kwargs) + qt_slot = QtThreadSafeCallback(cb=slot, cb_info=cb_info) + if qt_slot not in self._registered_slots: + self._registered_slots[qt_slot] = qt_slot + qt_slot = self._registered_slots[qt_slot] + self.client.connector.register(topics, cb=qt_slot, **kwargs) topics_str, _ = self.client.connector._convert_endpointinfo(topics) - self._slots[slot].update(set(topics_str)) + qt_slot.topics.update(set(topics_str)) def disconnect_slot(self, slot: Callable, topics: Union[str, list]): """ @@ -178,16 +194,16 @@ class BECDispatcher: # find the right slot to disconnect from ; # slot callbacks are wrapped in QtThreadSafeCallback objects, # but the slot we receive here is the original callable - for connected_slot in self._slots: + for connected_slot in self._registered_slots.values(): if connected_slot.cb == slot: break else: return self.client.connector.unregister(topics, cb=connected_slot) topics_str, _ = self.client.connector._convert_endpointinfo(topics) - self._slots[connected_slot].difference_update(set(topics_str)) - if not self._slots[connected_slot]: - del self._slots[connected_slot] + self._registered_slots[connected_slot].topics.difference_update(set(topics_str)) + if not self._registered_slots[connected_slot].topics: + del self._registered_slots[connected_slot] def disconnect_topics(self, topics: Union[str, list]): """ @@ -198,11 +214,16 @@ class BECDispatcher: """ self.client.connector.unregister(topics) topics_str, _ = self.client.connector._convert_endpointinfo(topics) - for slot in list(self._slots.keys()): - slot_topics = self._slots[slot] - slot_topics.difference_update(set(topics_str)) - if not slot_topics: - del self._slots[slot] + + remove_slots = [] + for connected_slot in self._registered_slots.values(): + connected_slot.topics.difference_update(set(topics_str)) + + if not connected_slot.topics: + remove_slots.append(connected_slot) + + for connected_slot in remove_slots: + self._registered_slots.pop(connected_slot, None) def disconnect_all(self, *args, **kwargs): """ diff --git a/tests/unit_tests/test_bec_dispatcher.py b/tests/unit_tests/test_bec_dispatcher.py index 8ca0aa3b..69b28a50 100644 --- a/tests/unit_tests/test_bec_dispatcher.py +++ b/tests/unit_tests/test_bec_dispatcher.py @@ -7,7 +7,7 @@ import pytest from bec_lib.messages import ScanMessage from bec_lib.serialization import MsgpackSerialization -from bec_widgets.utils.bec_dispatcher import QtRedisConnector +from bec_widgets.utils.bec_dispatcher import QtRedisConnector, QtThreadSafeCallback @pytest.fixture @@ -27,6 +27,7 @@ def bec_dispatcher_w_connector(bec_dispatcher, topics_msg_list, send_msg_event): connector = QtRedisConnector("localhost:1", redis_class_mock) bec_dispatcher.client.connector = connector yield bec_dispatcher + connector.shutdown() dummy_msg = MsgpackSerialization.dumps(ScanMessage(point_id=0, scan_id="0", data={})) @@ -62,7 +63,6 @@ def test_dispatcher_disconnect_all(bec_dispatcher_w_connector, qtbot, send_msg_e @pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))]) def test_dispatcher_disconnect_one(bec_dispatcher_w_connector, qtbot, send_msg_event): - # test for BEC issue #276 bec_dispatcher = bec_dispatcher_w_connector cb1 = mock.Mock(spec=[]) cb2 = mock.Mock(spec=[]) @@ -86,12 +86,21 @@ def test_dispatcher_2_cb_same_topic(bec_dispatcher_w_connector, qtbot, send_msg_ cb1 = mock.Mock(spec=[]) cb2 = mock.Mock(spec=[]) + num_slots = len(bec_dispatcher._registered_slots) + bec_dispatcher.connect_slot(cb1, "topic1") bec_dispatcher.connect_slot(cb2, "topic1") + + # The redis connector should only subscribe once to the topic assert len(bec_dispatcher.client.connector._topics_cb) == 1 - assert len(bec_dispatcher._slots) == 2 + + # The the given topic, two callbacks should be registered + assert len(bec_dispatcher.client.connector._topics_cb["topic1"]) == 2 + + # The dispatcher should have two slots + assert len(bec_dispatcher._registered_slots) == num_slots + 2 bec_dispatcher.disconnect_slot(cb1, "topic1") - assert len(bec_dispatcher._slots) == 1 + assert len(bec_dispatcher._registered_slots) == num_slots + 1 send_msg_event.set() qtbot.wait(10) @@ -99,9 +108,31 @@ def test_dispatcher_2_cb_same_topic(bec_dispatcher_w_connector, qtbot, send_msg_ cb2.assert_called_once() +@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg),)]) +def test_dispatcher_2_cb_same_topic_same_slot(bec_dispatcher_w_connector, qtbot, send_msg_event): + bec_dispatcher = bec_dispatcher_w_connector + cb1 = mock.Mock(spec=[]) + + bec_dispatcher.connect_slot(cb1, "topic1") + bec_dispatcher.connect_slot(cb1, "topic1") + assert len(bec_dispatcher.client.connector._topics_cb) == 1 + assert ( + len(list(filter(lambda slot: slot.cb == cb1, bec_dispatcher._registered_slots.values()))) + == 1 + ) + + send_msg_event.set() + qtbot.wait(10) + assert cb1.call_count == 1 + bec_dispatcher.disconnect_slot(cb1, "topic1") + assert ( + len(list(filter(lambda slot: slot.cb == cb1, bec_dispatcher._registered_slots.values()))) + == 0 + ) + + @pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))]) def test_dispatcher_2_topic_same_cb(bec_dispatcher_w_connector, qtbot, send_msg_event): - # test for BEC issue #276 bec_dispatcher = bec_dispatcher_w_connector cb1 = mock.Mock(spec=[]) @@ -114,3 +145,36 @@ def test_dispatcher_2_topic_same_cb(bec_dispatcher_w_connector, qtbot, send_msg_ send_msg_event.set() qtbot.wait(10) cb1.assert_called_once() + + +@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))]) +def test_dispatcher_2_topic_same_cb_with_boundmethod( + bec_dispatcher_w_connector, qtbot, send_msg_event +): + bec_dispatcher = bec_dispatcher_w_connector + + class MockObject: + def mock_slot(self, msg, metadata): + pass + + cb1 = MockObject() + + bec_dispatcher.connect_slot(cb1.mock_slot, "topic1", {"metadata": "test"}) + bec_dispatcher.connect_slot(cb1.mock_slot, "topic1", {"metadata": "test"}) + + def _get_slots(): + return list( + filter( + lambda slot: slot == QtThreadSafeCallback(cb1.mock_slot, {"metadata": "test"}), + bec_dispatcher._registered_slots.values(), + ) + ) + + assert len(bec_dispatcher.client.connector._topics_cb) == 1 + assert len(_get_slots()) == 1 + bec_dispatcher.disconnect_slot(cb1.mock_slot, "topic1") + assert len(bec_dispatcher.client.connector._topics_cb) == 0 + assert len(_get_slots()) == 0 + + send_msg_event.set() + qtbot.wait(10)