diff --git a/bec_widgets/widgets/plots/image/image.py b/bec_widgets/widgets/plots/image/image.py index a3c01be3..3f2f9310 100644 --- a/bec_widgets/widgets/plots/image/image.py +++ b/bec_widgets/widgets/plots/image/image.py @@ -7,8 +7,8 @@ import numpy as np from bec_lib import bec_logger from bec_lib.endpoints import MessageEndpoints from pydantic import BaseModel, Field, field_validator -from qtpy.QtCore import Qt, QTimer -from qtpy.QtWidgets import QComboBox, QStyledItemDelegate, QWidget +from qtpy.QtCore import QTimer +from qtpy.QtWidgets import QComboBox, QWidget from bec_widgets.utils import ConnectionConfig from bec_widgets.utils.colors import Colors @@ -49,6 +49,9 @@ class ImageLayerConfig(BaseModel): source: Literal["device_monitor_1d", "device_monitor_2d", "auto"] = Field( "auto", description="The source of the image data." ) + async_signal_name: str | None = Field( + None, description="Async signal name (obj_name) used for async endpoints." + ) class Image(ImageBase): @@ -116,7 +119,11 @@ class Image(ImageBase): ) self._init_toolbar_image() self.layer_removed.connect(self._on_layer_removed) + self.old_scan_id = None self.scan_id = None + self.async_update = False + self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status()) + self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress()) ################################## ### Toolbar Initialization @@ -181,18 +188,25 @@ class Image(ImageBase): Adjust the size of the device combo box and populate it with preview signals. Has to be done with QTimer.singleShot to ensure the UI is fully initialized, needed for testing. """ - self._populate_preview_signals() + self._populate_signals() self._reverse_device_items() self.device_combo_box.setCurrentText("") # set again default to empty string - def _populate_preview_signals(self) -> None: + def _populate_signals(self) -> None: """ Populate the device combo box with preview-signal devices in the format '_' and store the tuple(device, signal) in the item's userData for later use. """ preview_signals = self.client.device_manager.get_bec_signals("PreviewSignal") - for device, signal, signal_config in preview_signals: + async_signals = self.client.device_manager.get_bec_signals("AsyncSignal") + all_signals = preview_signals + async_signals + for device, signal, signal_config in all_signals: + describe = signal_config.get("describe") or {} + signal_info = describe.get("signal_info") or {} + ndim = signal_info.get("ndim") + if ndim == 0: + continue label = signal_config.get("obj_name", f"{device}_{signal}") self.device_combo_box.addItem(label, (device, signal, signal_config)) @@ -422,51 +436,89 @@ class Image(ImageBase): """ # TODO consider moving connecting and disconnecting logic to Image itself if multiple images + self.async_update = False + config = self.subscriptions["main"] + needs_async_setup = False + config.async_signal_name = None if isinstance(monitor, (list, tuple)): - device = self.dev[monitor[0]] + try: + device = self.dev[monitor[0]] + except KeyError: + logger.warning(f"Device '{monitor[0]}' not found; cannot connect monitor.") + return signal = monitor[1] if len(monitor) == 3: signal_config = monitor[2] else: - signal_config = device._info["signals"][signal] + try: + signal_config = device._info["signals"][signal] + except KeyError: + logger.warning(f"Signal '{signal}' not found on device '{device.name}'.") + return signal_class = signal_config.get("signal_class", None) - if signal_class != "PreviewSignal": - logger.warning(f"Signal '{monitor}' is not a PreviewSignal.") + if signal_class not in ["PreviewSignal", "AsyncSignal"]: + logger.warning(f"Signal '{monitor}' is not a PreviewSignal or AsyncSignal.") return - ndim = signal_config.get("describe", None).get("signal_info", None).get("ndim", None) + describe = signal_config.get("describe") or {} + signal_info = describe.get("signal_info") or {} + ndim = signal_info.get("ndim", None) if ndim is None: logger.warning( f"Signal '{monitor}' does not have a valid 'ndim' in its signal_info." ) return - if ndim == 1: - self.bec_dispatcher.connect_slot( - self.on_image_update_1d, MessageEndpoints.device_preview(device.name, signal) - ) - self.subscriptions["main"].source = "device_monitor_1d" - self.subscriptions["main"].monitor_type = "1d" + config.source = "device_monitor_1d" + config.monitor_type = "1d" + if signal_class == "PreviewSignal": + self.bec_dispatcher.connect_slot( + self.on_image_update_1d, + MessageEndpoints.device_preview(device.name, signal), + ) + elif signal_class == "AsyncSignal": + self.async_update = True + needs_async_setup = True + config.async_signal_name = signal_config.get( + "obj_name", f"{device.name}_{signal}" + ) + else: + logger.warning(f"Unsupported signal class '{signal_class}' for 1D monitor.") + return elif ndim == 2: - self.bec_dispatcher.connect_slot( - self.on_image_update_2d, MessageEndpoints.device_preview(device.name, signal) - ) - self.subscriptions["main"].source = "device_monitor_2d" - self.subscriptions["main"].monitor_type = "2d" + config.source = "device_monitor_2d" + config.monitor_type = "2d" + if signal_class == "PreviewSignal": + self.bec_dispatcher.connect_slot( + self.on_image_update_2d, + MessageEndpoints.device_preview(device.name, signal), + ) + elif signal_class == "AsyncSignal": + self.async_update = True + needs_async_setup = True + config.async_signal_name = signal_config.get( + "obj_name", f"{device.name}_{signal}" + ) + else: + logger.warning(f"Unsupported signal class '{signal_class}' for 2D monitor.") + return + else: + logger.warning(f"Unsupported ndim '{ndim}' for monitor '{monitor}'.") + return else: # FIXME old monitor 1d/2d endpoint handling, present for backwards compatibility, will be removed in future versions if type == "1d": self.bec_dispatcher.connect_slot( self.on_image_update_1d, MessageEndpoints.device_monitor_1d(monitor) ) - self.subscriptions["main"].source = "device_monitor_1d" - self.subscriptions["main"].monitor_type = "1d" + config.source = "device_monitor_1d" + config.monitor_type = "1d" elif type == "2d": self.bec_dispatcher.connect_slot( self.on_image_update_2d, MessageEndpoints.device_monitor_2d(monitor) ) - self.subscriptions["main"].source = "device_monitor_2d" - self.subscriptions["main"].monitor_type = "2d" + config.source = "device_monitor_2d" + config.monitor_type = "2d" elif type == "auto": self.bec_dispatcher.connect_slot( self.on_image_update_1d, MessageEndpoints.device_monitor_1d(monitor) @@ -474,14 +526,121 @@ class Image(ImageBase): self.bec_dispatcher.connect_slot( self.on_image_update_2d, MessageEndpoints.device_monitor_2d(monitor) ) - self.subscriptions["main"].source = "auto" + config.source = "auto" logger.warning( f"Updates for '{monitor}' will be fetch from both 1D and 2D monitor endpoints." ) - self.subscriptions["main"].monitor_type = "auto" + config.monitor_type = "auto" + config.monitor = monitor + if needs_async_setup: + self._setup_async_image(self.scan_id) logger.info(f"Connected to {monitor} with type {type}") - self.subscriptions["main"].monitor = monitor + + @SafeSlot(dict, dict) + def on_scan_status(self, msg: dict, meta: dict): + """ + Initial scan status message handler, which is triggered at the begging and end of scan. + Needed for setup of AsyncSignal connections. + + Args: + msg(dict): The message content. + meta(dict): The message metadata. + """ + current_scan_id = msg.get("scan_id", None) + if current_scan_id is None: + return + self._handle_scan_change(current_scan_id) + + @SafeSlot(dict, dict) + def on_scan_progress(self, msg: dict, meta: dict): + """ + For setting async image readback during scan progress updates if widget is started later than scan. + + Args: + msg(dict): The message content. + meta(dict): The message metadata. + """ + current_scan_id = meta.get("scan_id", None) + if current_scan_id is None: + return + self._handle_scan_change(current_scan_id) + + def _handle_scan_change(self, current_scan_id: str): + """ + Update internal scan ids and refresh async connections if needed. + + Args: + current_scan_id (str): The current scan identifier. + """ + if current_scan_id == self.scan_id: + return + self.old_scan_id = self.scan_id + self.scan_id = current_scan_id + if self.async_update: + self._setup_async_image(scan_id=self.scan_id) + + def _get_async_signal_name(self) -> tuple[str, str] | None: + """ + Returns device name and async signal name used for endpoints/messages. + + Returns: + tuple[str, str] | None: (device_name, async_signal_name) or None if not available. + """ + config = self.subscriptions["main"] + monitor = config.monitor + if monitor is None or not isinstance(monitor, (list, tuple)) or len(monitor) < 2: + return None + device_name = monitor[0] + async_signal = config.async_signal_name or monitor[1] + return device_name, async_signal + + def _setup_async_image(self, scan_id: str | None): + """ + (Re)connect async image readback for the current scan. + + Args: + scan_id (str | None): The scan identifier to subscribe to. + """ + if not self.async_update: + return + + config = self.subscriptions["main"] + async_names = self._get_async_signal_name() + if async_names is None: + logger.info("Async image setup skipped because monitor information is incomplete.") + return + + device_name, async_signal = async_names + if config.monitor_type == "1d": + slot = self.on_image_update_1d + elif config.monitor_type == "2d": + slot = self.on_image_update_2d + else: + logger.warning( + f"Async image setup skipped due to unsupported monitor type '{config.monitor_type}'." + ) + return + + # Disconnect any previous scan subscriptions to avoid stale updates. + for prev_scan_id in (self.old_scan_id, self.scan_id): + if prev_scan_id is None: + continue + self.bec_dispatcher.disconnect_slot( + slot, MessageEndpoints.device_async_signal(prev_scan_id, device_name, async_signal) + ) + + if scan_id is None: + logger.info("Scan ID not available yet; delaying async image subscription.") + return + + self.bec_dispatcher.connect_slot( + slot, + MessageEndpoints.device_async_signal(scan_id, device_name, async_signal), + from_start=True, + cb_info={"scan_id": scan_id}, + ) + logger.info(f"Setup async image for {device_name}.{async_signal} and scan {scan_id}.") def disconnect_monitor(self, monitor: str | tuple): """ @@ -490,20 +649,47 @@ class Image(ImageBase): Args: monitor(str|tuple): The name of the monitor to disconnect, or a tuple of (device, signal) for preview signals. """ + config = self.subscriptions["main"] if isinstance(monitor, (list, tuple)): - if self.subscriptions["main"].source == "device_monitor_1d": - self.bec_dispatcher.disconnect_slot( - self.on_image_update_1d, MessageEndpoints.device_preview(monitor[0], monitor[1]) - ) - elif self.subscriptions["main"].source == "device_monitor_2d": - self.bec_dispatcher.disconnect_slot( - self.on_image_update_2d, MessageEndpoints.device_preview(monitor[0], monitor[1]) - ) + if self.async_update: + async_names = self._get_async_signal_name() + ids_to_check = [self.scan_id, self.old_scan_id] + if config.source == "device_monitor_1d": + for scan_id in ids_to_check: + if scan_id is None or async_names is None: + continue + self.bec_dispatcher.disconnect_slot( + self.on_image_update_1d, + MessageEndpoints.device_async_signal( + scan_id, async_names[0], async_names[1] + ), + ) + elif config.source == "device_monitor_2d": + for scan_id in ids_to_check: + if scan_id is None or async_names is None: + continue + self.bec_dispatcher.disconnect_slot( + self.on_image_update_2d, + MessageEndpoints.device_async_signal( + scan_id, async_names[0], async_names[1] + ), + ) else: - logger.warning( - f"Cannot disconnect monitor {monitor} with source {self.subscriptions['main'].source}" - ) - return + if config.source == "device_monitor_1d": + self.bec_dispatcher.disconnect_slot( + self.on_image_update_1d, + MessageEndpoints.device_preview(monitor[0], monitor[1]), + ) + elif config.source == "device_monitor_2d": + self.bec_dispatcher.disconnect_slot( + self.on_image_update_2d, + MessageEndpoints.device_preview(monitor[0], monitor[1]), + ) + else: + logger.warning( + f"Cannot disconnect monitor {monitor} with source {self.subscriptions['main'].source}" + ) + return else: # FIXME old monitor 1d/2d endpoint handling, present for backwards compatibility, will be removed in future versions self.bec_dispatcher.disconnect_slot( self.on_image_update_1d, MessageEndpoints.device_monitor_1d(monitor) @@ -512,6 +698,8 @@ class Image(ImageBase): self.on_image_update_2d, MessageEndpoints.device_monitor_2d(monitor) ) self.subscriptions["main"].monitor = None + self.subscriptions["main"].async_signal_name = None + self.async_update = False self._sync_device_selection() ######################################## @@ -526,7 +714,7 @@ class Image(ImageBase): msg(dict): The message containing the data. metadata(dict): The metadata associated with the message. """ - data = msg["data"] + data = self._get_payload_data(msg) current_scan_id = metadata.get("scan_id", None) if current_scan_id is None: @@ -538,6 +726,9 @@ class Image(ImageBase): self.main_image.max_len = 0 if self.crosshair is not None: self.crosshair.reset() + if data is None: + logger.warning("No data received for image update.") + return image_buffer = self.adjust_image_buffer(self.main_image, data) if self._color_bar is not None: self._color_bar.blockSignals(True) @@ -590,7 +781,10 @@ class Image(ImageBase): msg(dict): The message containing the data. metadata(dict): The metadata associated with the message. """ - data = msg["data"] + data = self._get_payload_data(msg) + if data is None: + logger.warning("No data received for image update.") + return if self._color_bar is not None: self._color_bar.blockSignals(True) self.main_image.set_data(data) @@ -598,6 +792,22 @@ class Image(ImageBase): self._color_bar.blockSignals(False) self.image_updated.emit() + def _get_payload_data(self, msg: dict) -> np.ndarray | None: + """ + Extract payload from async/preview/monitor1D/2D message structures due to inconsistent formats in backend. + + Args: + msg (dict): The incoming message containing data. + """ + if not self.async_update: + return msg.get("data") + async_names = self._get_async_signal_name() + if async_names is None: + logger.warning("Async payload extraction failed; monitor info incomplete.") + return None + _, async_signal = async_names + return msg.get("signals", {}).get(async_signal, {}).get("value", None) + ################################################################################ # Clean up ################################################################################ @@ -634,6 +844,8 @@ class Image(ImageBase): self.device_combo_box.deleteLater() self.dim_combo_box.close() self.dim_combo_box.deleteLater() + self.bec_dispatcher.disconnect_slot(self.on_scan_status, MessageEndpoints.scan_status()) + self.bec_dispatcher.disconnect_slot(self.on_scan_progress, MessageEndpoints.scan_progress()) super().cleanup() diff --git a/tests/unit_tests/test_image_view_next_gen.py b/tests/unit_tests/test_image_view_next_gen.py index 78e34705..cebb6ae4 100644 --- a/tests/unit_tests/test_image_view_next_gen.py +++ b/tests/unit_tests/test_image_view_next_gen.py @@ -1,6 +1,7 @@ import numpy as np import pyqtgraph as pg import pytest +from bec_lib.endpoints import MessageEndpoints from qtpy.QtCore import QPointF from bec_widgets.widgets.plots.image.image import Image @@ -178,6 +179,114 @@ def test_image_setup_preview_signal_2d(qtbot, mocked_client, monkeypatch): np.testing.assert_array_equal(view.main_image.image, test_data) +def test_preview_signals_skip_0d_entries(qtbot, mocked_client, monkeypatch): + """ + Preview/async combobox should omit 0‑D signals. + """ + view = create_widget(qtbot, Image, client=mocked_client) + + def fake_get(sign_cls): + if sign_cls == "PreviewSignal": + return [ + ( + "dev", + "sig0d", + { + "obj_name": "sig0d", + "signal_class": "PreviewSignal", + "describe": {"signal_info": {"ndim": 0}}, + }, + ), + ( + "dev", + "sig2d", + { + "obj_name": "sig2d", + "signal_class": "PreviewSignal", + "describe": {"signal_info": {"ndim": 2}}, + }, + ), + ] + return [] + + monkeypatch.setattr(view.client.device_manager, "get_bec_signals", fake_get) + view.device_combo_box.clear() + view.device_combo_box.addItem("", None) + view._populate_signals() + + texts = [view.device_combo_box.itemText(i) for i in range(view.device_combo_box.count())] + assert "sig0d" not in texts + assert "sig2d" in texts + + +def test_image_async_signal_uses_obj_name(qtbot, mocked_client, monkeypatch): + """ + Verify async signals use obj_name for endpoints/payloads and reconnect with scan_id. + """ + view = create_widget(qtbot, Image, client=mocked_client) + signal_config = { + "obj_name": "async_obj", + "signal_class": "AsyncSignal", + "describe": {"signal_info": {"ndim": 1}}, + } + + view.image(monitor=("eiger", "img", signal_config)) + assert view.subscriptions["main"].async_signal_name == "async_obj" + + # Prepare scan ids and capture dispatcher calls + view.old_scan_id = "old_scan" + view.scan_id = "new_scan" + connected = [] + disconnected = [] + monkeypatch.setattr( + view.bec_dispatcher, + "connect_slot", + lambda slot, endpoint, from_start=False, cb_info=None: connected.append( + (slot, endpoint, from_start, cb_info) + ), + ) + monkeypatch.setattr( + view.bec_dispatcher, + "disconnect_slot", + lambda slot, endpoint: disconnected.append((slot, endpoint)), + ) + + view._setup_async_image(view.scan_id) + + expected_new = MessageEndpoints.device_async_signal("new_scan", "eiger", "async_obj") + expected_old = MessageEndpoints.device_async_signal("old_scan", "eiger", "async_obj") + assert any(ep == expected_new for _, ep, _, _ in connected) + assert any(ep == expected_old for _, ep in disconnected) + + # Payload extraction should use obj_name + payload = np.array([1, 2, 3]) + msg = {"signals": {"async_obj": {"value": payload}}} + assert view._get_payload_data(msg) is payload + + +def test_disconnect_monitor_clears_async_state(qtbot, mocked_client, monkeypatch): + view = create_widget(qtbot, Image, client=mocked_client) + signal_config = { + "obj_name": "async_obj", + "signal_class": "AsyncSignal", + "describe": {"signal_info": {"ndim": 2}}, + } + + view.image(monitor=("eiger", "img", signal_config)) + view.scan_id = "scan_x" + view.old_scan_id = "scan_y" + view.subscriptions["main"].async_signal_name = "async_obj" + + # Avoid touching real dispatcher + monkeypatch.setattr(view.bec_dispatcher, "disconnect_slot", lambda *args, **kwargs: None) + + view.disconnect_monitor(("eiger", "img", signal_config)) + + assert view.subscriptions["main"].monitor is None + assert view.subscriptions["main"].async_signal_name is None + assert view.async_update is False + + ############################################## # Device monitor endpoint update mechanism @@ -600,9 +709,9 @@ def test_monitor_selection_reverse_device_items(qtbot, mocked_client): assert combo.currentText() == "samy" -def test_monitor_selection_populate_preview_signals(qtbot, mocked_client, monkeypatch): +def test_monitor_selection_populate_signals(qtbot, mocked_client, monkeypatch): """ - Verify that _populate_preview_signals adds preview‑signal devices to the combo‑box + Verify that _populate_signals adds preview‑signal and async-signal devices to the combo‑box with the correct userData. """ view = create_widget(qtbot, Image, client=mocked_client) @@ -610,23 +719,32 @@ def test_monitor_selection_populate_preview_signals(qtbot, mocked_client, monkey # Provide a deterministic fake device_manager with get_bec_signals class _FakeDM: def get_bec_signals(self, _filter): - return [ - ("eiger", "img", {"obj_name": "eiger_img"}), - ("async_device", "img2", {"obj_name": "async_device_img2"}), - ] + if _filter == "PreviewSignal": + return [ + ("eiger", "img", {"obj_name": "eiger_img"}), + ("eiger2", "img2", {"obj_name": "eiger_img2"}), + ] + if _filter == "AsyncSignal": + return [("async_device", "img_async", {"obj_name": "async_device_img_async"})] + return [] monkeypatch.setattr(view.client, "device_manager", _FakeDM()) initial_count = view.device_combo_box.count() - view._populate_preview_signals() + view._populate_signals() - # Two new entries should have been added - assert view.device_combo_box.count() == initial_count + 2 + # PreviewSignal + AsyncSignal entries were added + assert view.device_combo_box.count() == initial_count + 3 # The first newly added item should carry tuple userData describing the device/signal data = view.device_combo_box.itemData(initial_count) assert isinstance(data, tuple) and data[0] == "eiger" + texts = [ + view.device_combo_box.itemText(i) + for i in range(initial_count, view.device_combo_box.count()) + ] + assert "async_device_img_async" in texts def test_monitor_selection_adjust_and_connect(qtbot, mocked_client, monkeypatch): @@ -641,7 +759,9 @@ def test_monitor_selection_adjust_and_connect(qtbot, mocked_client, monkeypatch) # Deterministic fake device_manager class _FakeDM: def get_bec_signals(self, _filter): - return [("eiger", "img", {"obj_name": "eiger_img"})] + if _filter == "PreviewSignal": + return [("eiger", "img", {"obj_name": "eiger_img"})] + return [] monkeypatch.setattr(view.client, "device_manager", _FakeDM())