From 1325704750ebab897e3dcae80c9d455bfbbf886f Mon Sep 17 00:00:00 2001 From: Ivan Usov Date: Mon, 31 Jul 2023 16:52:04 +0200 Subject: [PATCH] feat: add disconnect_dap_slot --- bec_widgets/bec_dispatcher.py | 46 ++++++++++++++++++++++++++--------- bec_widgets/scan_plot.py | 9 +++++-- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/bec_widgets/bec_dispatcher.py b/bec_widgets/bec_dispatcher.py index fc72a8d2..4a89cf5e 100644 --- a/bec_widgets/bec_dispatcher.py +++ b/bec_widgets/bec_dispatcher.py @@ -1,11 +1,20 @@ -from collections import defaultdict +from dataclasses import dataclass from threading import RLock from bec_lib import BECClient from bec_lib.core import BECMessage, MessageEndpoints +from bec_lib.core.redis_connector import RedisConsumerThreaded from PyQt5.QtCore import QObject, pyqtSignal +@dataclass +class _BECDap: + """Utility class to keep track of slots associated with a particular dap redis consumer""" + + consumer: RedisConsumerThreaded + slots = set() + + class _BECDispatcher(QObject): new_scan = pyqtSignal(dict, dict) scan_segment = pyqtSignal(dict, dict) @@ -16,16 +25,14 @@ class _BECDispatcher(QObject): self.client = BECClient() self.client.start() - # TODO: dap might not be a good fit to predefined slots, fix this inconsistency self._slot_signal_map = { "on_scan_segment": self.scan_segment, "on_new_scan": self.new_scan, } - self._daps = defaultdict(set) + self._daps = {} self._scan_id = None scan_lock = RLock() - self._dap_threads = [] def _scan_segment_cb(scan_segment, metadata): with scan_lock: @@ -44,26 +51,43 @@ class _BECDispatcher(QObject): if callable(slot): signal.connect(slot) - def connect_dap(self, slot, dap_name): + def connect_dap_slot(self, slot, dap_name): if dap_name not in self._daps: + # create a new consumer and connect slot def _dap_cb(msg): msg = BECMessage.ProcessedDataMessage.loads(msg.value) self.new_dap_data.emit(msg.content["data"]) dap_ep = MessageEndpoints.processed_data(dap_name) - dap_thread = self.client.connector.consumer(topics=dap_ep, cb=_dap_cb) - dap_thread.start() - self._dap_threads.append(dap_thread) + consumer = self.client.connector.consumer(topics=dap_ep, cb=_dap_cb) + consumer.start() self.new_dap_data.connect(slot) - self._daps[dap_name].add(slot) + + self._daps[dap_name] = _BECDap(consumer) + self._daps[dap_name].slots.add(slot) else: # connect slot if it's not yet connected - if slot not in self._daps[dap_name]: - self._daps[dap_name].add(slot) + if slot not in self._daps[dap_name].slots: self.new_dap_data.connect(slot) + self._daps[dap_name].slots.add(slot) + + def disconnect_dap_slot(self, slot, dap_name): + if dap_name not in self._daps: + return + + if slot not in self._daps[dap_name].slots: + return + + self.new_dap_data.disconnect(slot) + self._daps[dap_name].slots.remove(slot) + + if not self._daps[dap_name].slots: + # shutdown consumer if there are no more connected slots + self._daps[dap_name].consumer.shutdown() + del self._daps[dap_name] bec_dispatcher = _BECDispatcher() diff --git a/bec_widgets/scan_plot.py b/bec_widgets/scan_plot.py index 8905c588..b565283f 100644 --- a/bec_widgets/scan_plot.py +++ b/bec_widgets/scan_plot.py @@ -82,6 +82,12 @@ class BECScanPlot(pg.GraphicsView): @y_channel_list.setter def y_channel_list(self, new_list): + # TODO: do we want to care about dap/not dap here? + chan_removed = [chan for chan in self._y_channel_list if chan not in new_list] + if chan_removed and chan_removed[0].startswith("dap."): + chan_removed = chan_removed[0].partition("dap.")[-1] + bec_dispatcher.disconnect_dap_slot(self.redraw_dap, chan_removed) + self._y_channel_list = new_list # Prepare plot for a potentially different list of y channels @@ -91,11 +97,10 @@ class BECScanPlot(pg.GraphicsView): colors = itertools.cycle(COLORS) for y_chan in new_list: - # TODO: ideally, we dont want to care about dap/not dap here if y_chan.startswith("dap."): y_chan = y_chan.partition("dap.")[-1] curves = self.dap_curves - bec_dispatcher.connect_dap(self.redraw_dap, y_chan) + bec_dispatcher.connect_dap_slot(self.redraw_dap, y_chan) else: curves = self.scan_curves