0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

refactor(bec_dispatcher): new BEC dispatcher - rebased

This commit is contained in:
2024-02-27 13:59:40 +01:00
committed by wyzula-jan
parent 9def3734af
commit 90907e0a9c
3 changed files with 84 additions and 134 deletions

View File

@ -154,6 +154,7 @@ class BECFigureClientMixin:
self._run_rpc("close", (), wait_for_rpc_response=False) self._run_rpc("close", (), wait_for_rpc_response=False)
self._process.kill() self._process.kill()
self._process = None self._process = None
self._client.shutdown()
def _start_plot_process(self) -> None: def _start_plot_process(self) -> None:
""" """

View File

@ -1,37 +1,71 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import itertools import collections
import os
from collections.abc import Callable from collections.abc import Callable
from typing import Union from typing import TYPE_CHECKING, Union
import redis import redis
from bec_lib import BECClient, ServiceConfig from bec_lib import BECClient
from bec_lib.endpoints import EndpointInfo from bec_lib.redis_connector import MessageObject, RedisConnector
from qtpy.QtCore import QObject from qtpy.QtCore import QObject
from qtpy.QtCore import Signal as pyqtSignal from qtpy.QtCore import Signal as pyqtSignal
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition if TYPE_CHECKING:
# and cannot be dynamically added as class attributes after the class has been defined. from bec_lib.endpoints import EndpointInfo
_signal_class_factory = (
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
)
class _Connection: class QtThreadSafeCallback(QObject):
"""Utility class to keep track of slots connected to a particular redis connector""" cb_signal = pyqtSignal(dict, dict)
def __init__(self, callback) -> None: def __init__(self, cb):
self.callback = callback super().__init__()
self.slots = set() self.cb = cb
# keep a reference to a new signal class, so it is not gc'ed self.cb_signal.connect(self.cb)
self._signal_container = next(_signal_class_factory)()
self.signal: pyqtSignal = self._signal_container.signal def __hash__(self):
# make 2 differents QtThreadSafeCallback to look
# identical when used as dictionary keys, if the
# callback is the same
return id(self.cb)
def __call__(self, msg_content, metadata):
self.cb_signal.emit(msg_content, metadata)
class BECDispatcher(QObject): class QtRedisConnector(RedisConnector):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _execute_callback(self, cb, msg, kwargs):
if not isinstance(cb, QtThreadSafeCallback):
return super()._execute_callback(cb, msg, kwargs)
# if msg.msg_type == "bundle_message":
# # big warning: how to handle bundle messages?
# # message with messages inside ; which slot to call?
# # bundle_msg = msg
# # for msg in bundle_msg:
# # ...
# # for now, only consider the 1st message
# msg = msg[0]
# raise RuntimeError(f"
if isinstance(msg, MessageObject):
if isinstance(msg.value, list):
msg = msg.value[0]
else:
msg = msg.value
# we can notice kwargs are lost when passed to Qt slot
metadata = msg.metadata
cb(msg.content, metadata)
else:
# from stream
msg = msg["data"]
cb(msg.content, msg.metadata)
class BECDispatcher:
"""Utility class to keep track of slots connected to a particular redis connector""" """Utility class to keep track of slots connected to a particular redis connector"""
_instance = None _instance = None
@ -47,14 +81,21 @@ class BECDispatcher(QObject):
if self._initialized: if self._initialized:
return return
super().__init__() self._slots = collections.defaultdict(set)
self.client = BECClient() if client is None else client
if client is None:
self.client = BECClient(connector_cls=QtRedisConnector, forced=True)
else:
if self.client.started:
# have to reinitialize client to use proper connector
self.client.shutdown()
self.client._BECClient__init_params["connector_cls"] = QtRedisConnector
try: try:
self.client.start() self.client.start()
except redis.exceptions.ConnectionError: except redis.exceptions.ConnectionError:
print("Could not connect to Redis, skipping start of BECClient.") print("Could not connect to Redis, skipping start of BECClient.")
self._connections = {}
self._initialized = True self._initialized = True
@ -67,7 +108,6 @@ class BECDispatcher(QObject):
self, self,
slot: Callable, slot: Callable,
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]], topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
single_callback_for_all_topics=False,
) -> None: ) -> None:
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message. """Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
@ -75,117 +115,27 @@ class BECDispatcher(QObject):
slot (Callable): A slot method/function that accepts two inputs: content and metadata of slot (Callable): A slot method/function that accepts two inputs: content and metadata of
the corresponding pub/sub message the corresponding pub/sub message
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use
separate callbacks.
""" """
# Normalise the topics input slot = QtThreadSafeCallback(slot)
if isinstance(topics, (str, EndpointInfo)): self.client.connector.register(topics, cb=slot)
topics = [topics] topics_str, _ = self.client.connector._convert_endpointinfo(topics)
self._slots[slot].update(set(topics_str))
endpoint_to_consumer_type = { def disconnect_slot(self, slot: Callable, topics: Union[str, list]):
(topic.endpoint if isinstance(topic, EndpointInfo) else topic): ( self.client.connector.unregister(topics, cb=slot)
topic.message_op.name if isinstance(topic, EndpointInfo) else "SEND" topics_str, _ = self.client.connector._convert_endpointinfo(topics)
) self._slots[slot].difference_update(set(topics_str))
for topic in topics if not self._slots[slot]:
} del self._slots[slot]
# Group topics by consumer type def disconnect_topics(self, topics: Union[str, list]):
consumer_type_to_endpoints = {} self.client.connector.unregister(topics)
for endpoint, consumer_type in endpoint_to_consumer_type.items(): topics_str, _ = self.client.connector._convert_endpointinfo(topics)
if consumer_type not in consumer_type_to_endpoints: for slot in list(self._slots.keys()):
consumer_type_to_endpoints[consumer_type] = [] slot_topics = self._slots[slot]
consumer_type_to_endpoints[consumer_type].append(endpoint) slot_topics.difference_update(set(topics_str))
if not slot_topics:
del self._slots[slot]
for consumer_type, endpoints in consumer_type_to_endpoints.items(): def disconnect_all(self, *args, **kwargs):
topics_key = ( self.disconnect_topics(self.client.connector._topics_cb)
tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints)
)
if topics_key not in self._connections:
self._connections[topics_key] = self._create_connection(endpoints, consumer_type)
connection = self._connections[topics_key]
if slot not in connection.slots:
connection.signal.connect(slot)
connection.slots.add(slot)
def _create_connection(self, topics: list, consumer_type: str) -> _Connection:
"""Creates a new connection for given topics."""
def cb(msg):
if isinstance(msg, dict):
msg = msg["data"]
else:
msg = msg.value
for connection_key, connection in self._connections.items():
if set(topics).intersection(connection_key):
if isinstance(msg, list):
msg = msg[0]
connection.signal.emit(msg.content, msg.metadata)
try:
if consumer_type == "STREAM":
self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True)
else:
self.client.connector.register(topics=topics, cb=cb)
except redis.exceptions.ConnectionError:
print("Could not connect to Redis, skipping registration of topics.")
return _Connection(cb)
def _do_disconnect_slot(self, topic, slot):
print(f"Disconnecting {slot} from {topic}")
connection = self._connections[topic]
try:
connection.signal.disconnect(slot)
except TypeError:
print(f"Could not disconnect slot:'{slot}' from topic:'{topic}'")
print("Continue to remove slot:'{slot}' from 'connection.slots'.")
connection.slots.remove(slot)
if not connection.slots:
del self._connections[topic]
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
"""A helper method to disconnect a slot from a specific topic.
Args:
slot (Callable): A slot to be disconnected
topic (str): A corresponding topic that can typically be acquired via
bec_lib.MessageEndpoints
"""
connection = self._connections.get(topic)
if connection and slot in connection.slots:
self._do_disconnect_slot(topic, slot)
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
"""Disconnect widget's pyqt slot from pub/sub updates on a topic.
Args:
slot (Callable): A slot to be disconnected
topics (str | list): A corresponding topic or list of topics that can typically be acquired via
bec_lib.MessageEndpoints
"""
# Normalise the topics input
if isinstance(topics, (str, EndpointInfo)):
topics = [topics]
endpoints = [
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
]
for key, connection in list(self._connections.items()):
if slot in connection.slots:
common_topics = set(endpoints).intersection(key)
if common_topics:
remaining_topics = set(key) - set(endpoints)
# Disconnect slot from common topics
self._do_disconnect_slot(key, slot)
# Reconnect slot to remaining topics if any
if remaining_topics:
self.connect_slot(slot, list(remaining_topics), True)
def disconnect_all(self):
"""Disconnect all slots from all topics."""
for key, connection in list(self._connections.items()):
for slot in list(connection.slots):
self._disconnect_slot_from_topic(slot, key)

View File

@ -217,7 +217,6 @@ class MotorMap(pg.GraphicsLayoutWidget):
bec_dispatcher.connect_slot( bec_dispatcher.connect_slot(
self.on_device_readback, self.on_device_readback,
endpoints, endpoints,
single_callback_for_all_topics=True,
) )
def _add_limits_to_plot_data(self): def _add_limits_to_plot_data(self):