0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

refactor(bec_dispatcher): new BEC dispatcher - rebased

This commit is contained in:
2024-02-27 13:59:40 +01:00
committed by wyzula-jan
parent 9def3734af
commit 90907e0a9c
3 changed files with 84 additions and 134 deletions

View File

@ -154,6 +154,7 @@ class BECFigureClientMixin:
self._run_rpc("close", (), wait_for_rpc_response=False)
self._process.kill()
self._process = None
self._client.shutdown()
def _start_plot_process(self) -> None:
"""

View File

@ -1,37 +1,71 @@
from __future__ import annotations
import argparse
import itertools
import os
import collections
from collections.abc import Callable
from typing import Union
from typing import TYPE_CHECKING, Union
import redis
from bec_lib import BECClient, ServiceConfig
from bec_lib.endpoints import EndpointInfo
from bec_lib import BECClient
from bec_lib.redis_connector import MessageObject, RedisConnector
from qtpy.QtCore import QObject
from qtpy.QtCore import Signal as pyqtSignal
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition
# and cannot be dynamically added as class attributes after the class has been defined.
_signal_class_factory = (
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
)
if TYPE_CHECKING:
from bec_lib.endpoints import EndpointInfo
class _Connection:
"""Utility class to keep track of slots connected to a particular redis connector"""
class QtThreadSafeCallback(QObject):
cb_signal = pyqtSignal(dict, dict)
def __init__(self, callback) -> None:
self.callback = callback
def __init__(self, cb):
super().__init__()
self.slots = set()
# keep a reference to a new signal class, so it is not gc'ed
self._signal_container = next(_signal_class_factory)()
self.signal: pyqtSignal = self._signal_container.signal
self.cb = cb
self.cb_signal.connect(self.cb)
def __hash__(self):
# make 2 differents QtThreadSafeCallback to look
# identical when used as dictionary keys, if the
# callback is the same
return id(self.cb)
def __call__(self, msg_content, metadata):
self.cb_signal.emit(msg_content, metadata)
class BECDispatcher(QObject):
class QtRedisConnector(RedisConnector):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _execute_callback(self, cb, msg, kwargs):
if not isinstance(cb, QtThreadSafeCallback):
return super()._execute_callback(cb, msg, kwargs)
# if msg.msg_type == "bundle_message":
# # big warning: how to handle bundle messages?
# # message with messages inside ; which slot to call?
# # bundle_msg = msg
# # for msg in bundle_msg:
# # ...
# # for now, only consider the 1st message
# msg = msg[0]
# raise RuntimeError(f"
if isinstance(msg, MessageObject):
if isinstance(msg.value, list):
msg = msg.value[0]
else:
msg = msg.value
# we can notice kwargs are lost when passed to Qt slot
metadata = msg.metadata
cb(msg.content, metadata)
else:
# from stream
msg = msg["data"]
cb(msg.content, msg.metadata)
class BECDispatcher:
"""Utility class to keep track of slots connected to a particular redis connector"""
_instance = None
@ -47,14 +81,21 @@ class BECDispatcher(QObject):
if self._initialized:
return
super().__init__()
self.client = BECClient() if client is None else client
self._slots = collections.defaultdict(set)
if client is None:
self.client = BECClient(connector_cls=QtRedisConnector, forced=True)
else:
if self.client.started:
# have to reinitialize client to use proper connector
self.client.shutdown()
self.client._BECClient__init_params["connector_cls"] = QtRedisConnector
try:
self.client.start()
except redis.exceptions.ConnectionError:
print("Could not connect to Redis, skipping start of BECClient.")
self._connections = {}
self._initialized = True
@ -67,7 +108,6 @@ class BECDispatcher(QObject):
self,
slot: Callable,
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
single_callback_for_all_topics=False,
) -> None:
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
@ -75,117 +115,27 @@ class BECDispatcher(QObject):
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
the corresponding pub/sub message
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use
separate callbacks.
"""
# Normalise the topics input
if isinstance(topics, (str, EndpointInfo)):
topics = [topics]
slot = QtThreadSafeCallback(slot)
self.client.connector.register(topics, cb=slot)
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
self._slots[slot].update(set(topics_str))
endpoint_to_consumer_type = {
(topic.endpoint if isinstance(topic, EndpointInfo) else topic): (
topic.message_op.name if isinstance(topic, EndpointInfo) else "SEND"
)
for topic in topics
}
def disconnect_slot(self, slot: Callable, topics: Union[str, list]):
self.client.connector.unregister(topics, cb=slot)
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
self._slots[slot].difference_update(set(topics_str))
if not self._slots[slot]:
del self._slots[slot]
# Group topics by consumer type
consumer_type_to_endpoints = {}
for endpoint, consumer_type in endpoint_to_consumer_type.items():
if consumer_type not in consumer_type_to_endpoints:
consumer_type_to_endpoints[consumer_type] = []
consumer_type_to_endpoints[consumer_type].append(endpoint)
def disconnect_topics(self, topics: Union[str, list]):
self.client.connector.unregister(topics)
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
for slot in list(self._slots.keys()):
slot_topics = self._slots[slot]
slot_topics.difference_update(set(topics_str))
if not slot_topics:
del self._slots[slot]
for consumer_type, endpoints in consumer_type_to_endpoints.items():
topics_key = (
tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints)
)
if topics_key not in self._connections:
self._connections[topics_key] = self._create_connection(endpoints, consumer_type)
connection = self._connections[topics_key]
if slot not in connection.slots:
connection.signal.connect(slot)
connection.slots.add(slot)
def _create_connection(self, topics: list, consumer_type: str) -> _Connection:
"""Creates a new connection for given topics."""
def cb(msg):
if isinstance(msg, dict):
msg = msg["data"]
else:
msg = msg.value
for connection_key, connection in self._connections.items():
if set(topics).intersection(connection_key):
if isinstance(msg, list):
msg = msg[0]
connection.signal.emit(msg.content, msg.metadata)
try:
if consumer_type == "STREAM":
self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True)
else:
self.client.connector.register(topics=topics, cb=cb)
except redis.exceptions.ConnectionError:
print("Could not connect to Redis, skipping registration of topics.")
return _Connection(cb)
def _do_disconnect_slot(self, topic, slot):
print(f"Disconnecting {slot} from {topic}")
connection = self._connections[topic]
try:
connection.signal.disconnect(slot)
except TypeError:
print(f"Could not disconnect slot:'{slot}' from topic:'{topic}'")
print("Continue to remove slot:'{slot}' from 'connection.slots'.")
connection.slots.remove(slot)
if not connection.slots:
del self._connections[topic]
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
"""A helper method to disconnect a slot from a specific topic.
Args:
slot (Callable): A slot to be disconnected
topic (str): A corresponding topic that can typically be acquired via
bec_lib.MessageEndpoints
"""
connection = self._connections.get(topic)
if connection and slot in connection.slots:
self._do_disconnect_slot(topic, slot)
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
"""Disconnect widget's pyqt slot from pub/sub updates on a topic.
Args:
slot (Callable): A slot to be disconnected
topics (str | list): A corresponding topic or list of topics that can typically be acquired via
bec_lib.MessageEndpoints
"""
# Normalise the topics input
if isinstance(topics, (str, EndpointInfo)):
topics = [topics]
endpoints = [
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
]
for key, connection in list(self._connections.items()):
if slot in connection.slots:
common_topics = set(endpoints).intersection(key)
if common_topics:
remaining_topics = set(key) - set(endpoints)
# Disconnect slot from common topics
self._do_disconnect_slot(key, slot)
# Reconnect slot to remaining topics if any
if remaining_topics:
self.connect_slot(slot, list(remaining_topics), True)
def disconnect_all(self):
"""Disconnect all slots from all topics."""
for key, connection in list(self._connections.items()):
for slot in list(connection.slots):
self._disconnect_slot_from_topic(slot, key)
def disconnect_all(self, *args, **kwargs):
self.disconnect_topics(self.client.connector._topics_cb)

View File

@ -217,7 +217,6 @@ class MotorMap(pg.GraphicsLayoutWidget):
bec_dispatcher.connect_slot(
self.on_device_readback,
endpoints,
single_callback_for_all_topics=True,
)
def _add_limits_to_plot_data(self):