mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 11:41:49 +02:00
fix: bec_dispatcher.py can partially disconnect topics from slot
This commit is contained in:
@ -1,3 +1,7 @@
|
|||||||
|
# TODO last backup
|
||||||
|
|
||||||
|
# todo super last refactor
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -10,7 +14,7 @@ from bec_lib import BECClient, messages, ServiceConfig
|
|||||||
from bec_lib.redis_connector import RedisConsumerThreaded
|
from bec_lib.redis_connector import RedisConsumerThreaded
|
||||||
from qtpy.QtCore import QObject, Signal as pyqtSignal
|
from qtpy.QtCore import QObject, Signal as pyqtSignal
|
||||||
|
|
||||||
# Adding a new pyqt signal requres a class factory, as they must be part of the class definition
|
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition
|
||||||
# and cannot be dynamically added as class attributes after the class has been defined.
|
# and cannot be dynamically added as class attributes after the class has been defined.
|
||||||
_signal_class_factory = (
|
_signal_class_factory = (
|
||||||
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
|
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
|
||||||
@ -29,6 +33,8 @@ class _Connection:
|
|||||||
|
|
||||||
|
|
||||||
class _BECDispatcher(QObject):
|
class _BECDispatcher(QObject):
|
||||||
|
"""Utility class to keep track of slots connected to a particular redis consumer"""
|
||||||
|
|
||||||
def __init__(self, bec_config=None):
|
def __init__(self, bec_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = BECClient()
|
self.client = BECClient()
|
||||||
@ -41,64 +47,6 @@ class _BECDispatcher(QObject):
|
|||||||
self.client.initialize(config=ServiceConfig(config_path=bec_config))
|
self.client.initialize(config=ServiceConfig(config_path=bec_config))
|
||||||
self._connections = {}
|
self._connections = {}
|
||||||
|
|
||||||
def _connect_slot_to_topic(self, slot: Callable, topic: str) -> None:
|
|
||||||
"""A helper method to connect a slot to a specific topic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
|
|
||||||
the corresponding pub/sub message
|
|
||||||
topic (str): A topic that can typically be acquired via bec_lib.MessageEndpoints
|
|
||||||
"""
|
|
||||||
# create new connection for topic if it doesn't exist
|
|
||||||
if topic not in self._connections:
|
|
||||||
|
|
||||||
def cb(msg):
|
|
||||||
msg = messages.MessageReader.loads(msg.value)
|
|
||||||
# TODO: this can could be replaced with a simple
|
|
||||||
# self._connections[topic].signal.emit(msg.content, msg.metadata)
|
|
||||||
# once all dispatcher.connect_slot calls are made with a single topic only
|
|
||||||
if not isinstance(msg, list):
|
|
||||||
msg = [msg]
|
|
||||||
for msg_i in msg:
|
|
||||||
self._connections[topic].signal.emit(msg_i.content, msg_i.metadata)
|
|
||||||
|
|
||||||
consumer = self.client.connector.consumer(topics=topic, cb=cb)
|
|
||||||
consumer.start()
|
|
||||||
|
|
||||||
self._connections[topic] = _Connection(consumer)
|
|
||||||
|
|
||||||
# connect slot if it's not connected
|
|
||||||
if slot not in self._connections[topic].slots:
|
|
||||||
self._connections[topic].signal.connect(slot)
|
|
||||||
self._connections[topic].slots.add(slot)
|
|
||||||
|
|
||||||
def _connect_slot_to_multiple_topics(self, slot: Callable, topics: list) -> None:
|
|
||||||
"""
|
|
||||||
A helper method to connect a slot to multiple topics using a single callback.
|
|
||||||
Args:
|
|
||||||
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
|
|
||||||
the corresponding pub/sub message
|
|
||||||
topics (list): A list of topics that can typically be acquired via bec_lib.MessageEndpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Creating a unique key for this combination of topics
|
|
||||||
topics_key = tuple(sorted(topics))
|
|
||||||
|
|
||||||
if topics_key not in self._connections:
|
|
||||||
|
|
||||||
def cb(msg):
|
|
||||||
msg = messages.MessageReader.loads(msg.value)
|
|
||||||
self._connections[topics_key].signal.emit(msg.content, msg.metadata)
|
|
||||||
|
|
||||||
consumer = self.client.connector.consumer(topics=topics, cb=cb)
|
|
||||||
consumer.start()
|
|
||||||
|
|
||||||
self._connections[topics_key] = _Connection(consumer)
|
|
||||||
|
|
||||||
if slot not in self._connections[topics_key].slots:
|
|
||||||
self._connections[topics_key].signal.connect(slot)
|
|
||||||
self._connections[topics_key].slots.add(slot)
|
|
||||||
|
|
||||||
def connect_slot(
|
def connect_slot(
|
||||||
self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False
|
self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -114,11 +62,33 @@ class _BECDispatcher(QObject):
|
|||||||
if isinstance(topics, str):
|
if isinstance(topics, str):
|
||||||
topics = [topics]
|
topics = [topics]
|
||||||
|
|
||||||
if single_callback_for_all_topics:
|
# Ensure topics_key is a tuple, whether single_callback_for_all_topics is True or False.
|
||||||
self._connect_slot_to_multiple_topics(slot, topics)
|
topics_key = tuple(sorted(topics)) if single_callback_for_all_topics else tuple(topics)
|
||||||
else:
|
|
||||||
for topic in topics:
|
if topics_key not in self._connections:
|
||||||
self._connect_slot_to_topic(slot, topic)
|
self._connections[topics_key] = self._create_connection(topics)
|
||||||
|
connection = self._connections[topics_key]
|
||||||
|
if slot not in connection.slots:
|
||||||
|
connection.signal.connect(slot)
|
||||||
|
connection.slots.add(slot)
|
||||||
|
|
||||||
|
def _create_connection(self, topics: list) -> _Connection:
|
||||||
|
"""Creates a new connection for given topics."""
|
||||||
|
|
||||||
|
def cb(msg):
|
||||||
|
msg = messages.MessageReader.loads(msg.value)
|
||||||
|
if not isinstance(msg, list):
|
||||||
|
msg = [msg]
|
||||||
|
for msg_i in msg:
|
||||||
|
for connection_key, connection in self._connections.items():
|
||||||
|
if set(topics).intersection(
|
||||||
|
connection_key if isinstance(connection_key, tuple) else [connection_key]
|
||||||
|
):
|
||||||
|
connection.signal.emit(msg_i.content, msg_i.metadata)
|
||||||
|
|
||||||
|
consumer = self.client.connector.consumer(topics=topics, cb=cb)
|
||||||
|
consumer.start()
|
||||||
|
return _Connection(consumer)
|
||||||
|
|
||||||
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
|
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
|
||||||
"""A helper method to disconnect a slot from a specific topic.
|
"""A helper method to disconnect a slot from a specific topic.
|
||||||
@ -128,19 +98,13 @@ class _BECDispatcher(QObject):
|
|||||||
topic (str): A corresponding topic that can typically be acquired via
|
topic (str): A corresponding topic that can typically be acquired via
|
||||||
bec_lib.MessageEndpoints
|
bec_lib.MessageEndpoints
|
||||||
"""
|
"""
|
||||||
if topic not in self._connections:
|
connection = self._connections.get(topic)
|
||||||
return
|
if connection and slot in connection.slots:
|
||||||
|
connection.signal.disconnect(slot)
|
||||||
if slot not in self._connections[topic].slots:
|
connection.slots.remove(slot)
|
||||||
return
|
if not connection.slots:
|
||||||
|
connection.consumer.shutdown()
|
||||||
self._connections[topic].signal.disconnect(slot)
|
del self._connections[topic]
|
||||||
self._connections[topic].slots.remove(slot)
|
|
||||||
|
|
||||||
if not self._connections[topic].slots:
|
|
||||||
# shutdown consumer if there are no more connected slots
|
|
||||||
self._connections[topic].consumer.shutdown()
|
|
||||||
del self._connections[topic]
|
|
||||||
|
|
||||||
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
|
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
|
||||||
"""Disconnect widget's pyqt slot from pub/sub updates on a topic.
|
"""Disconnect widget's pyqt slot from pub/sub updates on a topic.
|
||||||
@ -153,16 +117,28 @@ class _BECDispatcher(QObject):
|
|||||||
if isinstance(topics, str):
|
if isinstance(topics, str):
|
||||||
topics = [topics]
|
topics = [topics]
|
||||||
|
|
||||||
for topic in topics:
|
for key, connection in list(self._connections.items()):
|
||||||
self._disconnect_slot_from_topic(slot, topic)
|
if slot in connection.slots:
|
||||||
|
common_topics = set(topics).intersection(key)
|
||||||
|
if common_topics:
|
||||||
|
remaining_topics = set(key) - set(topics)
|
||||||
|
# Disconnect slot from common topics
|
||||||
|
connection.signal.disconnect(slot)
|
||||||
|
connection.slots.remove(slot)
|
||||||
|
if not connection.slots:
|
||||||
|
print(f"{connection.consumer} is shutting down")
|
||||||
|
connection.consumer.shutdown()
|
||||||
|
del self._connections[key]
|
||||||
|
# Reconnect slot to remaining topics if any
|
||||||
|
if remaining_topics:
|
||||||
|
self.connect_slot(slot, list(remaining_topics), True)
|
||||||
|
|
||||||
def disconnect_all(self):
|
def disconnect_all(self):
|
||||||
"""Disconnect all slots from all topics."""
|
"""Disconnect all slots from all topics."""
|
||||||
for key, connection in list(self._connections.items()):
|
for key, connection in list(self._connections.items()):
|
||||||
for slot in list(connection.slots):
|
for slot in list(connection.slots):
|
||||||
self._disconnect_slot_from_topic(slot, key) # Updated to pass key
|
self._disconnect_slot_from_topic(slot, key)
|
||||||
|
|
||||||
# Check if the topic still exists before trying to shutdown and delete
|
|
||||||
if key in self._connections and not connection.slots:
|
if key in self._connections and not connection.slots:
|
||||||
connection.consumer.shutdown()
|
connection.consumer.shutdown()
|
||||||
del self._connections[key]
|
del self._connections[key]
|
||||||
|
@ -78,29 +78,40 @@ def test_disconnect_one_slot_one_topic(bec_dispatcher, consumer):
|
|||||||
slot1, slot2 = Mock(), Mock()
|
slot1, slot2 = Mock(), Mock()
|
||||||
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
||||||
|
|
||||||
# disconnect using a different slot
|
# disconnect using a different topic
|
||||||
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
|
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
|
||||||
consumer.call_args.kwargs["cb"](msg)
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
assert slot1.call_count == 1
|
assert slot1.call_count == 1
|
||||||
|
|
||||||
# disconnect using a different topics
|
# disconnect using a different slot
|
||||||
bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0")
|
bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0")
|
||||||
consumer.call_args.kwargs["cb"](msg)
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
assert slot1.call_count == 2
|
assert slot1.call_count == 2
|
||||||
|
|
||||||
# disconnect using the right slot and topics
|
# disconnect using the right slot and topics
|
||||||
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
||||||
with pytest.raises(KeyError):
|
# reset count to 0 for slot
|
||||||
consumer.call_args.kwargs["cb"](msg)
|
slot1.reset_mock()
|
||||||
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
|
assert slot1.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_identical(bec_dispatcher, consumer):
|
def test_disconnect_identical(bec_dispatcher, consumer):
|
||||||
slot1 = Mock()
|
slot1 = Mock()
|
||||||
|
# Try to connect slot twice
|
||||||
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
||||||
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
|
||||||
|
|
||||||
|
# Test to call the slot once (slot should be not connected twice)
|
||||||
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
|
assert slot1.call_count == 1
|
||||||
|
|
||||||
|
# Disconnect the slot
|
||||||
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
||||||
with pytest.raises(KeyError):
|
|
||||||
consumer.call_args.kwargs["cb"](msg)
|
# Test to call the slot once (slot should be not connected anymore), count remains 1
|
||||||
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
|
assert slot1.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer):
|
def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer):
|
||||||
@ -148,16 +159,17 @@ def test_disconnect_one_slot_many_topics(bec_dispatcher, consumer):
|
|||||||
|
|
||||||
# disconnect using the right slot and topics
|
# disconnect using the right slot and topics
|
||||||
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
|
||||||
with pytest.raises(KeyError):
|
# Calling disconnected topic0 should not call slot1
|
||||||
consumer.call_args_list[0].kwargs["cb"](msg)
|
consumer.call_args_list[0].kwargs["cb"](msg)
|
||||||
|
assert slot1.call_count == 4
|
||||||
|
# Calling topic1 should still call slot1
|
||||||
consumer.call_args_list[1].kwargs["cb"](msg)
|
consumer.call_args_list[1].kwargs["cb"](msg)
|
||||||
assert slot1.call_count == 5
|
assert slot1.call_count == 5
|
||||||
|
|
||||||
|
# disconnect remaining topic1 from slot1, calling any topic should not increase count
|
||||||
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
|
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
|
||||||
with pytest.raises(KeyError):
|
consumer.call_args_list[0].kwargs["cb"](msg)
|
||||||
consumer.call_args_list[0].kwargs["cb"](msg)
|
consumer.call_args_list[1].kwargs["cb"](msg)
|
||||||
with pytest.raises(KeyError):
|
|
||||||
consumer.call_args_list[1].kwargs["cb"](msg)
|
|
||||||
assert slot1.call_count == 5
|
assert slot1.call_count == 5
|
||||||
|
|
||||||
|
|
||||||
@ -174,12 +186,9 @@ def test_disconnect_all(bec_dispatcher, consumer):
|
|||||||
bec_dispatcher.disconnect_all()
|
bec_dispatcher.disconnect_all()
|
||||||
|
|
||||||
# Simulate messages and verify that none of the slots are called
|
# Simulate messages and verify that none of the slots are called
|
||||||
with pytest.raises(KeyError):
|
consumer.call_args_list[0].kwargs["cb"](msg)
|
||||||
consumer.call_args_list[0].kwargs["cb"](msg)
|
consumer.call_args_list[1].kwargs["cb"](msg)
|
||||||
with pytest.raises(KeyError):
|
consumer.call_args_list[2].kwargs["cb"](msg)
|
||||||
consumer.call_args_list[1].kwargs["cb"](msg)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
consumer.call_args_list[2].kwargs["cb"](msg)
|
|
||||||
|
|
||||||
# Ensure that the slots have not been called
|
# Ensure that the slots have not been called
|
||||||
assert slot1.call_count == 0
|
assert slot1.call_count == 0
|
||||||
@ -236,6 +245,5 @@ def test_disconnect_all_with_single_callback_for_multiple_topics(bec_dispatcher,
|
|||||||
assert slot1.call_count == 0 # Slot has not been called
|
assert slot1.call_count == 0 # Slot has not been called
|
||||||
|
|
||||||
# Simulate messages and verify that the slot is not called
|
# Simulate messages and verify that the slot is not called
|
||||||
msg = MessageObject(topic="topic1", value=ScanMessage(point_id=0, scanID=0, data={}).dumps())
|
consumer.call_args.kwargs["cb"](msg)
|
||||||
with pytest.raises(KeyError):
|
assert slot1.call_count == 0 # Slot has not been called
|
||||||
consumer.call_args.kwargs["cb"](msg)
|
|
||||||
|
Reference in New Issue
Block a user