mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 03:31:50 +02:00
fix(bec-dispatcher): fix reference to boundmethods to avoid duplicated subscriptions
This commit is contained in:
@ -4,8 +4,9 @@ import collections
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, DefaultDict, Hashable, Union
|
||||||
|
|
||||||
|
import louie
|
||||||
import redis
|
import redis
|
||||||
from bec_lib.client import BECClient
|
from bec_lib.client import BECClient
|
||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
@ -41,15 +42,25 @@ class QtThreadSafeCallback(QObject):
|
|||||||
self.cb_info = cb_info
|
self.cb_info = cb_info
|
||||||
|
|
||||||
self.cb = cb
|
self.cb = cb
|
||||||
|
self.cb_ref = louie.saferef.safe_ref(cb)
|
||||||
self.cb_signal.connect(self.cb)
|
self.cb_signal.connect(self.cb)
|
||||||
|
self.topics = set()
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
# make 2 differents QtThreadSafeCallback to look
|
# make 2 differents QtThreadSafeCallback to look
|
||||||
# identical when used as dictionary keys, if the
|
# identical when used as dictionary keys, if the
|
||||||
# callback is the same
|
# callback is the same
|
||||||
return f"{id(self.cb)}{self.cb_info}".__hash__()
|
return f"{id(self.cb_ref)}{self.cb_info}".__hash__()
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, QtThreadSafeCallback):
|
||||||
|
return False
|
||||||
|
return self.cb_ref == other.cb_ref and self.cb_info == other.cb_info
|
||||||
|
|
||||||
def __call__(self, msg_content, metadata):
|
def __call__(self, msg_content, metadata):
|
||||||
|
if self.cb_ref() is None:
|
||||||
|
# callback has been deleted
|
||||||
|
return
|
||||||
self.cb_signal.emit(msg_content, metadata)
|
self.cb_signal.emit(msg_content, metadata)
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +107,7 @@ class BECDispatcher:
|
|||||||
cls,
|
cls,
|
||||||
client=None,
|
client=None,
|
||||||
config: str | ServiceConfig | None = None,
|
config: str | ServiceConfig | None = None,
|
||||||
gui_id: str = None,
|
gui_id: str | None = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -109,7 +120,9 @@ class BECDispatcher:
|
|||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._slots = collections.defaultdict(set)
|
self._registered_slots: DefaultDict[Hashable, QtThreadSafeCallback] = (
|
||||||
|
collections.defaultdict()
|
||||||
|
)
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
if self.client is None:
|
if self.client is None:
|
||||||
@ -162,10 +175,13 @@ class BECDispatcher:
|
|||||||
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
||||||
cb_info (dict | None): A dictionary containing information about the callback. Defaults to None.
|
cb_info (dict | None): A dictionary containing information about the callback. Defaults to None.
|
||||||
"""
|
"""
|
||||||
slot = QtThreadSafeCallback(cb=slot, cb_info=cb_info)
|
qt_slot = QtThreadSafeCallback(cb=slot, cb_info=cb_info)
|
||||||
self.client.connector.register(topics, cb=slot, **kwargs)
|
if qt_slot not in self._registered_slots:
|
||||||
|
self._registered_slots[qt_slot] = qt_slot
|
||||||
|
qt_slot = self._registered_slots[qt_slot]
|
||||||
|
self.client.connector.register(topics, cb=qt_slot, **kwargs)
|
||||||
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
self._slots[slot].update(set(topics_str))
|
qt_slot.topics.update(set(topics_str))
|
||||||
|
|
||||||
def disconnect_slot(self, slot: Callable, topics: Union[str, list]):
|
def disconnect_slot(self, slot: Callable, topics: Union[str, list]):
|
||||||
"""
|
"""
|
||||||
@ -178,16 +194,16 @@ class BECDispatcher:
|
|||||||
# find the right slot to disconnect from ;
|
# find the right slot to disconnect from ;
|
||||||
# slot callbacks are wrapped in QtThreadSafeCallback objects,
|
# slot callbacks are wrapped in QtThreadSafeCallback objects,
|
||||||
# but the slot we receive here is the original callable
|
# but the slot we receive here is the original callable
|
||||||
for connected_slot in self._slots:
|
for connected_slot in self._registered_slots.values():
|
||||||
if connected_slot.cb == slot:
|
if connected_slot.cb == slot:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
self.client.connector.unregister(topics, cb=connected_slot)
|
self.client.connector.unregister(topics, cb=connected_slot)
|
||||||
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
self._slots[connected_slot].difference_update(set(topics_str))
|
self._registered_slots[connected_slot].topics.difference_update(set(topics_str))
|
||||||
if not self._slots[connected_slot]:
|
if not self._registered_slots[connected_slot].topics:
|
||||||
del self._slots[connected_slot]
|
del self._registered_slots[connected_slot]
|
||||||
|
|
||||||
def disconnect_topics(self, topics: Union[str, list]):
|
def disconnect_topics(self, topics: Union[str, list]):
|
||||||
"""
|
"""
|
||||||
@ -198,11 +214,16 @@ class BECDispatcher:
|
|||||||
"""
|
"""
|
||||||
self.client.connector.unregister(topics)
|
self.client.connector.unregister(topics)
|
||||||
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
for slot in list(self._slots.keys()):
|
|
||||||
slot_topics = self._slots[slot]
|
remove_slots = []
|
||||||
slot_topics.difference_update(set(topics_str))
|
for connected_slot in self._registered_slots.values():
|
||||||
if not slot_topics:
|
connected_slot.topics.difference_update(set(topics_str))
|
||||||
del self._slots[slot]
|
|
||||||
|
if not connected_slot.topics:
|
||||||
|
remove_slots.append(connected_slot)
|
||||||
|
|
||||||
|
for connected_slot in remove_slots:
|
||||||
|
self._registered_slots.pop(connected_slot, None)
|
||||||
|
|
||||||
def disconnect_all(self, *args, **kwargs):
|
def disconnect_all(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
from bec_lib.messages import ScanMessage
|
from bec_lib.messages import ScanMessage
|
||||||
from bec_lib.serialization import MsgpackSerialization
|
from bec_lib.serialization import MsgpackSerialization
|
||||||
|
|
||||||
from bec_widgets.utils.bec_dispatcher import QtRedisConnector
|
from bec_widgets.utils.bec_dispatcher import QtRedisConnector, QtThreadSafeCallback
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -27,6 +27,7 @@ def bec_dispatcher_w_connector(bec_dispatcher, topics_msg_list, send_msg_event):
|
|||||||
connector = QtRedisConnector("localhost:1", redis_class_mock)
|
connector = QtRedisConnector("localhost:1", redis_class_mock)
|
||||||
bec_dispatcher.client.connector = connector
|
bec_dispatcher.client.connector = connector
|
||||||
yield bec_dispatcher
|
yield bec_dispatcher
|
||||||
|
connector.shutdown()
|
||||||
|
|
||||||
|
|
||||||
dummy_msg = MsgpackSerialization.dumps(ScanMessage(point_id=0, scan_id="0", data={}))
|
dummy_msg = MsgpackSerialization.dumps(ScanMessage(point_id=0, scan_id="0", data={}))
|
||||||
@ -62,7 +63,6 @@ def test_dispatcher_disconnect_all(bec_dispatcher_w_connector, qtbot, send_msg_e
|
|||||||
|
|
||||||
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))])
|
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))])
|
||||||
def test_dispatcher_disconnect_one(bec_dispatcher_w_connector, qtbot, send_msg_event):
|
def test_dispatcher_disconnect_one(bec_dispatcher_w_connector, qtbot, send_msg_event):
|
||||||
# test for BEC issue #276
|
|
||||||
bec_dispatcher = bec_dispatcher_w_connector
|
bec_dispatcher = bec_dispatcher_w_connector
|
||||||
cb1 = mock.Mock(spec=[])
|
cb1 = mock.Mock(spec=[])
|
||||||
cb2 = mock.Mock(spec=[])
|
cb2 = mock.Mock(spec=[])
|
||||||
@ -86,12 +86,21 @@ def test_dispatcher_2_cb_same_topic(bec_dispatcher_w_connector, qtbot, send_msg_
|
|||||||
cb1 = mock.Mock(spec=[])
|
cb1 = mock.Mock(spec=[])
|
||||||
cb2 = mock.Mock(spec=[])
|
cb2 = mock.Mock(spec=[])
|
||||||
|
|
||||||
|
num_slots = len(bec_dispatcher._registered_slots)
|
||||||
|
|
||||||
bec_dispatcher.connect_slot(cb1, "topic1")
|
bec_dispatcher.connect_slot(cb1, "topic1")
|
||||||
bec_dispatcher.connect_slot(cb2, "topic1")
|
bec_dispatcher.connect_slot(cb2, "topic1")
|
||||||
|
|
||||||
|
# The redis connector should only subscribe once to the topic
|
||||||
assert len(bec_dispatcher.client.connector._topics_cb) == 1
|
assert len(bec_dispatcher.client.connector._topics_cb) == 1
|
||||||
assert len(bec_dispatcher._slots) == 2
|
|
||||||
|
# The the given topic, two callbacks should be registered
|
||||||
|
assert len(bec_dispatcher.client.connector._topics_cb["topic1"]) == 2
|
||||||
|
|
||||||
|
# The dispatcher should have two slots
|
||||||
|
assert len(bec_dispatcher._registered_slots) == num_slots + 2
|
||||||
bec_dispatcher.disconnect_slot(cb1, "topic1")
|
bec_dispatcher.disconnect_slot(cb1, "topic1")
|
||||||
assert len(bec_dispatcher._slots) == 1
|
assert len(bec_dispatcher._registered_slots) == num_slots + 1
|
||||||
|
|
||||||
send_msg_event.set()
|
send_msg_event.set()
|
||||||
qtbot.wait(10)
|
qtbot.wait(10)
|
||||||
@ -99,9 +108,31 @@ def test_dispatcher_2_cb_same_topic(bec_dispatcher_w_connector, qtbot, send_msg_
|
|||||||
cb2.assert_called_once()
|
cb2.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg),)])
|
||||||
|
def test_dispatcher_2_cb_same_topic_same_slot(bec_dispatcher_w_connector, qtbot, send_msg_event):
|
||||||
|
bec_dispatcher = bec_dispatcher_w_connector
|
||||||
|
cb1 = mock.Mock(spec=[])
|
||||||
|
|
||||||
|
bec_dispatcher.connect_slot(cb1, "topic1")
|
||||||
|
bec_dispatcher.connect_slot(cb1, "topic1")
|
||||||
|
assert len(bec_dispatcher.client.connector._topics_cb) == 1
|
||||||
|
assert (
|
||||||
|
len(list(filter(lambda slot: slot.cb == cb1, bec_dispatcher._registered_slots.values())))
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
send_msg_event.set()
|
||||||
|
qtbot.wait(10)
|
||||||
|
assert cb1.call_count == 1
|
||||||
|
bec_dispatcher.disconnect_slot(cb1, "topic1")
|
||||||
|
assert (
|
||||||
|
len(list(filter(lambda slot: slot.cb == cb1, bec_dispatcher._registered_slots.values())))
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))])
|
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))])
|
||||||
def test_dispatcher_2_topic_same_cb(bec_dispatcher_w_connector, qtbot, send_msg_event):
|
def test_dispatcher_2_topic_same_cb(bec_dispatcher_w_connector, qtbot, send_msg_event):
|
||||||
# test for BEC issue #276
|
|
||||||
bec_dispatcher = bec_dispatcher_w_connector
|
bec_dispatcher = bec_dispatcher_w_connector
|
||||||
cb1 = mock.Mock(spec=[])
|
cb1 = mock.Mock(spec=[])
|
||||||
|
|
||||||
@ -114,3 +145,36 @@ def test_dispatcher_2_topic_same_cb(bec_dispatcher_w_connector, qtbot, send_msg_
|
|||||||
send_msg_event.set()
|
send_msg_event.set()
|
||||||
qtbot.wait(10)
|
qtbot.wait(10)
|
||||||
cb1.assert_called_once()
|
cb1.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("topics_msg_list", [(("topic1", dummy_msg), ("topic2", dummy_msg))])
|
||||||
|
def test_dispatcher_2_topic_same_cb_with_boundmethod(
|
||||||
|
bec_dispatcher_w_connector, qtbot, send_msg_event
|
||||||
|
):
|
||||||
|
bec_dispatcher = bec_dispatcher_w_connector
|
||||||
|
|
||||||
|
class MockObject:
|
||||||
|
def mock_slot(self, msg, metadata):
|
||||||
|
pass
|
||||||
|
|
||||||
|
cb1 = MockObject()
|
||||||
|
|
||||||
|
bec_dispatcher.connect_slot(cb1.mock_slot, "topic1", {"metadata": "test"})
|
||||||
|
bec_dispatcher.connect_slot(cb1.mock_slot, "topic1", {"metadata": "test"})
|
||||||
|
|
||||||
|
def _get_slots():
|
||||||
|
return list(
|
||||||
|
filter(
|
||||||
|
lambda slot: slot == QtThreadSafeCallback(cb1.mock_slot, {"metadata": "test"}),
|
||||||
|
bec_dispatcher._registered_slots.values(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(bec_dispatcher.client.connector._topics_cb) == 1
|
||||||
|
assert len(_get_slots()) == 1
|
||||||
|
bec_dispatcher.disconnect_slot(cb1.mock_slot, "topic1")
|
||||||
|
assert len(bec_dispatcher.client.connector._topics_cb) == 0
|
||||||
|
assert len(_get_slots()) == 0
|
||||||
|
|
||||||
|
send_msg_event.set()
|
||||||
|
qtbot.wait(10)
|
||||||
|
Reference in New Issue
Block a user