mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 11:41:49 +02:00
refactor(bec_dispatcher): new BEC dispatcher - rebased
This commit is contained in:
@ -154,6 +154,7 @@ class BECFigureClientMixin:
|
|||||||
self._run_rpc("close", (), wait_for_rpc_response=False)
|
self._run_rpc("close", (), wait_for_rpc_response=False)
|
||||||
self._process.kill()
|
self._process.kill()
|
||||||
self._process = None
|
self._process = None
|
||||||
|
self._client.shutdown()
|
||||||
|
|
||||||
def _start_plot_process(self) -> None:
|
def _start_plot_process(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1,37 +1,71 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import collections
|
||||||
import os
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from bec_lib import BECClient, ServiceConfig
|
from bec_lib import BECClient
|
||||||
from bec_lib.endpoints import EndpointInfo
|
from bec_lib.redis_connector import MessageObject, RedisConnector
|
||||||
from qtpy.QtCore import QObject
|
from qtpy.QtCore import QObject
|
||||||
from qtpy.QtCore import Signal as pyqtSignal
|
from qtpy.QtCore import Signal as pyqtSignal
|
||||||
|
|
||||||
# Adding a new pyqt signal requires a class factory, as they must be part of the class definition
|
if TYPE_CHECKING:
|
||||||
# and cannot be dynamically added as class attributes after the class has been defined.
|
from bec_lib.endpoints import EndpointInfo
|
||||||
_signal_class_factory = (
|
|
||||||
type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal(dict, dict))) for i in itertools.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _Connection:
|
class QtThreadSafeCallback(QObject):
|
||||||
"""Utility class to keep track of slots connected to a particular redis connector"""
|
cb_signal = pyqtSignal(dict, dict)
|
||||||
|
|
||||||
def __init__(self, callback) -> None:
|
def __init__(self, cb):
|
||||||
self.callback = callback
|
super().__init__()
|
||||||
|
|
||||||
self.slots = set()
|
self.cb = cb
|
||||||
# keep a reference to a new signal class, so it is not gc'ed
|
self.cb_signal.connect(self.cb)
|
||||||
self._signal_container = next(_signal_class_factory)()
|
|
||||||
self.signal: pyqtSignal = self._signal_container.signal
|
def __hash__(self):
|
||||||
|
# make 2 differents QtThreadSafeCallback to look
|
||||||
|
# identical when used as dictionary keys, if the
|
||||||
|
# callback is the same
|
||||||
|
return id(self.cb)
|
||||||
|
|
||||||
|
def __call__(self, msg_content, metadata):
|
||||||
|
self.cb_signal.emit(msg_content, metadata)
|
||||||
|
|
||||||
|
|
||||||
class BECDispatcher(QObject):
|
class QtRedisConnector(RedisConnector):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _execute_callback(self, cb, msg, kwargs):
|
||||||
|
if not isinstance(cb, QtThreadSafeCallback):
|
||||||
|
return super()._execute_callback(cb, msg, kwargs)
|
||||||
|
# if msg.msg_type == "bundle_message":
|
||||||
|
# # big warning: how to handle bundle messages?
|
||||||
|
# # message with messages inside ; which slot to call?
|
||||||
|
# # bundle_msg = msg
|
||||||
|
# # for msg in bundle_msg:
|
||||||
|
# # ...
|
||||||
|
# # for now, only consider the 1st message
|
||||||
|
# msg = msg[0]
|
||||||
|
# raise RuntimeError(f"
|
||||||
|
if isinstance(msg, MessageObject):
|
||||||
|
if isinstance(msg.value, list):
|
||||||
|
msg = msg.value[0]
|
||||||
|
else:
|
||||||
|
msg = msg.value
|
||||||
|
|
||||||
|
# we can notice kwargs are lost when passed to Qt slot
|
||||||
|
metadata = msg.metadata
|
||||||
|
cb(msg.content, metadata)
|
||||||
|
else:
|
||||||
|
# from stream
|
||||||
|
msg = msg["data"]
|
||||||
|
cb(msg.content, msg.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class BECDispatcher:
|
||||||
"""Utility class to keep track of slots connected to a particular redis connector"""
|
"""Utility class to keep track of slots connected to a particular redis connector"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
@ -47,14 +81,21 @@ class BECDispatcher(QObject):
|
|||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
super().__init__()
|
self._slots = collections.defaultdict(set)
|
||||||
self.client = BECClient() if client is None else client
|
|
||||||
|
if client is None:
|
||||||
|
self.client = BECClient(connector_cls=QtRedisConnector, forced=True)
|
||||||
|
else:
|
||||||
|
if self.client.started:
|
||||||
|
# have to reinitialize client to use proper connector
|
||||||
|
self.client.shutdown()
|
||||||
|
self.client._BECClient__init_params["connector_cls"] = QtRedisConnector
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client.start()
|
self.client.start()
|
||||||
except redis.exceptions.ConnectionError:
|
except redis.exceptions.ConnectionError:
|
||||||
print("Could not connect to Redis, skipping start of BECClient.")
|
print("Could not connect to Redis, skipping start of BECClient.")
|
||||||
|
|
||||||
self._connections = {}
|
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@ -67,7 +108,6 @@ class BECDispatcher(QObject):
|
|||||||
self,
|
self,
|
||||||
slot: Callable,
|
slot: Callable,
|
||||||
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
|
topics: Union[EndpointInfo, str, list[Union[EndpointInfo, str]]],
|
||||||
single_callback_for_all_topics=False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
|
"""Connect widget's pyqt slot, so that it is called on new pub/sub topic message.
|
||||||
|
|
||||||
@ -75,117 +115,27 @@ class BECDispatcher(QObject):
|
|||||||
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
|
slot (Callable): A slot method/function that accepts two inputs: content and metadata of
|
||||||
the corresponding pub/sub message
|
the corresponding pub/sub message
|
||||||
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
topics (EndpointInfo | str | list): A topic or list of topics that can typically be acquired via bec_lib.MessageEndpoints
|
||||||
single_callback_for_all_topics (bool): If True, use the same callback for all topics, otherwise use
|
|
||||||
separate callbacks.
|
|
||||||
"""
|
"""
|
||||||
# Normalise the topics input
|
slot = QtThreadSafeCallback(slot)
|
||||||
if isinstance(topics, (str, EndpointInfo)):
|
self.client.connector.register(topics, cb=slot)
|
||||||
topics = [topics]
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
|
self._slots[slot].update(set(topics_str))
|
||||||
|
|
||||||
endpoint_to_consumer_type = {
|
def disconnect_slot(self, slot: Callable, topics: Union[str, list]):
|
||||||
(topic.endpoint if isinstance(topic, EndpointInfo) else topic): (
|
self.client.connector.unregister(topics, cb=slot)
|
||||||
topic.message_op.name if isinstance(topic, EndpointInfo) else "SEND"
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
)
|
self._slots[slot].difference_update(set(topics_str))
|
||||||
for topic in topics
|
if not self._slots[slot]:
|
||||||
}
|
del self._slots[slot]
|
||||||
|
|
||||||
# Group topics by consumer type
|
def disconnect_topics(self, topics: Union[str, list]):
|
||||||
consumer_type_to_endpoints = {}
|
self.client.connector.unregister(topics)
|
||||||
for endpoint, consumer_type in endpoint_to_consumer_type.items():
|
topics_str, _ = self.client.connector._convert_endpointinfo(topics)
|
||||||
if consumer_type not in consumer_type_to_endpoints:
|
for slot in list(self._slots.keys()):
|
||||||
consumer_type_to_endpoints[consumer_type] = []
|
slot_topics = self._slots[slot]
|
||||||
consumer_type_to_endpoints[consumer_type].append(endpoint)
|
slot_topics.difference_update(set(topics_str))
|
||||||
|
if not slot_topics:
|
||||||
|
del self._slots[slot]
|
||||||
|
|
||||||
for consumer_type, endpoints in consumer_type_to_endpoints.items():
|
def disconnect_all(self, *args, **kwargs):
|
||||||
topics_key = (
|
self.disconnect_topics(self.client.connector._topics_cb)
|
||||||
tuple(sorted(endpoints)) if single_callback_for_all_topics else tuple(endpoints)
|
|
||||||
)
|
|
||||||
|
|
||||||
if topics_key not in self._connections:
|
|
||||||
self._connections[topics_key] = self._create_connection(endpoints, consumer_type)
|
|
||||||
connection = self._connections[topics_key]
|
|
||||||
|
|
||||||
if slot not in connection.slots:
|
|
||||||
connection.signal.connect(slot)
|
|
||||||
connection.slots.add(slot)
|
|
||||||
|
|
||||||
def _create_connection(self, topics: list, consumer_type: str) -> _Connection:
|
|
||||||
"""Creates a new connection for given topics."""
|
|
||||||
|
|
||||||
def cb(msg):
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg = msg["data"]
|
|
||||||
else:
|
|
||||||
msg = msg.value
|
|
||||||
for connection_key, connection in self._connections.items():
|
|
||||||
if set(topics).intersection(connection_key):
|
|
||||||
if isinstance(msg, list):
|
|
||||||
msg = msg[0]
|
|
||||||
connection.signal.emit(msg.content, msg.metadata)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if consumer_type == "STREAM":
|
|
||||||
self.client.connector.register_stream(topics=topics, cb=cb, newest_only=True)
|
|
||||||
else:
|
|
||||||
self.client.connector.register(topics=topics, cb=cb)
|
|
||||||
except redis.exceptions.ConnectionError:
|
|
||||||
print("Could not connect to Redis, skipping registration of topics.")
|
|
||||||
|
|
||||||
return _Connection(cb)
|
|
||||||
|
|
||||||
def _do_disconnect_slot(self, topic, slot):
|
|
||||||
print(f"Disconnecting {slot} from {topic}")
|
|
||||||
connection = self._connections[topic]
|
|
||||||
try:
|
|
||||||
connection.signal.disconnect(slot)
|
|
||||||
except TypeError:
|
|
||||||
print(f"Could not disconnect slot:'{slot}' from topic:'{topic}'")
|
|
||||||
print("Continue to remove slot:'{slot}' from 'connection.slots'.")
|
|
||||||
connection.slots.remove(slot)
|
|
||||||
if not connection.slots:
|
|
||||||
del self._connections[topic]
|
|
||||||
|
|
||||||
def _disconnect_slot_from_topic(self, slot: Callable, topic: str) -> None:
|
|
||||||
"""A helper method to disconnect a slot from a specific topic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slot (Callable): A slot to be disconnected
|
|
||||||
topic (str): A corresponding topic that can typically be acquired via
|
|
||||||
bec_lib.MessageEndpoints
|
|
||||||
"""
|
|
||||||
connection = self._connections.get(topic)
|
|
||||||
if connection and slot in connection.slots:
|
|
||||||
self._do_disconnect_slot(topic, slot)
|
|
||||||
|
|
||||||
def disconnect_slot(self, slot: Callable, topics: Union[str, list]) -> None:
|
|
||||||
"""Disconnect widget's pyqt slot from pub/sub updates on a topic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slot (Callable): A slot to be disconnected
|
|
||||||
topics (str | list): A corresponding topic or list of topics that can typically be acquired via
|
|
||||||
bec_lib.MessageEndpoints
|
|
||||||
"""
|
|
||||||
# Normalise the topics input
|
|
||||||
if isinstance(topics, (str, EndpointInfo)):
|
|
||||||
topics = [topics]
|
|
||||||
|
|
||||||
endpoints = [
|
|
||||||
topic.endpoint if isinstance(topic, EndpointInfo) else topic for topic in topics
|
|
||||||
]
|
|
||||||
|
|
||||||
for key, connection in list(self._connections.items()):
|
|
||||||
if slot in connection.slots:
|
|
||||||
common_topics = set(endpoints).intersection(key)
|
|
||||||
if common_topics:
|
|
||||||
remaining_topics = set(key) - set(endpoints)
|
|
||||||
# Disconnect slot from common topics
|
|
||||||
self._do_disconnect_slot(key, slot)
|
|
||||||
# Reconnect slot to remaining topics if any
|
|
||||||
if remaining_topics:
|
|
||||||
self.connect_slot(slot, list(remaining_topics), True)
|
|
||||||
|
|
||||||
def disconnect_all(self):
|
|
||||||
"""Disconnect all slots from all topics."""
|
|
||||||
for key, connection in list(self._connections.items()):
|
|
||||||
for slot in list(connection.slots):
|
|
||||||
self._disconnect_slot_from_topic(slot, key)
|
|
||||||
|
@ -217,7 +217,6 @@ class MotorMap(pg.GraphicsLayoutWidget):
|
|||||||
bec_dispatcher.connect_slot(
|
bec_dispatcher.connect_slot(
|
||||||
self.on_device_readback,
|
self.on_device_readback,
|
||||||
endpoints,
|
endpoints,
|
||||||
single_callback_for_all_topics=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_limits_to_plot_data(self):
|
def _add_limits_to_plot_data(self):
|
||||||
|
Reference in New Issue
Block a user