From 6a3df34cdfbec2434153362ded630305e5dc5e28 Mon Sep 17 00:00:00 2001 From: Ivan Usov Date: Thu, 10 Aug 2023 16:24:42 +0200 Subject: [PATCH] feat: add generic connect function for slots --- bec_widgets/bec_dispatcher.py | 56 +++++++++++++++++++++++++++++++++++ bec_widgets/scan_plot.py | 12 +++++--- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/bec_widgets/bec_dispatcher.py b/bec_widgets/bec_dispatcher.py index e77b6f38..dba70b2c 100644 --- a/bec_widgets/bec_dispatcher.py +++ b/bec_widgets/bec_dispatcher.py @@ -1,4 +1,5 @@ import argparse +import itertools import os from dataclasses import dataclass from threading import RLock @@ -17,6 +18,27 @@ class _BECDap: slots = set() +# Adding a new pyqt signal requres a class factory, as they must be part of the class definition +# and cannot be dynamically added as class attributes after the class has been defined. +_signal_class_factory = ( + type(f"Signal{i}", (QObject,), dict(signal=pyqtSignal("PyQt_PyObject"))) + for i in itertools.count() +) + + +@dataclass +class _Connection: + """Utility class to keep track of slots connected to a particular redis consumer""" + + consumer: RedisConsumerThreaded + slots = set() + # keep a reference to a new signal class, so it is not gc'ed + _signal_container = next(_signal_class_factory)() + + def __post_init__(self): + self.signal = self._signal_container.signal + + class _BECDispatcher(QObject): new_scan = pyqtSignal(dict, dict) scan_segment = pyqtSignal(dict, dict) @@ -42,6 +64,7 @@ class _BECDispatcher(QObject): "on_new_scan": self.new_scan, } self._daps = {} + self._connections = {} self._scan_id = None scan_lock = RLock() @@ -65,6 +88,39 @@ class _BECDispatcher(QObject): if callable(slot): signal.connect(slot) + def connect_slot(self, slot, topic): + # create new connection for topic if it doesn't exist + if topic not in self._connections: + + def cb(msg): + msg = BECMessage.MessageReader.loads(msg.value) + self._connections[topic].signal.emit(msg) + + consumer = self.client.connector.consumer(topics=topic, cb=cb) + consumer.start() + + self._connections[topic] = _Connection(consumer) + + # connect slot if it's not connected + if slot not in self._connections[topic].slots: + self._connections[topic].signal.connect(slot) + self._connections[topic].slots.add(slot) + + def disconnect_slot(self, slot, topic): + if topic not in self._connections: + return + + if slot not in self._connections[topic].slots: + return + + self._connections[topic].signal.disconnect(slot) + self._connections[topic].slots.remove(slot) + + if not self._connections[topic].slots: + # shutdown consumer if there are no more connected slots + self._connections[topic].consumer.shutdown() + del self._connections[topic] + def connect_dap_slot(self, slot, dap_name): if dap_name not in self._daps: # create a new consumer and connect slot diff --git a/bec_widgets/scan_plot.py b/bec_widgets/scan_plot.py index 68393298..dbe1eec2 100644 --- a/bec_widgets/scan_plot.py +++ b/bec_widgets/scan_plot.py @@ -1,6 +1,7 @@ import itertools import pyqtgraph as pg +from bec_lib.core import MessageEndpoints from bec_lib.core.logger import bec_logger from PyQt5.QtCore import pyqtProperty, pyqtSlot @@ -61,8 +62,9 @@ class BECScanPlot(pg.GraphicsView): plot_curve.setData(x=[*x, x_new], y=[*y, y_new]) - @pyqtSlot(dict, dict) - def redraw_dap(self, data, _metadata): + @pyqtSlot("PyQt_PyObject") + def redraw_dap(self, msg): + data = msg.content["data"] for chan, plot_curve in self.dap_curves.items(): if not chan: continue @@ -86,7 +88,8 @@ class BECScanPlot(pg.GraphicsView): chan_removed = [chan for chan in self._y_channel_list if chan not in new_list] if chan_removed and chan_removed[0].startswith("dap."): chan_removed = chan_removed[0].partition("dap.")[-1] - bec_dispatcher.disconnect_dap_slot(self.redraw_dap, chan_removed) + chan_removed_ep = MessageEndpoints.processed_data(chan_removed) + bec_dispatcher.disconnect_slot(self.redraw_dap, chan_removed_ep) self._y_channel_list = new_list @@ -100,7 +103,8 @@ class BECScanPlot(pg.GraphicsView): if y_chan.startswith("dap."): y_chan = y_chan.partition("dap.")[-1] curves = self.dap_curves - bec_dispatcher.connect_dap_slot(self.redraw_dap, y_chan) + y_chan_ep = MessageEndpoints.processed_data(y_chan) + bec_dispatcher.connect_slot(self.redraw_dap, y_chan_ep) else: curves = self.scan_curves