diff --git a/bec_widgets/widgets/containers/main_window/main_window.py b/bec_widgets/widgets/containers/main_window/main_window.py index 8c4f7d57..312f8919 100644 --- a/bec_widgets/widgets/containers/main_window/main_window.py +++ b/bec_widgets/widgets/containers/main_window/main_window.py @@ -209,6 +209,7 @@ class BECMainWindow(BECWidget, QMainWindow): self._scan_progress_bar_simple.progressbar.label_template = "" self._scan_progress_bar_simple.progressbar.setFixedHeight(self.SCAN_PROGRESS_HEIGHT) self._scan_progress_bar_simple.progressbar.setFixedWidth(self.SCAN_PROGRESS_WIDTH) + # This one do not need dynamic styling on hover ScanProgressBar since user will hover on it probably later, when progress bar is big enough self._scan_progress_bar_full = ScanProgressBar( self, rpc_exposed=False, rpc_passthrough_children=False, enable_dynamic_stylesheet=False ) @@ -237,8 +238,8 @@ class BECMainWindow(BECWidget, QMainWindow): # The actual line line = QFrame() - line.setFrameShape(QFrame.VLine) - line.setFrameShadow(QFrame.Sunken) + line.setFrameShape(QFrame.Shape.VLine) + line.setFrameShadow(QFrame.Shadow.Sunken) line.setFixedHeight(status_bar.sizeHint().height() - 2) # Wrapper to center the line vertically -> work around for QFrame not being able to center itself @@ -246,7 +247,7 @@ class BECMainWindow(BECWidget, QMainWindow): vbox = QVBoxLayout(wrapper) vbox.setContentsMargins(0, 0, 0, 0) vbox.addStretch() - vbox.addWidget(line, alignment=Qt.AlignHCenter) + vbox.addWidget(line, alignment=Qt.AlignmentFlag.AlignHCenter) vbox.addStretch() wrapper.setFixedWidth(line.sizeHint().width()) diff --git a/bec_widgets/widgets/progress/bec_progressbar/bec_progressbar.py b/bec_widgets/widgets/progress/bec_progressbar/bec_progressbar.py index 51d7e158..a1951013 100644 --- a/bec_widgets/widgets/progress/bec_progressbar/bec_progressbar.py +++ b/bec_widgets/widgets/progress/bec_progressbar/bec_progressbar.py @@ -20,7 +20,7 @@ class ProgressState(Enum): @classmethod def from_bec_status(cls, status: str) -> "ProgressState": """ - Map a BEC status string (open, paused, aborted, halted, closed) + Map a BEC status string (open, paused, aborted, halt/halted, closed, user_completed) to the corresponding ProgressState. Any unknown status falls back to NORMAL. """ @@ -28,8 +28,10 @@ class ProgressState(Enum): "open": cls.NORMAL, "paused": cls.PAUSED, "aborted": cls.INTERRUPTED, + "halt": cls.PAUSED, "halted": cls.PAUSED, "closed": cls.COMPLETED, + "user_completed": cls.PAUSED, } return mapping.get(status.lower(), cls.NORMAL) @@ -104,9 +106,6 @@ class BECProgressBar(BECWidget, QWidget): self.progressbar.setMinimumHeight(0) self.progressbar.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Ignored) - # Backwards-compatible alias used by existing tests and downstream code. - self.center_label = self.progressbar - self._layout = QVBoxLayout(self) self._layout.setContentsMargins(self._padding_left_right, 0, self._padding_left_right, 0) self._layout.setSpacing(0) @@ -339,6 +338,7 @@ class BECProgressBar(BECWidget, QWidget): def _setup_style_sheet(self, *, chunk_radius: int) -> None: radius = int(round(self._corner_radius)) + chunk_color = self._state_colors[self._current_visual_state()].name() self.progressbar.setStyleSheet(f""" QProgressBar {{ background-color: palette(mid); @@ -348,7 +348,7 @@ class BECProgressBar(BECWidget, QWidget): text-align: center; }} QProgressBar::chunk {{ - background-color: palette(highlight); + background-color: {chunk_color}; border-radius: {chunk_radius}px; }} """) @@ -377,6 +377,11 @@ class BECProgressBar(BECWidget, QWidget): return 0 if self._enable_dynamic_stylesheet else self._target_chunk_radius() def _calculate_chunk_radius(self, target_radius: int) -> int: + """ + This whole chunk logic is to calculater radius based on the current size. + If the radius is smaller than size of the progressbar it is just not applied. + The chunk stylesheet logic is smoothing it as much as possible. + """ if target_radius <= 0 or self._maximum <= 0: return 0 fill_width = self.progressbar.width() * min(1.0, max(0.0, self._value / self._maximum)) @@ -385,6 +390,16 @@ class BECProgressBar(BECWidget, QWidget): return min(target_radius, max(1, int(fill_width / 2))) def _apply_state_style(self) -> None: + chunk_radius = self._chunk_radius + if chunk_radius is None: + target_radius = self._target_chunk_radius() + chunk_radius = ( + self._calculate_chunk_radius(target_radius) + if self._enable_dynamic_stylesheet + else target_radius + ) + self._chunk_radius = chunk_radius + self._setup_style_sheet(chunk_radius=chunk_radius) color = self._state_colors[self._current_visual_state()] palette = self.progressbar.palette() palette.setColor(QPalette.ColorRole.Highlight, color) @@ -406,20 +421,23 @@ class BECProgressBar(BECWidget, QWidget): if __name__ == "__main__": # pragma: no cover app = QApplication(sys.argv) - progressBar = BECProgressBar() - progressBar.show() - progressBar.set_minimum(-100) - progressBar.set_maximum(0) + progress_bar = BECProgressBar() + progress_bar.setWindowTitle("BEC Progress Bar") + progress_bar.resize(360, 48) + progress_bar.set_minimum(-100) + progress_bar.set_maximum(0) + progress_bar.set_value(-100) + progress_bar.show() # Example of setting values def update_progress(): - value = progressBar._user_value + 2.5 - if value > progressBar._user_maximum: - value = -100 # progressBar._maximum / progressBar._upsampling_factor - progressBar.set_value(value) + value = progress_bar._user_value + 2.5 + if value > progress_bar._user_maximum: + value = progress_bar._user_minimum + progress_bar.set_value(value) - timer = QTimer() + timer = QTimer(progress_bar) timer.timeout.connect(update_progress) - timer.start(200) # Update every half second + timer.start(200) sys.exit(app.exec()) diff --git a/bec_widgets/widgets/progress/progress_backend.py b/bec_widgets/widgets/progress/progress_backend.py new file mode 100644 index 00000000..53b7f5f6 --- /dev/null +++ b/bec_widgets/widgets/progress/progress_backend.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Literal + +import numpy as np +from bec_lib.endpoints import MessageEndpoints +from qtpy.QtCore import QObject, QTimer, Signal + + +@dataclass(frozen=True) +class ProgressSnapshot: + value: float + max_value: float + done: bool + status: Literal["open", "paused", "aborted", "halt", "halted", "closed", "user_completed"] + scan_id: str | None = None + scan_number: int | None = None + rid: str | None = None + is_new_scan: bool = False + + +class ProgressTask(QObject): + """ + Class to store progress information. + Inspired by https://github.com/Textualize/rich/blob/master/rich/progress.py + """ + + def __init__( + self, parent: QObject | None, value: float = 0, max_value: float = 0, done: bool = False + ): + super().__init__(parent=parent) + self.start_time = time.monotonic() + self.done = done + self.value = value + self.max_value = max_value + self._elapsed_time = 0 + + self.timer = QTimer(self) + self.timer.timeout.connect(self.update_elapsed_time) + self.timer.start(1000) + + def update(self, value: float, max_value: float, done: bool = False): + """ + Update the progress. + """ + self.max_value = max_value + self.done = done + self.value = value + if done: + self.timer.stop() + + def update_elapsed_time(self): + """ + Update the time estimates. This is called every second by a QTimer. + """ + self._elapsed_time = max(0.0, time.monotonic() - self.start_time) + + @property + def percentage(self) -> float: + """float: Get progress of task as a percentage. If a None total was set, returns 0""" + if not self.max_value: + return 0.0 + completed = (self.value / self.max_value) * 100.0 + completed = min(100.0, max(0.0, completed)) + return completed + + @property + def speed(self) -> float: + """Get the estimated speed in steps per second.""" + if self._elapsed_time == 0: + return 0.0 + + return self.value / self._elapsed_time + + @property + def frequency(self) -> float: + """Get the estimated frequency in steps per second.""" + if self.speed == 0: + return 0.0 + return 1 / self.speed + + @property + def time_elapsed(self) -> str: + return self._format_time(int(self._elapsed_time)) + + @property + def remaining(self) -> float: + """Get the estimated remaining steps.""" + if self.done: + return 0.0 + remaining = self.max_value - self.value + return remaining + + @property + def time_remaining(self) -> str: + """ + Get the estimated remaining time in the format HH:MM:SS. + """ + if self.done or not self.speed or not self.remaining: + return self._format_time(0) + estimate = int(np.round(self.remaining / self.speed)) + + return self._format_time(estimate) + + @staticmethod + def _format_time(seconds: float) -> str: + """ + Format the time in seconds to a string in the format HH:MM:SS. + """ + return f"{seconds // 3600:02}:{(seconds // 60) % 60:02}:{seconds % 60:02}" + + +class BECProgressTracker(QObject): + """ + Shared backend for BEC scan progress messages. + """ + + progress_started = Signal(object) + progress_updated = Signal(object) + progress_finished = Signal(object) + progress_cleared = Signal() + + def __init__(self, bec_dispatcher, parent: QObject | None = None): + super().__init__(parent=parent) + self.bec_dispatcher = bec_dispatcher + self._connected = False + self.task: ProgressTask | None = None + self.scan_number: int | None = None + self._active_scan_id: str | None = None + self._active_rid: str | None = None + + def start(self) -> None: + if self._connected: + return + self.bec_dispatcher.connect_slot( + self.process_progress_message, MessageEndpoints.scan_progress() + ) + self._connected = True + + def _start_task(self, scan_id: str | None, rid: str | None = None) -> None: + if self.task is not None: + self.task.timer.stop() + self.task.deleteLater() + self.task = ProgressTask(parent=self) + self._active_scan_id = scan_id + self._active_rid = rid + self.progress_started.emit( + ProgressSnapshot( + value=0, + max_value=100, + done=False, + status="open", + scan_id=self._active_scan_id, + scan_number=self.scan_number, + rid=self._active_rid, + ) + ) + + def clear_task(self, *, emit_finished: bool = True) -> None: + if self.task is None: + self._active_scan_id = None + self._active_rid = None + self.progress_cleared.emit() + return + self.task.timer.stop() + self.task.deleteLater() + self.task = None + self._active_scan_id = None + self._active_rid = None + self.progress_cleared.emit() + if emit_finished: + self.progress_finished.emit( + ProgressSnapshot( + value=0, + max_value=100, + done=True, + status="open", + scan_id=self._active_scan_id, + scan_number=self.scan_number, + rid=self._active_rid, + ) + ) + + def process_progress_message( + self, msg_content: dict, metadata: dict + ) -> ProgressSnapshot | None: + done = msg_content.get("done", False) + value = msg_content.get("value", 0) + max_value = msg_content.get("max_value", 100) + status: Literal[ + "open", "paused", "aborted", "halt", "halted", "closed", "user_completed" + ] = metadata.get("status", "open") + scan_id = metadata.get("scan_id") or metadata.get("RID") + rid = metadata.get("RID") + scan_number = metadata.get("scan_number") + if scan_number is not None: + self.scan_number = scan_number + is_new_scan = False + previous_scan_id = self._active_scan_id + previous_rid = self._active_rid + identity_changed = ( + (scan_id is not None and scan_id != previous_scan_id) + or (rid is not None and rid != previous_rid) + or (previous_scan_id is None and previous_rid is None) + ) + + if self.task is None: + self._start_task(scan_id, rid=rid) + is_new_scan = identity_changed + elif scan_id is not None and scan_id != self._active_scan_id: + self._start_task(scan_id, rid=rid) + is_new_scan = True + elif rid is not None and rid != self._active_rid: + self._start_task(scan_id or self._active_scan_id, rid=rid) + is_new_scan = True + + if self.task is None: + return None + + self.task.update(value, max_value, done) + snapshot = ProgressSnapshot( + value=value, + max_value=max_value, + done=done, + status=status, + scan_id=self._active_scan_id, + scan_number=self.scan_number, + rid=self._active_rid, + is_new_scan=is_new_scan, + ) + self.progress_updated.emit(snapshot) + if done: + self.clear_task() + return snapshot + + def cleanup(self) -> None: + self.clear_task(emit_finished=False) + if self._connected: + self.bec_dispatcher.disconnect_slot( + self.process_progress_message, MessageEndpoints.scan_progress() + ) + self._connected = False diff --git a/bec_widgets/widgets/progress/ring_progress_bar/ring.py b/bec_widgets/widgets/progress/ring_progress_bar/ring.py index c998c953..1bedcc05 100644 --- a/bec_widgets/widgets/progress/ring_progress_bar/ring.py +++ b/bec_widgets/widgets/progress/ring_progress_bar/ring.py @@ -12,7 +12,9 @@ from qtpy.QtWidgets import QWidget from bec_widgets import BECWidget from bec_widgets.utils.bec_connector import ConnectionConfig from bec_widgets.utils.colors import Colors +from bec_widgets.utils.entry_validator import EntryValidator from bec_widgets.utils.error_popups import SafeProperty, SafeSlot +from bec_widgets.widgets.progress.progress_backend import BECProgressTracker, ProgressSnapshot logger = bec_logger.logger if TYPE_CHECKING: @@ -81,6 +83,8 @@ class Ring(BECWidget, QWidget): self._color: QColor = self.convert_color(self.config.color) self._background_color: QColor = self.convert_color(self.config.background_color) self.registered_slot: tuple[Callable, str | EndpointInfo] | None = None + self.progress_tracker = BECProgressTracker(self.bec_dispatcher, parent=self) + self.progress_tracker.progress_updated.connect(self._on_progress_snapshot) self.RID = None self._gap = 5 self._hovered = False @@ -219,35 +223,32 @@ class Ring(BECWidget, QWidget): case "manual": if self.config.mode == "manual": return - if self.registered_slot is not None: - self.bec_dispatcher.disconnect_slot(*self.registered_slot) + self._disconnect_registered_update() self.config.mode = "manual" - self.registered_slot = None case "scan": if self.config.mode == "scan": return - if self.registered_slot is not None: - self.bec_dispatcher.disconnect_slot(*self.registered_slot) + self._disconnect_registered_update() self.config.mode = "scan" - self.bec_dispatcher.connect_slot( - self.on_scan_progress, MessageEndpoints.scan_progress() - ) - self.registered_slot = (self.on_scan_progress, MessageEndpoints.scan_progress()) + self.progress_tracker.start() case "device": - if self.registered_slot is not None: - self.bec_dispatcher.disconnect_slot(*self.registered_slot) + self._disconnect_registered_update() self.config.mode = "device" if device == "": - self.registered_slot = None return self.config.device = device - # self.config.signal = self._get_signal_from_device(device, signal) signal = self._update_device_connection(device, signal) self.config.signal = signal case _: raise ValueError(f"Unsupported mode: {mode}") + def _disconnect_registered_update(self): + if self.registered_slot is not None: + self.bec_dispatcher.disconnect_slot(*self.registered_slot) + self.registered_slot = None + self.progress_tracker.cleanup() + def set_precision(self, precision: int): """ Set the precision for the ring widget. @@ -268,57 +269,12 @@ class Ring(BECWidget, QWidget): self.config.direction = direction self._request_update() - def _get_signals_for_device(self, device: str) -> dict[str, list[str]]: - """ - Get the signals for the device. - - Args: - device(str): Device name for the device - - Returns: - dict[str, list[str]]: Dictionary with the signals for the device - """ - dm = self.bec_dispatcher.client.device_manager - if not dm: - raise ValueError("Device manager is not available in the BEC client.") - dev_obj = dm.devices.get(device) - if dev_obj is None: - raise ValueError(f"Device '{device}' not found in device manager.") - - progress_signals = [ - obj["component_name"] - for obj in dev_obj._info["signals"].values() - if obj["signal_class"] == "ProgressSignal" - ] - hinted_signals = [ - obj["obj_name"] - for obj in dev_obj._info["signals"].values() - if obj["kind_str"] == "hinted" - and obj["signal_class"] - not in ["ProgressSignal", "AsyncSignal", "AsyncMultiSignal", "DynamicSignal"] - ] - - normal_signals = [ - obj["component_name"] - for obj in dev_obj._info["signals"].values() - if obj["kind_str"] == "normal" - ] - return { - "progress_signals": progress_signals, - "hinted_signals": hinted_signals, - "normal_signals": normal_signals, - } - def _update_device_connection(self, device: str, signal: str | None) -> str: """ Update the device connection for the ring widget. - In general, we support two modes here: - - If signal is provided, we use that directly. - - If signal is not provided, we try to get the signal from the device manager. - We first check for progress signals, then for hinted signals, and finally for normal signals. - - Depending on what type of signal we get (progress or hinted/normal), we subscribe to different endpoints. + Device mode always subscribes to the device readback endpoint. If no signal is provided, + the signal is resolved from the device hints, matching the plot widgets. Args: device(str): Device name for the device mode @@ -335,57 +291,11 @@ class Ring(BECWidget, QWidget): if dev_obj is None: return "" - signals = self._get_signals_for_device(device) - progress_signals = signals["progress_signals"] - hinted_signals = signals["hinted_signals"] - normal_signals = signals["normal_signals"] - - if not signal: - # If signal is not provided, we try to get it from the device manager - if len(progress_signals) > 0: - signal = progress_signals[0] - logger.info( - f"Using progress signal '{signal}' for device '{device}' in ring progress bar." - ) - elif len(hinted_signals) > 0: - signal = hinted_signals[0] - logger.info( - f"Using hinted signal '{signal}' for device '{device}' in ring progress bar." - ) - elif len(normal_signals) > 0: - signal = normal_signals[0] - logger.info( - f"Using normal signal '{signal}' for device '{device}' in ring progress bar." - ) - else: - logger.warning(f"No signals found for device '{device}' in ring progress bar.") - return "" - - if signal in progress_signals: - endpoint = MessageEndpoints.device_progress(device) - self.bec_dispatcher.connect_slot(self.on_device_progress, endpoint) - self.registered_slot = (self.on_device_progress, endpoint) - return signal - if signal in hinted_signals or signal in normal_signals: - endpoint = MessageEndpoints.device_readback(device) - self.bec_dispatcher.connect_slot(self.on_device_readback, endpoint) - self.registered_slot = (self.on_device_readback, endpoint) - return signal - - @SafeSlot(dict, dict) - def on_scan_progress(self, msg, meta): - """ - Update the ring widget with the scan progress. - - Args: - msg(dict): Message with the scan progress - meta(dict): Metadata for the message - """ - current_RID = meta.get("RID", None) - if current_RID != self.RID: - self.set_min_max_values(0, msg.get("max_value", 100)) - self.set_value(msg.get("value", 0)) - self.update() + signal = EntryValidator(dm.devices).validate_signal(device, signal or None) + endpoint = MessageEndpoints.device_readback(device) + self.bec_dispatcher.connect_slot(self.on_device_readback, endpoint) + self.registered_slot = (self.on_device_readback, endpoint) + return signal @SafeSlot(dict, dict) def on_device_readback(self, msg, meta): @@ -406,32 +316,23 @@ class Ring(BECWidget, QWidget): self.set_value(value) self.update() - @SafeSlot(dict, dict) - def on_device_progress(self, msg, meta): - """ - Update the ring widget with the device progress. - - Args: - msg(dict): Message with the device progress - meta(dict): Metadata for the message - """ - device = self.config.device - if device is None: - return - max_val = msg.get("max_value", 100) - self.set_min_max_values(0, max_val) - value = msg.get("value", 0) - if msg.get("done"): - value = max_val - self.set_value(value) + def _on_progress_snapshot(self, snapshot: ProgressSnapshot): + if snapshot.is_new_scan: + self.set_min_max_values(0, snapshot.max_value) + self.RID = snapshot.rid + self.set_value(snapshot.value) self.update() def paintEvent(self, event): if not self.progress_container: return - painter = QtGui.QPainter(self) - painter.setRenderHint(QtGui.QPainter.RenderHint.Antialiasing) size = min(self.width(), self.height()) + if size <= 0 or not self.isVisible(): + return + painter = QtGui.QPainter(self) + if not painter.isActive(): + return + painter.setRenderHint(QtGui.QPainter.RenderHint.Antialiasing) # Center the ring x_offset = (self.width() - size) // 2 @@ -509,15 +410,6 @@ class Ring(BECWidget, QWidget): return QtGui.QColor(*color) raise ValueError(f"Unsupported color format: {color}") - def cleanup(self): - """ - Cleanup the ring widget. - Disconnect any registered slots. - """ - if self.registered_slot is not None: - self.bec_dispatcher.disconnect_slot(*self.registered_slot) - self.registered_slot = None - ############################################### ####### QProperties ########################### ############################################### @@ -666,6 +558,7 @@ class Ring(BECWidget, QWidget): if self.registered_slot is not None: self.bec_dispatcher.disconnect_slot(*self.registered_slot) self.registered_slot = None + self.progress_tracker.cleanup() self._hover_animation.stop() super().cleanup() diff --git a/bec_widgets/widgets/progress/scan_progressbar/scan_progressbar.py b/bec_widgets/widgets/progress/scan_progressbar/scan_progressbar.py index 3b08d33e..3ed44043 100644 --- a/bec_widgets/widgets/progress/scan_progressbar/scan_progressbar.py +++ b/bec_widgets/widgets/progress/scan_progressbar/scan_progressbar.py @@ -1,123 +1,20 @@ from __future__ import annotations -import enum import os -import time -from typing import Literal -import numpy as np -from bec_lib import messages -from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from qtpy.QtCore import QObject, QTimer, Signal +from qtpy.QtCore import Signal from qtpy.QtWidgets import QVBoxLayout, QWidget from bec_widgets.utils.bec_widget import BECWidget -from bec_widgets.utils.error_popups import SafeProperty, SafeSlot +from bec_widgets.utils.error_popups import SafeProperty from bec_widgets.utils.ui_loader import UILoader from bec_widgets.widgets.progress.bec_progressbar.bec_progressbar import ProgressState +from bec_widgets.widgets.progress.progress_backend import BECProgressTracker, ProgressSnapshot logger = bec_logger.logger -class ProgressSource(enum.Enum): - """ - Enum to define the source of the progress. - """ - - SCAN_PROGRESS = "scan_progress" - DEVICE_PROGRESS = "device_progress" - - -class ProgressTask(QObject): - """ - Class to store progress information. - Inspired by https://github.com/Textualize/rich/blob/master/rich/progress.py - """ - - def __init__(self, parent: QWidget, value: float = 0, max_value: float = 0, done: bool = False): - super().__init__(parent=parent) - self.start_time = time.monotonic() - self.done = done - self.value = value - self.max_value = max_value - self._elapsed_time = 0 - - self.timer = QTimer(self) - self.timer.timeout.connect(self.update_elapsed_time) - self.timer.start(1000) - - def update(self, value: float, max_value: float, done: bool = False): - """ - Update the progress. - """ - self.max_value = max_value - self.done = done - self.value = value - if done: - self.timer.stop() - - def update_elapsed_time(self): - """ - Update the time estimates. This is called every second by a QTimer. - """ - self._elapsed_time = max(0.0, time.monotonic() - self.start_time) - - @property - def percentage(self) -> float: - """float: Get progress of task as a percentage. If a None total was set, returns 0""" - if not self.max_value: - return 0.0 - completed = (self.value / self.max_value) * 100.0 - completed = min(100.0, max(0.0, completed)) - return completed - - @property - def speed(self) -> float: - """Get the estimated speed in steps per second.""" - if self._elapsed_time == 0: - return 0.0 - - return self.value / self._elapsed_time - - @property - def frequency(self) -> float: - """Get the estimated frequency in steps per second.""" - if self.speed == 0: - return 0.0 - return 1 / self.speed - - @property - def time_elapsed(self) -> str: - # format the elapsed time to a string in the format HH:MM:SS - return self._format_time(int(self._elapsed_time)) - - @property - def remaining(self) -> float: - """Get the estimated remaining steps.""" - if self.done: - return 0.0 - remaining = self.max_value - self.value - return remaining - - @property - def time_remaining(self) -> str: - """ - Get the estimated remaining time in the format HH:MM:SS. - """ - if self.done or not self.speed or not self.remaining: - return self._format_time(0) - estimate = int(np.round(self.remaining / self.speed)) - - return self._format_time(estimate) - - def _format_time(self, seconds: float) -> str: - """ - Format the time in seconds to a string in the format HH:MM:SS. - """ - return f"{seconds // 3600:02}:{(seconds // 60) % 60:02}:{seconds % 60:02}" - - class ScanProgressBar(BECWidget, QWidget): """ Widget to display a progress bar that is hooked up to the scan progress of a scan. @@ -158,101 +55,34 @@ class ScanProgressBar(BECWidget, QWidget): self._show_remaining_time = self.ui.remaining_time_label.isVisible() self._show_source_label = self.ui.source_label.isVisible() - self._progress_source = None - self._progress_device = None - self.task = None - self.scan_number = None - self._active_scan_id = None - self.connect_to_queue() - - def connect_to_queue(self): - """ - Connect to the queue status signal. - """ - self.bec_dispatcher.connect_slot(self.on_queue_update, MessageEndpoints.scan_queue_status()) - - def set_progress_source(self, source: ProgressSource, device=None): - """ - Set the source of the progress. - """ - if self._progress_source == source and self._progress_device == device: - self.update_source_label(source, device=device) - return - if self._progress_source is not None: - self.bec_dispatcher.disconnect_slot( - self.on_progress_update, - ( - MessageEndpoints.scan_progress() - if self._progress_source == ProgressSource.SCAN_PROGRESS - else MessageEndpoints.device_progress(device=self._progress_device) - ), - ) - self._progress_source = source - self._progress_device = None if source == ProgressSource.SCAN_PROGRESS else device - self.bec_dispatcher.connect_slot( - self.on_progress_update, - ( - MessageEndpoints.scan_progress() - if source == ProgressSource.SCAN_PROGRESS - else MessageEndpoints.device_progress(device=device) - ), + self.progress_tracker = BECProgressTracker(self.bec_dispatcher, parent=self) + self.progress_tracker.progress_started.connect(self._on_progress_started) + self.progress_tracker.progress_updated.connect(self._on_progress_snapshot) + self.progress_tracker.progress_finished.connect( + lambda _snapshot: self.progress_finished.emit() ) - self.update_source_label(source, device=device) - # self.progress_started.emit() + self.progress_tracker.start() - def _start_task(self, scan_id: str | None) -> None: - if self.task is not None: - self.task.timer.stop() - self.task.deleteLater() - self.task = ProgressTask(parent=self) - self.task.timer.timeout.connect(self.update_labels) - self._active_scan_id = scan_id + def update_source_label(self): + scan_number = self.progress_tracker.scan_number + scan_text = f"Scan {scan_number}" if scan_number is not None else "Scan" + if self.ui.source_label.text() == scan_text: + return + logger.info(f"Set progress source to {scan_text}") + self.ui.source_label.setText(scan_text) + + def _on_progress_started(self, _snapshot: ProgressSnapshot): + if self.progress_tracker.task is not None: + self.progress_tracker.task.timer.timeout.connect(self.update_labels) self.progress_started.emit() - def _clear_task(self, *, emit_finished: bool = True) -> None: - if self.task is None: - self._active_scan_id = None - return - self.task.timer.stop() - self.task.deleteLater() - self.task = None - self._active_scan_id = None - if emit_finished: - self.progress_finished.emit() - - def update_source_label(self, source: ProgressSource, device=None): - scan_text = f"Scan {self.scan_number}" if self.scan_number is not None else "Scan" - text = scan_text if source == ProgressSource.SCAN_PROGRESS else f"Device {device}" - if self.ui.source_label.text() == text: - return - logger.info(f"Set progress source to {text}") - self.ui.source_label.setText(text) - - @SafeSlot(dict, dict) - def on_progress_update(self, msg_content: dict, metadata: dict): - """ - Update the progress bar based on the progress message. - """ - value = msg_content["value"] - max_value = msg_content.get("max_value", 100) - done = msg_content.get("done", False) - status: Literal["open", "paused", "aborted", "halted", "closed"] = metadata.get( - "status", "open" - ) - - if self.task is None: - return - self.task.update(value, max_value, done) - + def _on_progress_snapshot(self, snapshot: ProgressSnapshot): self.update_labels() - - self.progressbar.set_maximum(self.task.max_value) - self.progressbar.state = ProgressState.from_bec_status(status) - self.progressbar.set_value(self.task.value) - - if done: - self._clear_task() - return + self.update_source_label() + self.progressbar.set_maximum(snapshot.max_value) + state = ProgressState.from_bec_status(snapshot.status) + self.progressbar.state = state + self.progressbar.set_value(snapshot.value) @SafeProperty(bool) def show_elapsed_time(self): @@ -289,86 +119,17 @@ class ScanProgressBar(BECWidget, QWidget): """ Update the labels based on the progress task. """ - if self.task is None: + task = self.progress_tracker.task + if task is None: return - self.ui.elapsed_time_label.setText(self.task.time_elapsed) - self.ui.remaining_time_label.setText(self.task.time_remaining) - - @SafeSlot(dict, dict, verify_sender=True) - def on_queue_update(self, msg_content, metadata): - """ - Update the progress bar based on the queue status. - """ - if not "queue" in msg_content: - self._clear_task() - return - if "primary" not in msg_content["queue"]: - self._clear_task() - return - if (primary_queue := msg_content.get("queue").get("primary")) is None: - self._clear_task() - return - if not isinstance(primary_queue, messages.ScanQueueStatus): - self._clear_task() - return - primary_queue_info = primary_queue.info - if len(primary_queue_info) == 0: - self._clear_task() - return - scan_info = primary_queue_info[0] - if scan_info is None: - self._clear_task() - return - - active_request_block = scan_info.active_request_block - if active_request_block is None: - self._clear_task() - return - - status = scan_info.status.lower() - if status != "running": - self._clear_task() - return - - scan_id = active_request_block.scan_id or str(active_request_block.scan_number) - if self.task is None or self._active_scan_id != scan_id: - self._start_task(scan_id) - - self.scan_number = active_request_block.scan_number - report_instructions = active_request_block.report_instructions - if not report_instructions: - return - - # for now, let's just use the first instruction - instruction = report_instructions[0] - - if "scan_progress" in instruction: - self.set_progress_source(ProgressSource.SCAN_PROGRESS) - elif "device_progress" in instruction: - if not instruction["device_progress"]: - return - device = instruction["device_progress"][0] - self.set_progress_source(ProgressSource.DEVICE_PROGRESS, device=device) + self.ui.elapsed_time_label.setText(task.time_elapsed) + self.ui.remaining_time_label.setText(task.time_remaining) def cleanup(self): - self._clear_task(emit_finished=False) - if self._progress_source is not None: - self.bec_dispatcher.disconnect_slot( - self.on_progress_update, - ( - MessageEndpoints.scan_progress() - if self._progress_source == ProgressSource.SCAN_PROGRESS - else MessageEndpoints.device_progress(device=self._progress_device) - ), - ) - self._progress_source = None - self._progress_device = None + self.progress_tracker.cleanup() self.progressbar.close() self.progressbar.deleteLater() - self.bec_dispatcher.disconnect_slot( - self.on_queue_update, MessageEndpoints.scan_queue_status() - ) super().cleanup() diff --git a/tests/unit_tests/test_bec_progressbar.py b/tests/unit_tests/test_bec_progressbar.py index d0628f8e..9269137e 100644 --- a/tests/unit_tests/test_bec_progressbar.py +++ b/tests/unit_tests/test_bec_progressbar.py @@ -44,7 +44,7 @@ def test_progressbar_label(progressbar): progressbar.label_template = "Test: $value" progressbar.set_value(50) assert progressbar._get_label() == "Test: 50" - assert progressbar.center_label.text() == "Test: 50" + assert progressbar.progressbar.text() == "Test: 50" def test_progressbar_equal_minimum_and_maximum_does_not_raise(progressbar): @@ -63,7 +63,10 @@ def test_progressbar_uses_static_stylesheet_with_palette_state_color(progressbar style_sheet = progressbar.progressbar.styleSheet() assert "QProgressBar::chunk" in style_sheet - assert "background-color: palette(highlight);" in style_sheet + assert ( + f"background-color: {progressbar._state_colors[ProgressState.PAUSED].name()};" + in style_sheet + ) assert "background-color: palette(mid);" in style_sheet assert "border-radius: 7px;" in style_sheet assert ( @@ -171,8 +174,10 @@ def test_progress_state_from_bec_status(): "open": ProgressState.NORMAL, "paused": ProgressState.PAUSED, "aborted": ProgressState.INTERRUPTED, + "halt": ProgressState.PAUSED, "halted": ProgressState.PAUSED, "closed": ProgressState.COMPLETED, + "user_completed": ProgressState.PAUSED, "UNKNOWN": ProgressState.NORMAL, # fallback } for text, expected in mapping.items(): diff --git a/tests/unit_tests/test_device_initialization_progress_bar.py b/tests/unit_tests/test_device_initialization_progress_bar.py index 530b53c1..c08a82b9 100644 --- a/tests/unit_tests/test_device_initialization_progress_bar.py +++ b/tests/unit_tests/test_device_initialization_progress_bar.py @@ -40,7 +40,7 @@ def test_update_device_initialization_progress(progress_bar, qtbot): assert progress_bar.progress_bar._user_value == 1 assert progress_bar.progress_bar._user_maximum == 3 assert progress_bar.progress_label.text() == f"{msg.device} initialization in progress..." - assert "1 / 3 - 33 %" == progress_bar.progress_bar.center_label.text() + assert "1 / 3 - 33 %" == progress_bar.progress_bar.progressbar.text() # II. Update with message of finished DeviceInitializationProgressMessage, finished=True, success=True msg.finished = True @@ -49,7 +49,7 @@ def test_update_device_initialization_progress(progress_bar, qtbot): assert progress_bar.progress_bar._user_value == 1 assert progress_bar.progress_bar._user_maximum == 3 assert progress_bar.progress_label.text() == f"{msg.device} initialization succeeded!" - assert "1 / 3 - 33 %" == progress_bar.progress_bar.center_label.text() + assert "1 / 3 - 33 %" == progress_bar.progress_bar.progressbar.text() # III. Update with message of finished DeviceInitializationProgressMessage, finished=True, success=False msg.finished = True @@ -59,7 +59,7 @@ def test_update_device_initialization_progress(progress_bar, qtbot): with qtbot.waitSignal(progress_bar.failed_devices_changed) as signal_blocker: progress_bar._update_device_initialization_progress(msg.model_dump(), {}) assert progress_bar.progress_label.text() == f"{msg.device} initialization failed!" - assert "2 / 3 - 66 %" == progress_bar.progress_bar.center_label.text() + assert "2 / 3 - 66 %" == progress_bar.progress_bar.progressbar.text() assert progress_bar.progress_bar._user_value == 2 assert progress_bar.progress_bar._user_maximum == 3 diff --git a/tests/unit_tests/test_main_widnow.py b/tests/unit_tests/test_main_widnow.py index 641973cc..2b0fd40e 100644 --- a/tests/unit_tests/test_main_widnow.py +++ b/tests/unit_tests/test_main_widnow.py @@ -117,13 +117,6 @@ def test_hidden_scan_progress_parent_blocks_children_namespace(bec_main_window): assert nested_progress.parent_id == hidden_progress.gui_id -def test_compact_scan_progress_bar_uses_status_bar_sizing(bec_main_window): - progressbar = bec_main_window._scan_progress_bar_simple.progressbar - - assert progressbar.height() == bec_main_window.SCAN_PROGRESS_HEIGHT - assert progressbar.progressbar.minimumHeight() == 0 - - ################################################################# # Tests for BECMainWindow Addons ################################################################# diff --git a/tests/unit_tests/test_progress_backend.py b/tests/unit_tests/test_progress_backend.py new file mode 100644 index 00000000..1f839125 --- /dev/null +++ b/tests/unit_tests/test_progress_backend.py @@ -0,0 +1,92 @@ +from unittest import mock + +from bec_lib.endpoints import MessageEndpoints + +from bec_widgets.widgets.progress.progress_backend import BECProgressTracker + + +def _dispatcher(): + dispatcher = mock.MagicMock() + return dispatcher + + +def test_tracker_subscribes_to_scan_progress_immediately(): + dispatcher = _dispatcher() + tracker = BECProgressTracker(dispatcher) + + tracker.start() + + dispatcher.connect_slot.assert_called_once_with( + tracker.process_progress_message, MessageEndpoints.scan_progress() + ) + tracker.cleanup() + + +def test_tracker_starts_scan_from_scan_progress_metadata(): + dispatcher = _dispatcher() + tracker = BECProgressTracker(dispatcher) + snapshots = [] + tracker.progress_updated.connect(snapshots.append) + + tracker.start() + tracker.process_progress_message( + {"value": 3, "max_value": 10}, + {"scan_id": "scan-2", "RID": "rid-2", "scan_number": 2, "status": "open"}, + ) + + assert tracker.task is not None + assert tracker._active_scan_id == "scan-2" + assert tracker._active_rid == "rid-2" + assert tracker.scan_number == 2 + assert snapshots[-1].scan_number == 2 + + tracker.cleanup() + + +def test_tracker_switches_sources_idempotently(): + dispatcher = _dispatcher() + tracker = BECProgressTracker(dispatcher) + + tracker.start() + tracker.start() + assert dispatcher.connect_slot.call_count == 1 + assert dispatcher.disconnect_slot.call_count == 0 + + tracker.cleanup() + + +def test_tracker_marks_new_scan_only_when_rid_changes(): + dispatcher = _dispatcher() + tracker = BECProgressTracker(dispatcher) + snapshots = [] + tracker.progress_updated.connect(snapshots.append) + tracker.start() + + tracker.process_progress_message({"value": 10, "max_value": 100}, {"RID": "rid-1"}) + tracker.process_progress_message({"value": 20, "max_value": 200}, {"RID": "rid-1"}) + tracker.process_progress_message({"value": 5, "max_value": 50}, {"RID": "rid-2"}) + + assert [snapshot.is_new_scan for snapshot in snapshots] == [True, False, True] + assert tracker._active_rid == "rid-2" + + tracker.cleanup() + + +def test_tracker_keeps_partial_value_for_done_scan_progress(): + dispatcher = _dispatcher() + tracker = BECProgressTracker(dispatcher) + snapshots = [] + tracker.progress_updated.connect(snapshots.append) + tracker.start() + + tracker.process_progress_message( + {"value": 4, "max_value": 10, "done": True}, + {"scan_id": "scan-1", "RID": "rid-1", "status": "aborted"}, + ) + + assert snapshots[-1].value == 4 + assert snapshots[-1].max_value == 10 + assert snapshots[-1].done is True + assert tracker.task is None + + tracker.cleanup() diff --git a/tests/unit_tests/test_ring_progress_bar_ring.py b/tests/unit_tests/test_ring_progress_bar_ring.py index 2700eb66..fbeaad2b 100644 --- a/tests/unit_tests/test_ring_progress_bar_ring.py +++ b/tests/unit_tests/test_ring_progress_bar_ring.py @@ -79,7 +79,7 @@ def test_set_update_to_scan(ring_widget): # Verify that connect_slot was called ring_widget.bec_dispatcher.connect_slot.assert_called_once() call_args = ring_widget.bec_dispatcher.connect_slot.call_args - assert call_args[0][0] == ring_widget.on_scan_progress + assert call_args[0][0] == ring_widget.progress_tracker.process_progress_message assert "scan_progress" in str(call_args[0][1]) @@ -432,12 +432,13 @@ def test_update_device_connection_with_progress_signal(ring_widget_with_device): ring_widget.bec_dispatcher.connect_slot = MagicMock() - ring_widget._update_device_connection("samx", "progress") + signal = ring_widget._update_device_connection("samx", "progress") - # Should connect to device_progress endpoint + # Device mode always connects to device_readback, even if the explicit signal is a ProgressSignal. + assert signal == "samx_progress" ring_widget.bec_dispatcher.connect_slot.assert_called_once() call_args = ring_widget.bec_dispatcher.connect_slot.call_args - assert call_args[0][0] == ring_widget.on_device_progress + assert call_args[0][0] == ring_widget.on_device_readback def test_update_device_connection_with_hinted_signal(ring_widget): @@ -477,39 +478,35 @@ def test_update_device_connection_device_not_found(ring_widget): ################################### -# on_scan_progress tests +# scan progress tests ################################### -def test_on_scan_progress_updates_value(ring_widget): +def test_scan_progress_updates_value(ring_widget): msg = {"value": 42, "max_value": 100} meta = {"RID": "test_rid_123"} - ring_widget.on_scan_progress(msg, meta) + ring_widget.progress_tracker.process_progress_message(msg, meta) assert ring_widget.config.value == 42 -def test_on_scan_progress_updates_min_max_on_new_rid(ring_widget): +def test_scan_progress_updates_min_max_on_new_rid(ring_widget): msg = {"value": 50, "max_value": 200} meta = {"RID": "new_rid"} - ring_widget.RID = "old_rid" - ring_widget.on_scan_progress(msg, meta) + ring_widget.progress_tracker.process_progress_message(msg, meta) assert ring_widget.config.min_value == 0 assert ring_widget.config.max_value == 200 assert ring_widget.config.value == 50 -def test_on_scan_progress_same_rid_no_min_max_update(ring_widget): - msg = {"value": 75, "max_value": 300} +def test_scan_progress_same_rid_no_min_max_update(ring_widget): meta = {"RID": "same_rid"} - ring_widget.RID = "same_rid" - ring_widget.set_min_max_values(0, 100) - - ring_widget.on_scan_progress(msg, meta) + ring_widget.progress_tracker.process_progress_message({"value": 10, "max_value": 100}, meta) + ring_widget.progress_tracker.process_progress_message({"value": 75, "max_value": 300}, meta) # Max value should not be updated when RID is the same assert ring_widget.config.max_value == 100 @@ -570,63 +567,3 @@ def test_on_device_readback_missing_signal_data(ring_widget): # Value should not change when signal is missing assert ring_widget.config.value == initial_value - - -################################### -# on_device_progress tests -################################### - - -def test_on_device_progress_updates_value_and_max(ring_widget): - ring_widget.config.device = "samx" - - msg = {"value": 30, "max_value": 150, "done": False} - meta = {} - - ring_widget.on_device_progress(msg, meta) - - assert ring_widget.config.value == 30 - assert ring_widget.config.max_value == 150 - - -def test_on_device_progress_done_sets_to_max(ring_widget): - ring_widget.config.device = "samx" - - msg = {"value": 80, "max_value": 100, "done": True} - meta = {} - - ring_widget.on_device_progress(msg, meta) - - # When done is True, value should be set to max_value - assert ring_widget.config.value == 100 - assert ring_widget.config.max_value == 100 - - -def test_on_device_progress_no_device_returns_early(ring_widget): - ring_widget.config.device = None - - msg = {"value": 50, "max_value": 100, "done": False} - meta = {} - - initial_value = ring_widget.config.value - initial_max = ring_widget.config.max_value - - ring_widget.on_device_progress(msg, meta) - - # Nothing should change - assert ring_widget.config.value == initial_value - assert ring_widget.config.max_value == initial_max - - -def test_on_device_progress_default_values(ring_widget): - ring_widget.config.device = "samx" - - # Message without value and max_value - msg = {} - meta = {} - - ring_widget.on_device_progress(msg, meta) - - # Should use defaults: value=0, max_value=100 - assert ring_widget.config.value == 0 - assert ring_widget.config.max_value == 100 diff --git a/tests/unit_tests/test_scan_progress_bar.py b/tests/unit_tests/test_scan_progress_bar.py index a9cf554b..0ba1f70c 100644 --- a/tests/unit_tests/test_scan_progress_bar.py +++ b/tests/unit_tests/test_scan_progress_bar.py @@ -2,19 +2,14 @@ from unittest import mock import numpy as np import pytest -from bec_lib import messages -from bec_lib.endpoints import MessageEndpoints from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.widgets.progress.bec_progressbar.bec_progressbar import ( BECProgressBar, ProgressState, ) -from bec_widgets.widgets.progress.scan_progressbar.scan_progressbar import ( - ProgressSource, - ProgressTask, - ScanProgressBar, -) +from bec_widgets.widgets.progress.progress_backend import ProgressTask +from bec_widgets.widgets.progress.scan_progressbar.scan_progressbar import ScanProgressBar from .client_mocks import mocked_client @@ -27,30 +22,6 @@ def scan_progressbar(qtbot, mocked_client): yield widget -@pytest.fixture -def scan_message(): - return messages.ScanQueueMessage( - metadata={ - "file_suffix": None, - "file_directory": None, - "user_metadata": {"sample_name": ""}, - "RID": "94949c6e-d5f2-4f01-837e-a5d36257dd5d", - }, - scan_type="line_scan", - parameter={ - "args": {"samx": [-10.0, 10.0]}, - "kwargs": { - "steps": 20, - "relative": False, - "exp_time": 0.1, - "burst_at_each_point": 1, - "system_config": {"file_suffix": None, "file_directory": None}, - }, - }, - queue="primary", - ) - - def test_progress_task_basic(): """percentage, remaining, and formatted time helpers behave as expected.""" task = ProgressTask(parent=None, value=50, max_value=100, done=False) @@ -74,8 +45,7 @@ def test_progress_task_basic(): def test_progress_task_elapsed_time_uses_monotonic_clock(monkeypatch): times = iter([100.0, 102.5]) monkeypatch.setattr( - "bec_widgets.widgets.progress.scan_progressbar.scan_progressbar.time.monotonic", - lambda: next(times), + "bec_widgets.widgets.progress.progress_backend.time.monotonic", lambda: next(times) ) task = ProgressTask(parent=None) task.timer.stop() @@ -98,13 +68,26 @@ def test_scan_progressbar_passes_dynamic_stylesheet_setting(qtbot, mocked_client assert widget.progressbar.enable_dynamic_stylesheet is False +def test_scan_progressbar_starts_from_scan_progress_before_queue_update(scan_progressbar): + scan_progressbar.progress_tracker.clear_task(emit_finished=False) + + scan_progressbar.progress_tracker.process_progress_message( + {"value": 3, "max_value": 10, "done": False}, metadata={"RID": "live-rid"} + ) + + assert scan_progressbar.progress_tracker.task is not None + assert scan_progressbar.progress_tracker._active_scan_id == "live-rid" + assert scan_progressbar.progressbar._user_value == 3 + assert scan_progressbar.progressbar._user_maximum == 10 + + def test_update_labels_content(scan_progressbar): """update_labels() reflects ProgressTask time strings on the UI.""" # fabricate a task with known timings task = ProgressTask(parent=scan_progressbar, value=30, max_value=100, done=False) task.timer.stop() task._elapsed_time = 50 - scan_progressbar.task = task + scan_progressbar.progress_tracker.task = task scan_progressbar.update_labels() @@ -112,17 +95,17 @@ def test_update_labels_content(scan_progressbar): assert scan_progressbar.ui.remaining_time_label.text() == "00:01:57" -def test_on_progress_update(qtbot, scan_progressbar): +def test_progress_update(qtbot, scan_progressbar): """ - on_progress_update() should forward new values to the embedded - BECProgressBar and keep ProgressTask in sync. + Scan progress updates should update the embedded BECProgressBar + and keep ProgressTask in sync. """ task = ProgressTask(parent=scan_progressbar, value=0, max_value=100, done=False) task.timer.stop() - scan_progressbar.task = task + scan_progressbar.progress_tracker.task = task msg = {"value": 20, "max_value": 100, "done": False} - scan_progressbar.on_progress_update(msg, metadata={"status": "open"}) + scan_progressbar.progress_tracker.process_progress_message(msg, metadata={"status": "open"}) qtbot.wait(200) bar = scan_progressbar.progressbar @@ -138,8 +121,10 @@ def test_on_progress_update(qtbot, scan_progressbar): ("open", 10, 100, ProgressState.NORMAL), ("paused", 25, 100, ProgressState.PAUSED), ("aborted", 30, 100, ProgressState.INTERRUPTED), + ("halt", 40, 100, ProgressState.PAUSED), ("halted", 40, 100, ProgressState.PAUSED), ("closed", 100, 100, ProgressState.COMPLETED), + ("user_completed", 40, 100, ProgressState.PAUSED), ], ) def test_state_mapping_during_updates( @@ -148,9 +133,9 @@ def test_state_mapping_during_updates( """ScanProgressBar should translate BEC status → ProgressState consistently.""" task = ProgressTask(parent=scan_progressbar, value=0, max_value=max_val, done=False) task.timer.stop() - scan_progressbar.task = task + scan_progressbar.progress_tracker.task = task - scan_progressbar.on_progress_update( + scan_progressbar.progress_tracker.process_progress_message( {"value": value, "max_value": max_val, "done": status == "closed"}, metadata={"status": status}, ) @@ -158,110 +143,39 @@ def test_state_mapping_during_updates( assert scan_progressbar.progressbar.state is expected_state -def test_source_label_updates(scan_progressbar): - """update_source_label() renders correct text for both progress sources.""" - # device progress - scan_progressbar.update_source_label(ProgressSource.DEVICE_PROGRESS, device="motor") - assert scan_progressbar.ui.source_label.text() == "Device motor" +def test_aborted_done_scan_keeps_partial_progress(scan_progressbar): + scan_progressbar.progress_tracker.process_progress_message( + {"value": 4, "max_value": 10, "done": True}, + metadata={"scan_id": "scan-1", "RID": "rid-1", "status": "aborted"}, + ) - # scan progress (needs a scan_number for deterministic text) - scan_progressbar.scan_number = 5 - scan_progressbar.update_source_label(ProgressSource.SCAN_PROGRESS) + assert scan_progressbar.progressbar._user_value == 4 + assert scan_progressbar.progressbar._user_maximum == 10 + assert scan_progressbar.progressbar.state is ProgressState.INTERRUPTED + assert scan_progressbar.progress_tracker.task is None + + +def test_source_label_updates(scan_progressbar): + """update_source_label() renders the current scan label.""" + scan_progressbar.progress_tracker.scan_number = 5 + scan_progressbar.update_source_label() assert scan_progressbar.ui.source_label.text() == "Scan 5" def test_source_label_update_logs_only_on_text_change(scan_progressbar): - scan_progressbar.scan_number = 5 + scan_progressbar.progress_tracker.scan_number = 5 with mock.patch( "bec_widgets.widgets.progress.scan_progressbar.scan_progressbar.logger.info" ) as mock_info: - scan_progressbar.set_progress_source(ProgressSource.SCAN_PROGRESS) - scan_progressbar.set_progress_source(ProgressSource.SCAN_PROGRESS) - scan_progressbar.set_progress_source(ProgressSource.SCAN_PROGRESS) + scan_progressbar.update_source_label() + scan_progressbar.update_source_label() + scan_progressbar.update_source_label() mock_info.assert_called_once_with("Set progress source to Scan 5") -def test_set_progress_source_connections(scan_progressbar, monkeypatch): - """ """ - - connect_calls = [] - disconnect_calls = [] - - def fake_connect(slot, endpoint): - connect_calls.append(endpoint) - - def fake_disconnect(slot, endpoint): - disconnect_calls.append(endpoint) - - # Patch dispatcher methods - monkeypatch.setattr(scan_progressbar.bec_dispatcher, "connect_slot", fake_connect) - monkeypatch.setattr(scan_progressbar.bec_dispatcher, "disconnect_slot", fake_disconnect) - - # switch to SCAN_PROGRESS - scan_progressbar.scan_number = 7 - scan_progressbar.set_progress_source(ProgressSource.SCAN_PROGRESS) - - assert scan_progressbar._progress_source == ProgressSource.SCAN_PROGRESS - assert scan_progressbar.ui.source_label.text() == "Scan 7" - assert connect_calls[-1] == MessageEndpoints.scan_progress() - assert disconnect_calls == [] - - # switch to DEVICE_PROGRESS - device = "motor" - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device=device) - - assert scan_progressbar._progress_source == ProgressSource.DEVICE_PROGRESS - assert scan_progressbar.ui.source_label.text() == f"Device {device}" - assert connect_calls[-1] == MessageEndpoints.device_progress(device=device) - assert disconnect_calls == [MessageEndpoints.scan_progress()] - - # calling again with the SAME source should not add new connect calls - prev_connect_count = len(connect_calls) - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device=device) - assert len(connect_calls) == prev_connect_count, "No extra connect made for same source" - - -def test_set_progress_source_disconnects_previous_device_subscription( - scan_progressbar, monkeypatch -): - - disconnect_calls = [] - - monkeypatch.setattr(scan_progressbar.bec_dispatcher, "connect_slot", lambda *args: None) - monkeypatch.setattr( - scan_progressbar.bec_dispatcher, - "disconnect_slot", - lambda slot, endpoint: disconnect_calls.append(endpoint), - ) - - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device="motor1") - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device="motor2") - - assert disconnect_calls == [MessageEndpoints.device_progress(device="motor1")] - - -def test_set_progress_source_disconnects_device_when_switching_to_scan( - scan_progressbar, monkeypatch -): - - disconnect_calls = [] - - monkeypatch.setattr(scan_progressbar.bec_dispatcher, "connect_slot", lambda *args: None) - monkeypatch.setattr( - scan_progressbar.bec_dispatcher, - "disconnect_slot", - lambda slot, endpoint: disconnect_calls.append(endpoint), - ) - - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device="motor1") - scan_progressbar.set_progress_source(ProgressSource.SCAN_PROGRESS) - - assert disconnect_calls == [MessageEndpoints.device_progress(device="motor1")] - - -def test_cleanup_disconnects_active_device_subscription(scan_progressbar, monkeypatch): +def test_cleanup_disconnects_active_scan_subscription(scan_progressbar, monkeypatch): disconnect_calls = [] @@ -275,7 +189,6 @@ def test_cleanup_disconnects_active_device_subscription(scan_progressbar, monkey monkeypatch.setattr(scan_progressbar.progressbar, "deleteLater", lambda: None) monkeypatch.setattr(BECWidget, "cleanup", lambda self: None) - scan_progressbar.set_progress_source(ProgressSource.DEVICE_PROGRESS, device="motor1") with ( mock.patch.object(scan_progressbar, "close", wraps=scan_progressbar.close) as close_mock, mock.patch.object( @@ -284,316 +197,20 @@ def test_cleanup_disconnects_active_device_subscription(scan_progressbar, monkey ): ScanProgressBar.cleanup(scan_progressbar) - assert disconnect_calls == [ - MessageEndpoints.device_progress(device="motor1"), - MessageEndpoints.scan_queue_status(), - ] - assert scan_progressbar._progress_source is None - assert scan_progressbar._progress_device is None + assert len(disconnect_calls) == 1 + assert scan_progressbar.progress_tracker._connected is False close_mock.assert_not_called() delete_later_mock.assert_not_called() def test_cleanup_stops_active_task(scan_progressbar, monkeypatch): monkeypatch.setattr(BECWidget, "cleanup", lambda self: None) - scan_progressbar.task = ProgressTask(parent=scan_progressbar) - scan_progressbar._active_scan_id = "scan-1" - timer = scan_progressbar.task.timer + scan_progressbar.progress_tracker.task = ProgressTask(parent=scan_progressbar) + scan_progressbar.progress_tracker._active_scan_id = "scan-1" + timer = scan_progressbar.progress_tracker.task.timer ScanProgressBar.cleanup(scan_progressbar) assert not timer.isActive() - assert scan_progressbar.task is None - assert scan_progressbar._active_scan_id is None - - -def test_progressbar_queue_update(scan_progressbar): - """ - Test that an empty queue update does not change the progress source. - """ - msg = messages.ScanQueueStatusMessage( - queue={"primary": messages.ScanQueueStatus(info=[], status="RUNNING")} - ) - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - mock_set_source.assert_not_called() - - -def test_progressbar_queue_update_clears_task_when_queue_is_empty(scan_progressbar): - scan_progressbar.task = ProgressTask(parent=scan_progressbar) - scan_progressbar._active_scan_id = "scan-1" - timer = scan_progressbar.task.timer - msg = messages.ScanQueueStatusMessage( - queue={"primary": messages.ScanQueueStatus(info=[], status="RUNNING")} - ) - - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - - assert not timer.isActive() - assert scan_progressbar.task is None - assert scan_progressbar._active_scan_id is None - - -def test_progressbar_queue_update_clears_task_when_scan_is_not_running( - scan_progressbar, scan_message -): - scan_progressbar.task = ProgressTask(parent=scan_progressbar) - scan_progressbar._active_scan_id = "scan-1" - timer = scan_progressbar.task.timer - request_block = messages.RequestBlock( - msg=scan_message, - RID="some-rid", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=1, - scan_id="scan-1", - report_instructions=[{"scan_progress": 20}], - ) - msg = messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id="queue-1", - scan_id=["scan-1"], - status="completed", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[1], - ) - ], - status="RUNNING", - ) - }, - ) - - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - - assert not timer.isActive() - assert scan_progressbar.task is None - assert scan_progressbar._active_scan_id is None - mock_set_source.assert_not_called() - - -def test_progressbar_queue_update_with_scan(scan_progressbar, scan_message): - """ - Test that a queue update with a scan changes the progress source to SCAN_PROGRESS. - """ - request_block = messages.RequestBlock( - msg=scan_message, - RID="some-rid", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=1, - scan_id="e3f50794-852c-4bb1-965e-41c585ab0aa9", - report_instructions=[{"scan_progress": 20}], - ) - msg = messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id="40831e2c-fbd1-4432-8072-ad168a7ad964", - scan_id=["e3f50794-852c-4bb1-965e-41c585ab0aa9"], - status="RUNNING", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[1], - ) - ], - status="RUNNING", - ) - }, - ) - - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - mock_set_source.assert_called_once_with(ProgressSource.SCAN_PROGRESS) - - -def test_progressbar_queue_update_starts_new_task_for_new_scan(scan_progressbar, scan_message): - started = mock.Mock() - scan_progressbar.progress_started.connect(started) - - def queue_msg(scan_id: str, scan_number: int): - request_block = messages.RequestBlock( - msg=scan_message, - RID=f"rid-{scan_number}", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=scan_number, - scan_id=scan_id, - report_instructions=[{"scan_progress": 20}], - ) - return messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id=f"queue-{scan_number}", - scan_id=[scan_id], - status="RUNNING", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[scan_number], - ) - ], - status="RUNNING", - ) - }, - ) - - first_msg = queue_msg("scan-1", 1) - scan_progressbar.on_queue_update( - first_msg.content, first_msg.metadata, _override_slot_params={"verify_sender": False} - ) - first_task = scan_progressbar.task - assert first_task is not None - assert first_task.timer.isActive() - - second_msg = queue_msg("scan-2", 2) - scan_progressbar.on_queue_update( - second_msg.content, second_msg.metadata, _override_slot_params={"verify_sender": False} - ) - - assert started.call_count == 2 - assert not first_task.timer.isActive() - assert scan_progressbar.task is not first_task - assert scan_progressbar._active_scan_id == "scan-2" - - -def test_progressbar_queue_update_with_device(scan_progressbar, scan_message): - """ - Test that a queue update with a device changes the progress source to DEVICE_PROGRESS. - """ - request_block = messages.RequestBlock( - msg=scan_message, - RID="some-rid", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=1, - scan_id="e3f50794-852c-4bb1-965e-41c585ab0aa9", - report_instructions=[{"device_progress": ["samx"]}], - ) - msg = messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id="40831e2c-fbd1-4432-8072-ad168a7ad964", - scan_id=["e3f50794-852c-4bb1-965e-41c585ab0aa9"], - status="RUNNING", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[1], - ) - ], - status="RUNNING", - ) - }, - ) - - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - mock_set_source.assert_called_once_with(ProgressSource.DEVICE_PROGRESS, device="samx") - - -def test_progressbar_queue_update_ignores_empty_device_progress(scan_progressbar, scan_message): - request_block = messages.RequestBlock( - msg=scan_message, - RID="some-rid", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=1, - scan_id="e3f50794-852c-4bb1-965e-41c585ab0aa9", - report_instructions=[{"device_progress": []}], - ) - msg = messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id="40831e2c-fbd1-4432-8072-ad168a7ad964", - scan_id=["e3f50794-852c-4bb1-965e-41c585ab0aa9"], - status="RUNNING", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[1], - ) - ], - status="RUNNING", - ) - }, - ) - - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - mock_set_source.assert_not_called() - - -def test_progressbar_queue_update_with_no_scan_or_device(scan_progressbar, scan_message): - """ - Test that a queue update with neither scan nor device does not change the progress source. - """ - request_block = messages.RequestBlock( - msg=scan_message, - RID="some-rid", - scan_motors=["samx"], - readout_priority={"monitored": ["samx"]}, - is_scan=True, - scan_number=1, - scan_id="e3f50794-852c-4bb1-965e-41c585ab0aa9", - ) - msg = messages.ScanQueueStatusMessage( - metadata={}, - queue={ - "primary": messages.ScanQueueStatus( - info=[ - messages.QueueInfoEntry( - queue_id="40831e2c-fbd1-4432-8072-ad168a7ad964", - scan_id=["e3f50794-852c-4bb1-965e-41c585ab0aa9"], - status="RUNNING", - active_request_block=request_block, - is_scan=[True], - request_blocks=[request_block], - scan_number=[1], - ) - ], - status="RUNNING", - ) - }, - ) - - with mock.patch.object(scan_progressbar, "set_progress_source") as mock_set_source: - scan_progressbar.on_queue_update( - msg.content, msg.metadata, _override_slot_params={"verify_sender": False} - ) - mock_set_source.assert_not_called() + assert scan_progressbar.progress_tracker.task is None + assert scan_progressbar.progress_tracker._active_scan_id is None