0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

fix: bec_dispatcher.py can partially disconnect topics from slot

This commit is contained in:
wyzula-jan
2024-01-16 16:02:22 +01:00
parent e51be04b95
commit 7607d7a3b6
2 changed files with 86 additions and 102 deletions

View File

@ -1,3 +1,7 @@
# TODO last backup
# todo super last refactor
from __future__ import annotations from __future__ import annotations
import argparse import argparse
@ -10,7 +14,7 @@ from bec_lib import BECClient, messages, ServiceConfig
from bec_lib.redis_connector import RedisConsumerThreaded from bec_lib.redis_connector import RedisConsumerThreaded
from qtpy.QtCore import QObject, Signal as pyqtSignal from qtpy.QtCore import QObject, Signal as pyqtSignal
# Adding a new pyqt signal requres a class factory, as they must be part of the class definition # Adding a new pyqt signal requires a class factory, as they must be part of the class definition
# and cannot be dynamically added as class attributes after the class has been defined. # and cannot be dynamically added as class attributes after the class has been defined.
_signal_class_factory = ( _signal_class_factory = (
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count() type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
@ -29,6 +33,8 @@ class _Connection:
class _BECDispatcher(QObject): class _BECDispatcher(QObject):
"""Utility class to keep track of slots connected to a particular redis consumer"""
def __init__(self, bec_config=None): def __init__(self, bec_config=None):
super().__init__() super().__init__()
self.client = BECClient() self.client = BECClient()
@ -41,64 +47,6 @@ class _BECDispatcher(QObject):
self.client.initialize(config=ServiceConfig(config_path=bec_config)) self.client.initialize(config=ServiceConfig(config_path=bec_config))
self._connections = {} self._connections = {}
def _connect_slot_to_topic(self, slot: Callable, topic: str) -> None:
"""A helper method to connect a slot to a specific topic
Args:
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
the corresponding pub/sub message
topic (str): A topic that can typically be acquired via bec_lib.MessageEndpoints
"""
# create new connection for topic if it doesn't exist
if topic not in self._connections:
def cb(msg):
msg = messages.MessageReader.loads(msg.value)
# TODO: this can could be replaced with a simple
# self._connections[topic].signal.emit(msg.content, msg.metadata)
# once all dispatcher.connect_slot calls are made with a single topic only
if not isinstance(msg, list):
msg = [msg]
for msg_i in msg:
self._connections[topic].signal.emit(msg_i.content, msg_i.metadata)
consumer = self.client.connector.consumer(topics=topic, cb=cb)
consumer.start()
self._connections[topic] = _Connection(consumer)
# connect slot if it's not connected
if slot not in self._connections[topic].slots:
self._connections[topic].signal.connect(slot)
self._connections[topic].slots.add(slot)
def _connect_slot_to_multiple_topics(self, slot: Callable, topics: list) -> None:
"""
A helper method to connect a slot to multiple topics using a single callback.
Args:
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
the corresponding pub/sub message
topics (list): A list of topics that can typically be acquired via bec_lib.MessageEndpoints.
"""
# Creating a unique key for this combination of topics
topics_key = tuple(sorted(topics))
if topics_key not in self._connections:
def cb(msg):
msg = messages.MessageReader.loads(msg.value)
self._connections[topics_key].signal.emit(msg.content, msg.metadata)
consumer = self.client.connector.consumer(topics=topics, cb=cb)
consumer.start()
self._connections[topics_key] = _Connection(consumer)
if slot not in self._connections[topics_key].slots:
self._connections[topics_key].signal.connect(slot)
self._connections[topics_key].slots.add(slot)
def connect_slot( def connect_slot(
self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False self, slot: Callable, topics: Union[str, list], single_callback_for_all_topics=False
) -> None: ) -> None:
@ -114,11 +62,33 @@ class _BECDispatcher(QObject):
if isinstance(topics, str): if isinstance(topics, str):
topics = [topics] topics = [topics]
if single_callback_for_all_topics: # Ensure topics_key is a tuple, whether single_callback_for_all_topics is True or False.
self._connect_slot_to_multiple_topics(slot, topics) topics_key = tuple(sorted(topics)) if single_callback_for_all_topics else tuple(topics)
else:
for topic in topics: if topics_key not in self._connections:
self._connect_slot_to_topic(slot, topic) self._connections[topics_key] = self._create_connection(topics)
connection = self._connections[topics_key]
if slot not in connection.slots:
connection.signal.connect(slot)
connection.slots.add(slot)
def _create_connection(self, topics: list) -> _Connection:
"""Creates a new connection for given topics."""
def cb(msg):
msg = messages.MessageReader.loads(msg.value)
if not isinstance(msg, list):
msg = [msg]
for msg_i in msg:
for connection_key, connection in self._connections.items():
if set(topics).intersection(
connection_key if isinstance(connection_key, tuple) else [connection_key]
):
connection.signal.emit(msg_i.content, msg_i.metadata)
consumer = self.client.connector.consumer(topics=topics, cb=cb)
consumer.start()
return _Connection(consumer)
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None: def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
"""A helper method to disconnect a slot from a specific topic. """A helper method to disconnect a slot from a specific topic.
@ -128,19 +98,13 @@ class _BECDispatcher(QObject):
topic (str): A corresponding topic that can typically be acquired via topic (str): A corresponding topic that can typically be acquired via
bec_lib.MessageEndpoints bec_lib.MessageEndpoints
""" """
if topic not in self._connections: connection = self._connections.get(topic)
return if connection and slot in connection.slots:
connection.signal.disconnect(slot)
if slot not in self._connections[topic].slots: connection.slots.remove(slot)
return if not connection.slots:
connection.consumer.shutdown()
self._connections[topic].signal.disconnect(slot) del self._connections[topic]
self._connections[topic].slots.remove(slot)
if not self._connections[topic].slots:
# shutdown consumer if there are no more connected slots
self._connections[topic].consumer.shutdown()
del self._connections[topic]
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None: def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
"""Disconnect widget's pyqt slot from pub/sub updates on a topic. """Disconnect widget's pyqt slot from pub/sub updates on a topic.
@ -153,16 +117,28 @@ class _BECDispatcher(QObject):
if isinstance(topics, str): if isinstance(topics, str):
topics = [topics] topics = [topics]
for topic in topics: for key, connection in list(self._connections.items()):
self._disconnect_slot_from_topic(slot, topic) if slot in connection.slots:
common_topics = set(topics).intersection(key)
if common_topics:
remaining_topics = set(key) - set(topics)
# Disconnect slot from common topics
connection.signal.disconnect(slot)
connection.slots.remove(slot)
if not connection.slots:
print(f"{connection.consumer} is shutting down")
connection.consumer.shutdown()
del self._connections[key]
# Reconnect slot to remaining topics if any
if remaining_topics:
self.connect_slot(slot, list(remaining_topics), True)
def disconnect_all(self): def disconnect_all(self):
"""Disconnect all slots from all topics.""" """Disconnect all slots from all topics."""
for key, connection in list(self._connections.items()): for key, connection in list(self._connections.items()):
for slot in list(connection.slots): for slot in list(connection.slots):
self._disconnect_slot_from_topic(slot, key) # Updated to pass key self._disconnect_slot_from_topic(slot, key)
# Check if the topic still exists before trying to shutdown and delete
if key in self._connections and not connection.slots: if key in self._connections and not connection.slots:
connection.consumer.shutdown() connection.consumer.shutdown()
del self._connections[key] del self._connections[key]

View File

@ -78,29 +78,40 @@ def test_disconnect_one_slot_one_topic(bec_dispatcher, consumer):
slot1, slot2 = Mock(), Mock() slot1, slot2 = Mock(), Mock()
bec_dispatcher.connect_slot(slot=slot1, topics="topic0") bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
# disconnect using a different slot # disconnect using a different topic
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
consumer.call_args.kwargs["cb"](msg) consumer.call_args.kwargs["cb"](msg)
assert slot1.call_count == 1 assert slot1.call_count == 1
# disconnect using a different topics # disconnect using a different slot
bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0") bec_dispatcher.disconnect_slot(slot=slot2, topics="topic0")
consumer.call_args.kwargs["cb"](msg) consumer.call_args.kwargs["cb"](msg)
assert slot1.call_count == 2 assert slot1.call_count == 2
# disconnect using the right slot and topics # disconnect using the right slot and topics
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
with pytest.raises(KeyError): # reset count to 0 for slot
consumer.call_args.kwargs["cb"](msg) slot1.reset_mock()
consumer.call_args.kwargs["cb"](msg)
assert slot1.call_count == 0
def test_disconnect_identical(bec_dispatcher, consumer): def test_disconnect_identical(bec_dispatcher, consumer):
slot1 = Mock() slot1 = Mock()
# Try to connect slot twice
bec_dispatcher.connect_slot(slot=slot1, topics="topic0") bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
bec_dispatcher.connect_slot(slot=slot1, topics="topic0") bec_dispatcher.connect_slot(slot=slot1, topics="topic0")
# Test to call the slot once (slot should be not connected twice)
consumer.call_args.kwargs["cb"](msg)
assert slot1.call_count == 1
# Disconnect the slot
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
with pytest.raises(KeyError):
consumer.call_args.kwargs["cb"](msg) # Test to call the slot once (slot should be not connected anymore), count remains 1
consumer.call_args.kwargs["cb"](msg)
assert slot1.call_count == 1
def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer): def test_disconnect_many_slots_one_topic(bec_dispatcher, consumer):
@ -148,16 +159,17 @@ def test_disconnect_one_slot_many_topics(bec_dispatcher, consumer):
# disconnect using the right slot and topics # disconnect using the right slot and topics
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0") bec_dispatcher.disconnect_slot(slot=slot1, topics="topic0")
with pytest.raises(KeyError): # Calling disconnected topic0 should not call slot1
consumer.call_args_list[0].kwargs["cb"](msg) consumer.call_args_list[0].kwargs["cb"](msg)
assert slot1.call_count == 4
# Calling topic1 should still call slot1
consumer.call_args_list[1].kwargs["cb"](msg) consumer.call_args_list[1].kwargs["cb"](msg)
assert slot1.call_count == 5 assert slot1.call_count == 5
# disconnect remaining topic1 from slot1, calling any topic should not increase count
bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1") bec_dispatcher.disconnect_slot(slot=slot1, topics="topic1")
with pytest.raises(KeyError): consumer.call_args_list[0].kwargs["cb"](msg)
consumer.call_args_list[0].kwargs["cb"](msg) consumer.call_args_list[1].kwargs["cb"](msg)
with pytest.raises(KeyError):
consumer.call_args_list[1].kwargs["cb"](msg)
assert slot1.call_count == 5 assert slot1.call_count == 5
@ -174,12 +186,9 @@ def test_disconnect_all(bec_dispatcher, consumer):
bec_dispatcher.disconnect_all() bec_dispatcher.disconnect_all()
# Simulate messages and verify that none of the slots are called # Simulate messages and verify that none of the slots are called
with pytest.raises(KeyError): consumer.call_args_list[0].kwargs["cb"](msg)
consumer.call_args_list[0].kwargs["cb"](msg) consumer.call_args_list[1].kwargs["cb"](msg)
with pytest.raises(KeyError): consumer.call_args_list[2].kwargs["cb"](msg)
consumer.call_args_list[1].kwargs["cb"](msg)
with pytest.raises(KeyError):
consumer.call_args_list[2].kwargs["cb"](msg)
# Ensure that the slots have not been called # Ensure that the slots have not been called
assert slot1.call_count == 0 assert slot1.call_count == 0
@ -236,6 +245,5 @@ def test_disconnect_all_with_single_callback_for_multiple_topics(bec_dispatcher,
assert slot1.call_count == 0 # Slot has not been called assert slot1.call_count == 0 # Slot has not been called
# Simulate messages and verify that the slot is not called # Simulate messages and verify that the slot is not called
msg = MessageObject(topic="topic1", value=ScanMessage(point_id=0, scanID=0, data={}).dumps()) consumer.call_args.kwargs["cb"](msg)
with pytest.raises(KeyError): assert slot1.call_count == 0 # Slot has not been called
consumer.call_args.kwargs["cb"](msg)