refactor: use BECClient for cb on scan_segment

This commit is contained in:
2023-07-24 10:16:23 +02:00
parent 87163fde32
commit ff534ad67f
3 changed files with 28 additions and 33 deletions
+13 -17
View File
@@ -1,19 +1,21 @@
from collections import defaultdict
from threading import RLock
from bec_lib.core import BECMessage, MessageEndpoints, RedisConnector
from bec_lib import BECClient
from bec_lib.core import BECMessage, MessageEndpoints
from PyQt5.QtCore import QObject, pyqtSignal
bec_connector = RedisConnector("localhost:6379")
class _BECDispatcher(QObject):
scan_segment = pyqtSignal("PyQt_PyObject")
new_scan = pyqtSignal(dict, dict)
scan_segment = pyqtSignal(dict, dict)
new_dap_data = pyqtSignal(dict)
new_scan = pyqtSignal("PyQt_PyObject")
def __init__(self):
super().__init__()
self.client = BECClient()
self.client.start()
# TODO: dap might not be a good fit to predefined slots, fix this inconsistency
self._slot_signal_map = {
"on_scan_segment": self.scan_segment,
@@ -25,22 +27,16 @@ class _BECDispatcher(QObject):
scan_lock = RLock()
self._dap_threads = []
def _scan_cb(msg):
msg = BECMessage.ScanMessage.loads(msg.value)[0]
def _scan_segment_cb(scan_segment, metadata):
with scan_lock:
# TODO: use ScanStatusMessage instead?
scan_id = msg.content["scanID"]
scan_id = metadata["scanID"]
if self._scan_id != scan_id:
self._scan_id = scan_id
self.new_scan.emit(msg)
self.scan_segment.emit(msg)
self.new_scan.emit(scan_segment, metadata)
self.scan_segment.emit(scan_segment, metadata)
scan_readback = MessageEndpoints.scan_segment()
self._scan_thread = bec_connector.consumer(
topics=scan_readback,
cb=_scan_cb,
)
self._scan_thread.start()
self.client.callbacks.register("scan_segment", _scan_segment_cb, sync=False)
def connect(self, widget):
for slot_name, signal in self._slot_signal_map.items():
@@ -56,7 +52,7 @@ class _BECDispatcher(QObject):
self.new_dap_data.emit(msg.content["data"])
dap_ep = MessageEndpoints.processed_data(dap_name)
dap_thread = bec_connector.consumer(topics=dap_ep, cb=_dap_cb)
dap_thread = self.client.connector.consumer(topics=dap_ep, cb=_dap_cb)
dap_thread.start()
self._dap_threads.append(dap_thread)
+10 -11
View File
@@ -33,18 +33,18 @@ class BECScanPlot2D(pg.GraphicsView):
self.imageItem = pg.ImageItem()
self.plot_item.addItem(self.imageItem)
@pyqtSlot("PyQt_PyObject")
def on_new_scan(self, msg):
@pyqtSlot(dict, dict)
def on_new_scan(self, _scan_segment, metadata):
# TODO: Do we reset in case of a scan type change?
self.imageItem.clear()
# TODO: better to check the number of coordinates in metadata["positions"]?
if msg.metadata["scan_name"] != "grid_scan":
if metadata["scan_name"] != "grid_scan":
return
positions = [sorted(set(pos)) for pos in zip(*msg.metadata["positions"])]
positions = [sorted(set(pos)) for pos in zip(*metadata["positions"])]
motors = msg.metadata["scan_motors"]
motors = metadata["scan_motors"]
if self.x_channel and self.y_channel:
self._x_ind = motors.index(self.x_channel) if self.x_channel in motors else None
self._y_ind = motors.index(self.y_channel) if self.y_channel in motors else None
@@ -77,21 +77,20 @@ class BECScanPlot2D(pg.GraphicsView):
self.plot_item.setLabel("bottom", motors[self._x_ind])
self.plot_item.setLabel("left", motors[self._y_ind])
@pyqtSlot("PyQt_PyObject")
def on_scan_segment(self, msg):
if not self.z_channel or msg.metadata["scan_name"] != "grid_scan":
@pyqtSlot(dict, dict)
def on_scan_segment(self, scan_segment, metadata):
if not self.z_channel or metadata["scan_name"] != "grid_scan":
return
if self._x_ind is None or self._y_ind is None:
return
point_id = msg.content["point_id"]
point_coord = msg.metadata["positions"][point_id]
point_coord = metadata["positions"][scan_segment["point_id"]]
x_coord_ind = self._xpos.index(point_coord[self._x_ind])
y_coord_ind = self._ypos.index(point_coord[self._y_ind])
data = msg.content["data"]
data = scan_segment["data"]
z_new = data[self.z_channel][self.z_channel]["value"]
image = self.imageItem.image
+5 -5
View File
@@ -27,17 +27,17 @@ class BECScanPlot(pg.GraphicsView):
self.scan_curves = {}
self.dap_curves = {}
@pyqtSlot("PyQt_PyObject")
def on_new_scan(self, _msg):
@pyqtSlot(dict, dict)
def on_new_scan(self, _scan_segment, _metadata):
for plot_curve in {**self.scan_curves, **self.dap_curves}.values():
plot_curve.setData(x=[], y=[])
@pyqtSlot("PyQt_PyObject")
def on_scan_segment(self, msg):
@pyqtSlot(dict, dict)
def on_scan_segment(self, scan_segment, _metadata):
if not self.x_channel:
return
data = msg.content["data"]
data = scan_segment["data"]
if self.x_channel not in data:
logger.warning(f"Unknown channel `{self.x_channel}` for X data in {self.objectName()}")