diff --git a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py index ba79919..184a475 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py @@ -6,8 +6,7 @@ import os import subprocess import sys from datetime import datetime -from functools import partial -from typing import Literal, Optional, cast +from typing import Literal from bec_lib import bec_logger from bec_lib.endpoints import MessageEndpoints @@ -15,27 +14,12 @@ from bec_widgets.utils.bec_dispatcher import BECDispatcher from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.utils.colors import apply_theme, get_accent_colors from bec_widgets.utils.error_popups import SafeSlot -from qtpy.QtCore import Qt # pylint: disable=E0611 -from qtpy.QtGui import QFont +from qtpy.QtWidgets import QApplication, QVBoxLayout, QWidget -# pylint: disable=E0611 -from qtpy.QtWidgets import ( - QApplication, - QComboBox, - QDoubleSpinBox, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QPushButton, - QVBoxLayout, - QWidget, -) - -from debye_bec.bec_widgets.widgets.data_viewer.qt_widgets import TaggedListWidget -from debye_bec.bec_widgets.widgets.data_viewer.viewer import HDF5Viewer +from .panels.input_panel import InputPanel +from .panels.scan_view import ScanViewer logger = bec_logger.logger @@ -54,20 +38,15 @@ class DataViewer(BECWidget, QWidget): super().__init__(parent=parent, theme_update=True, *arg, **kwargs) self.get_bec_shortcuts() - logger.info(f"Type of self.client: {type(self.client)}") - central = QWidget() self.root_layout = QVBoxLayout(central) - self.input = InputPanel() - self.viewer = HDF5Viewer() - + self.viewer = ScanViewer() self.root_layout.addWidget(self.input, 0) self.root_layout.addWidget(self.viewer, 1) self.setLayout(self.root_layout) self.setWindowTitle("Data Viewer") - # self.resize(1800, 800) self.history = [] self.bec_dispatcher.connect_slot(self.on_history_update, MessageEndpoints.scan_history()) @@ -75,14 +54,9 @@ class DataViewer(BECWidget, QWidget): self.current_row = 0 - logger.info(self.client.acl) - logger.info(self.client.active_account) - logger.info(self.client.proc) - logger.info(self.client.username) - self.input.scan_sel.currentItemChanged_connect(self.scan_sel_changed) - self.input.load_button.clicked_connect(self.load_dataset) - self.input.unload_button.clicked_connect(self.unload_all_datasets) + self.input.load_button.clicked_connect(self.load_scan) + self.input.unload_button.clicked_connect(self.unload_all_scans) self.input.open_fm_button.clicked_connect(self.open_in_file_manager) def apply_theme(self, theme: Literal["dark", "light"]): @@ -97,10 +71,12 @@ class DataViewer(BECWidget, QWidget): @SafeSlot() def scan_sel_changed(self, *_, **kwargs): + """Updates the current row value of the scan selection list""" self.current_row = kwargs["value"]().row() @SafeSlot() def open_in_file_manager(self, *_): + """Open the scan folder in the systems default file manager""" if len(self.history) > 0: scan = self.history[self.current_row] filepath = scan["file_components"][0].decode().rsplit("/", 1)[0] @@ -113,54 +89,93 @@ class DataViewer(BECWidget, QWidget): ) @SafeSlot() - def load_dataset( - self, *_ - ): # TODO: Check scan file components for combined xas/xrd scans. Is the Pilatus file in there as well? + def load_scan(self, *_): + """ + Loads a scan. Find all files within the scan folder, sort them and + then load the files in the scan view + """ if len(self.history) > 0: scan = self.history[self.current_row] - file = scan["file_components"][0] + b"_master." + scan["file_components"][1] - logger.info(file.decode()) - self.viewer.load_files([file.decode()]) + base_filepath = scan["file_components"][0].decode().rsplit("/", 1)[0] + filenames = [ + f + for f in os.listdir(base_filepath) + if os.path.isfile(os.path.join(base_filepath, f)) + ] + + def sort_priority(name): + if "master" in name: + return 0 + if name.endswith(".h5"): + return 1 + return 2 + + sorted_files = [ + f"{base_filepath}/{name}" for name in sorted(filenames, key=sort_priority) + ] + self.viewer.load_files(sorted_files) @SafeSlot() - def unload_all_datasets(self, *_): + def unload_all_scans(self, *_): + """Removes all scans from the scan view""" self.viewer.clear_files() - def duration_string(self, start: str, end: str) -> str: + def duration_formatted(self, start: str, end: str) -> str: + """ + Calculates the duration of a scan based on start end end time and + formats it as an easy readable string. + + Args: + start(str): start time in iso-format + end(str): end time in iso-format + + Returns: + str: Formatted duration, e.g. '1min 10s' or '1h 13min' + """ start_dt = datetime.fromisoformat(start) end_dt = datetime.fromisoformat(end) - seconds = abs(int((end_dt - start_dt).total_seconds())) days, remainder = divmod(seconds, 86400) hours, remainder = divmod(remainder, 3600) - minutes, _ = divmod(remainder, 60) + minutes, seconds = divmod(remainder, 60) parts = [] - if days: parts.append(f"{days}d") if hours: parts.append(f"{hours}h") if minutes: parts.append(f"{minutes}min") - + if not days and not hours and minutes < 10: + parts.append(f"{seconds}s") return " ".join(parts) if parts else "<1min" + def time_formatted(self, iso_time: str) -> str: + """ + Formates a time as an easy readable string. + + Args: + iso_time(str): Time in iso-format + + Returns: + str: Time formatted with format '%d.%m.%Y %H:%M', e.g. '14.01.1995 08:12' + """ + dt = datetime.fromisoformat(iso_time) + return dt.strftime("%d.%m.%Y %H:%M") + @SafeSlot() def on_history_update(self, *_): + """Updates the scan list based on the bec scan history.""" self.history = [] self.input.scan_sel.clear() - # Get the length of the scan history, which is 0 when the bec server was started - # and no scan has finished yet. Limit the history to the latest 20 scans. + if self.client.history is None: + return max_scans = min(len(self.client.history), MAX_HIST_LEN) for n in range(1, max_scans): # last scans, limited by MAX_HIST_LEN - # logger.info(self.client.history[-n].metadata["bec"]["status"]) - start_time = self.client.history[-n].metadata["start_time"] - end_time = self.client.history[-n].metadata["end_time"] - # logger.info(type(start_time)) - scan_data = self.client.history[-n].metadata["bec"] - # logger.info(scan_data) + start_time = self.client.history[-n].metadata["start_time"] # type: ignore + end_time = self.client.history[-n].metadata["end_time"] # type: ignore + scan_data = self.client.history[-n].metadata["bec"] # type: ignore scan_number = scan_data["scan_number"] scan_name = scan_data["scan_name"] comment = scan_data["metadata"]["user_metadata"]["comment"] @@ -186,142 +201,13 @@ class DataViewer(BECWidget, QWidget): tags.append((comment, get_accent_colors().warning.name())) if status == "closed": tags.append((status, get_accent_colors().success.name())) - elif status == "halted": + elif status == "halted" or status == "aborted": tags.append((status, get_accent_colors().emergency.name())) else: tags.append((status, "#656365")) - tags.append((self.duration_string(start_time, end_time), "#656365")) + tags.append((self.duration_formatted(start_time, end_time), "#656365")) + tags.append((self.time_formatted(start_time), "#656365")) self.input.scan_sel.addTaggedItem(label=str(scan_number), tags=tags) - # logger.info(f"Scan history: {self.history}") - - -class InputPanel(QWidget): - """Panel for scan selection of the data viewer widget""" - - def __init__(self, parent=None): - super().__init__(parent) - self._layout = QHBoxLayout(self) - # self._layout.setSizeConstraint(QLayout.SetFixedSize) # type: ignore - - # Scan selection - self.scan_sel = ListWidget("scan_sel", "Scan", ["Si", "Rh", "Pt"]) - self.load_button = Button(label_button="Load Dataset", enabled=True) - self.unload_button = Button(label_button="Unload all", enabled=True) - self.open_fm_button = Button(label_button="Open in File Manager", enabled=True) - - self._button_layout = QVBoxLayout() - self._button_layout.addWidget(self.load_button) - self._button_layout.addWidget(self.unload_button) - self._button_layout.addWidget(self.open_fm_button) - self._button_layout.addStretch() - - # Assemble complete scan selection group - self.input_group = Group( - "Scan selection", [self._button_layout, self.scan_sel], orientation="horizontal" - ) - - self._layout.addWidget(self.input_group) - # self._layout.addStretch() - - -class Group(QGroupBox): - def __init__(self, label, objs, orientation="vertical"): - super().__init__(label) - if orientation == "vertical": - self._layout = QVBoxLayout(self) # type: ignore - elif orientation == "horizontal": # assume horizontal - self._layout = QHBoxLayout(self) # type: ignore - else: - raise ValueError(f"Orientation {orientation} is not supported!") - for obj in objs: - if isinstance(obj, QWidget): - self._layout.addWidget(obj) # type: ignore - elif isinstance(obj, QLayout): - self._layout.addLayout(obj) - - -class ListWidget(QWidget): - def __init__(self, identifier="", label="", enums=[]): - super().__init__() - layout = QHBoxLayout(self) - layout.setContentsMargins(10, 0, 0, 0) - layout.setSpacing(0) - self.identifier = identifier - # self.label = QLabel(label) - # self.label.setFixedWidth(140) - # self.label.setContentsMargins(0, 0, 10, 0) - # self.label.setWordWrap(True) - # layout.addWidget(self.label) - self.value = TaggedListWidget() - # self.value.setFixedWidth(400) - # for entry in enums: - # self.value.addItem(entry) - layout.addWidget(self.value) - - def clear(self): - self.value.clear() - - def addTaggedItem(self, label, tags): - self.value.addTaggedItem(label, tags) - - def setCurrentIndex(self, text): - self.value.setCurrentIndex(text) - - # def currentIndex(self) -> int: - # return self.value.currentIndex() - - # def has_focus(self) -> bool: - # return QApplication.focusWidget() is self.value.view() - - def currentItemChanged_connect(self, func): - """Connect a function to the Enter/Return key press.""" - self.value.currentItemChanged.connect( - partial( - func, - identifier=self.identifier, - value_obj=self.value, - value=lambda: self.value.currentIndex(), - ) - ) - - def setDisabled(self, disable): - self.value.setDisabled(disable) - - -class Button(QWidget): - def __init__(self, label=None, label_button: str = "", enabled=False): - super().__init__() - layout = QHBoxLayout(self) - layout.setContentsMargins(10, 0, 0, 0) - layout.setSpacing(0) - if label is not None: - self.label = QLabel(label) - self.label.setFixedWidth(140) - layout.addWidget(self.label) - self.button = QPushButton(label_button) - if label is not None: - self.button.setFixedWidth(160) - self.enable_button(enabled) - layout.addWidget(self.button) - - def clicked_connect(self, func): - """Connect a function to the button press.""" - self.button.clicked.connect(func) - - def enable_button(self, enable: bool = False): - if enable: - self.button.setStyleSheet( - f"QPushButton {{background-color: {get_accent_colors().default.name()}; color: white;}}" - ) - self.button.setEnabled(True) - else: # disabled - self.button.setStyleSheet( - "QPushButton {{background-color: rgb(120, 120, 120); color: white;}}" - ) - self.button.setDisabled(True) - - def setText(self, text): - self.button.setText(text) if __name__ == "__main__": diff --git a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py index cf5b92f..41560d4 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py @@ -5,7 +5,7 @@ from bec_widgets.utils.bec_designer import designer_material_icon from qtpy.QtDesigner import QDesignerCustomWidgetInterface from qtpy.QtWidgets import QWidget -from debye_bec.bec_widgets.widgets.data_viewer.data_viewer import DataViewer +from .data_viewer import DataViewer DOM_XML = """ @@ -22,7 +22,7 @@ class DataViewerPlugin(QDesignerCustomWidgetInterface): # pragma: no cover def createWidget(self, parent): if parent is None: - return QWidget() + return QWidget() t = DataViewer(parent) return t diff --git a/debye_bec/bec_widgets/widgets/data_viewer/loaders.py b/debye_bec/bec_widgets/widgets/data_viewer/loaders.py new file mode 100644 index 0000000..5b260ab --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/loaders.py @@ -0,0 +1,229 @@ +""" +File-format loader (HDF5, images). +""" + +import os +from abc import ABC, abstractmethod +from typing import Iterator, Literal, Optional + +import h5py +import numpy as np + + +class NodeInfo: + """Describes a single node (group or dataset) inside a loaded file.""" + + __slots__ = ("name", "path", "kind", "dtype", "shape") + + def __init__( + self, + name: str, + path: str, + kind: Literal["group", "dataset"], + dtype: str = "", + shape: tuple[int, ...] = (), + ): + self.name = name + self.path = path + self.kind = kind + self.dtype = dtype # e.g. "float", "int", "str", … + self.shape = shape # empty tuple for scalars / groups + + +class BaseFileLoader(ABC): + """ + Abstract base class for file-format loaders. + + Subclass this to add support for a new format. Three things are required: + + 1. ``EXTENSIONS`` — tuple of lowercase extensions this loader handles, + e.g. ``(".h5", ".hdf5")``. + + 2. ``open(filepath)`` — open the file and keep any handles alive. + + 3. ``iter_nodes(path)`` — yield ``NodeInfo`` objects for the direct + children of *path* (depth-1 walk; the tree widget calls this + recursively as the user expands nodes). + + 4. ``read_dataset(path)`` — return the dataset at *path* as a + ``numpy.ndarray``. + + 5. ``close()`` — release any open file handles. + + 6. ``child_count(path)`` — return the number of direct children of a + group node (used for the status label; override if cheap to compute). + """ + + EXTENSIONS: tuple[str, ...] = () + + @abstractmethod + def open(self, filepath: str) -> None: ... + + @abstractmethod + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: ... + + @abstractmethod + def read_dataset(self, path: str) -> np.ndarray: ... + + @abstractmethod + def close(self) -> None: ... + + def child_count(self, path: str) -> int: + return sum(1 for _ in self.iter_nodes(path)) + + +class HDF5Loader(BaseFileLoader): + """Loader for HDF5 / NeXus files (.h5, .hdf5, .nxs, .nx).""" + + EXTENSIONS = (".h5", ".hdf5", ".hdf", ".nxs", ".nx") + + def __init__(self): + self._file: Optional[h5py.File] = None + + def open(self, filepath: str) -> None: + self._file = h5py.File(filepath, "r") + + def close(self) -> None: + if self._file is not None: + self._file.close() + self._file = None + + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: + assert self._file is not None, "File not open" + obj = self._file[path] if path != "/" else self._file + + if not isinstance(obj, h5py.Group): + return + + for key in obj.keys(): + try: + child = obj[key] + except Exception: + continue + + child_path = child.name # h5py always gives the absolute path + + if isinstance(child, h5py.Group): + yield NodeInfo(name=key, path=child_path, kind="group") + + elif isinstance(child, h5py.Dataset): + shape_tuple = child.shape + d = child.dtype + dtype = "unknown" + if np.issubdtype(d, np.integer): + dtype = "int" + if np.issubdtype(d, np.floating): + dtype = "float" + if np.issubdtype(d, np.complexfloating): + dtype = "complex" + if d.kind in ("S", "U", "O"): + dtype = "str" + + yield NodeInfo( + name=key, path=child_path, kind="dataset", dtype=dtype, shape=shape_tuple + ) + + def read_dataset(self, path: str) -> np.ndarray: + assert self._file is not None, "File not open" + return self._file[path][()] + + def child_count(self, path: str) -> int: + assert self._file is not None, "File not open" + obj = self._file[path] if path != "/" else self._file + return len(obj) if isinstance(obj, h5py.Group) else 0 + + +class ImageLoader(BaseFileLoader): + """ + Loader for raster image files. + + The file is treated as a single, flat dataset. ``iter_nodes`` yields one + leaf node. + + Requires: Pillow (``pip install Pillow``) + """ + + EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".gif", + ".tiff", + ".tif", + ".bmp", + ".webp", + ".ico", + ".ppm", + ".pgm", + ".pbm", + ) + + def __init__(self): + self._filepath: Optional[str] = None + + def open(self, filepath: str) -> None: + # Validate that Pillow can open it; keep only the path. + try: + from PIL import Image as _PILImage # noqa: F401 — existence check + + _PILImage.open(filepath).verify() + except Exception as exc: + raise OSError(f"Cannot open image {filepath!r}: {exc}") from exc + self._filepath = filepath + + def close(self) -> None: + self._filepath = None + + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: + """Images have no internal hierarchy — yield a single leaf node.""" + if path != "/" or self._filepath is None: + return + from PIL import Image as _PILImage + + with _PILImage.open(self._filepath) as img: + w, h = img.size + mode = img.mode # e.g. "RGB", "RGBA", "L", … + + name = os.path.basename(self._filepath) + yield NodeInfo(name=name, path="/image", kind="dataset", dtype=mode, shape=(h, w)) + + def read_dataset(self, path: str) -> np.ndarray: + """Return the image as a uint8 numpy array (H x W x C or H x W).""" + from PIL import Image as _PILImage + + with _PILImage.open(self._filepath) as img: # type: ignore[arg-type] + # Animated GIF → first frame only + if hasattr(img, "n_frames") and img.n_frames > 1: + img.seek(0) + return np.asarray(img) + + def child_count(self, path: str) -> int: + return 1 if path == "/" else 0 + + +class LoaderRegistry: + """Maps file extensions to loader classes.""" + + def __init__(self): + self._registry: dict[str, type[BaseFileLoader]] = {} + + def register(self, loader_cls: type[BaseFileLoader]) -> None: + """Register a loader class for all extensions it declares.""" + for ext in loader_cls.EXTENSIONS: + self._registry[ext.lower()] = loader_cls + + def get_loader(self, filepath: str) -> Optional[BaseFileLoader]: + """Return a fresh loader instance for *filepath*, or None if unsupported.""" + ext = os.path.splitext(filepath)[1].lower() + cls = self._registry.get(ext) + return cls() if cls is not None else None + + @property + def supported_extensions(self) -> list[str]: + return sorted(self._registry) + + +# Default global registry — pre-populated with built-in loaders. +registry = LoaderRegistry() +registry.register(HDF5Loader) +registry.register(ImageLoader) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/__init__,py b/debye_bec/bec_widgets/widgets/data_viewer/panels/__init__,py new file mode 100644 index 0000000..e69de29 diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py new file mode 100644 index 0000000..e9b368c --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py @@ -0,0 +1,393 @@ +""" +Data viewer displaying the data +""" + +from typing import Literal, Optional + +import numpy as np +import pyqtgraph as pg +from bec_lib import bec_logger +from bec_widgets.utils.colors import Colors +from qtpy.QtCore import Qt + +# pylint: disable=E0611 +from qtpy.QtGui import QFont, QPixmap + +# pylint: disable=E0611 +from qtpy.QtWidgets import ( + QAbstractItemView, + QApplication, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QRadioButton, + QSizePolicy, + QSlider, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +logger = bec_logger.logger + +MAX_ROWS = 2000 +MAX_COLS = 500 + + +class DataView(QWidget): + def __init__(self): + super().__init__() + self._layout = QVBoxLayout(self) + + header = QHBoxLayout() + self.path_label = QLabel("") + self.path_label.setObjectName("path_label") + self.info_label = QLabel("") + self.info_label.setObjectName("info_label") + header.addWidget(self.path_label, 1) + header.addWidget(self.info_label) + self._layout.addLayout(header) + + mode_box = QGroupBox("View mode") + mode_layout = QHBoxLayout(mode_box) + mode_layout.setContentsMargins(8, 4, 8, 4) + self.rb_auto = QRadioButton("Auto") + self.rb_plot = QRadioButton("Plot") + self.rb_image = QRadioButton("Image") + self.rb_table = QRadioButton("Table") + self.rb_auto.setChecked(True) + for rb in (self.rb_auto, self.rb_image, self.rb_plot, self.rb_table): + mode_layout.addWidget(rb) + rb.toggled.connect(self._on_mode_change) + mode_layout.addStretch() + self._layout.addWidget(mode_box) + + self.content = QWidget() + self.content_layout = QVBoxLayout(self.content) + self.content_layout.setContentsMargins(0, 0, 0, 0) + self._layout.addWidget(self.content, 1) + + self.plot_widget = None + self.image_widget = None + self._current_data = None + self.show_empty() + + def apply_theme(self, theme: Optional[Literal["dark", "light"]] = None): + """ + Apply the theme + + Args: + theme (Optional[str]): Theme, either "dark", "light", or None. Defaults to None. + """ + if theme is None: + app = QApplication.instance() + theme = app.theme.theme # type: ignore + + bg_color = pg.getConfigOption("background") + fg_color = pg.getConfigOption("foreground") + if self.plot_widget is not None: + n_curves = len(self.plot_widget.listDataItems()) + colors = Colors.golden_angle_color( + colormap="plasma", num=max(10, n_curves + 1), format="HEX" + ) + for idx, curve in enumerate(self.plot_widget.listDataItems()): + curve.setPen(pg.mkPen(color=colors[idx])) + # Background + self.plot_widget.setBackground(bg_color) + # Axes (tick marks, tick labels, axis line) + for axis in ["left", "bottom", "right", "top"]: + ax = self.plot_widget.getAxis(axis) + ax.setPen(pg.mkPen(color=fg_color)) + ax.setTextPen(pg.mkPen(color=fg_color)) + + if self.image_widget is not None: + self.image_widget.getView().setBackgroundColor(bg_color) + self.image_widget.ui.histogram.setBackground(bg_color) + + def _clear_stack(self): + while self.content_layout.count(): + item = self.content_layout.takeAt(0) + if item.widget(): + item.widget().deleteLater() + self.plot_widget = None + self.image_widget = None + + def show_empty(self): + """Empties the content area.""" + self._clear_stack() + empty_label = QLabel("No data selected") + empty_label.setObjectName("info_label") + empty_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.path_label.setText("") + self.info_label.setText("") + self.content_layout.addWidget(empty_label) + + def show_unsupported(self, path: str = "") -> None: + """Display a friendly 'not implemented' message for unknown file types.""" + self._clear_stack() + self._current_data = None + self.path_label.setText(path) + self.info_label.setText("") + lbl = QLabel("File type not supported") + lbl.setObjectName("info_label") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.content_layout.addWidget(lbl) + + def _on_mode_change(self): + if self._current_data is not None: + self.display(self._current_data, self.path_label.text()) + + def _active_mode(self): + if self.rb_plot.isChecked(): + return "plot" + if self.rb_image.isChecked(): + return "image" + if self.rb_table.isChecked(): + return "table" + return "auto" + + def display(self, data, path: str = "") -> None: + """ + Render *data* in the panel. + + Parameters + ---------- + data: + A ``numpy.ndarray`` (or anything convertible to one). + path: + Human-readable label shown in the header. + """ + self._current_data = data + self.path_label.setText(path) + + if not isinstance(data, np.ndarray): + data = np.array(data) + + self.info_label.setText(f"shape {data.shape}, dtype {data.dtype}, {data.size} elements") + + mode = self._active_mode() + if mode == "auto": + if data.ndim <= 1 and data.size > 1: + mode = "plot" + elif data.ndim <= 4 and min(data.shape, default=0) > 1: + mode = "image" + else: + mode = "table" + + if mode == "plot": + self._show_plot_1d(data) + elif mode == "image": + self._show_image_2d(data) + else: + self._show_table(data) + + def _show_plot_1d(self, data): + self._clear_stack() + + is_2d = data.ndim == 2 + + if is_2d: + n_rows, _ = data.shape + row_data = data[0].astype(np.float32) + else: + row_data = data.reshape(-1).astype(np.float32) + + x = np.arange(row_data.size, dtype=np.float32) + + self.plot_widget = pg.PlotWidget() + plot_item = self.plot_widget.getPlotItem() + assert plot_item is not None, "PlotWidget has no PlotItem" + plot_item.showGrid(x=True, y=True, alpha=0.25) + self.plot_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + plot_item.setAutoVisible(y=False) # type: ignore[attr-defined] + + curve = pg.PlotDataItem( + x, row_data, pen=pg.mkPen(color="#2980b9", width=1.6), antialias=False + ) + self.plot_widget.addItem(curve) + + curve.setDownsampling(auto=True, method="peak") + curve.setClipToView(True) + curve.setSkipFiniteCheck(True) + + plot_item.enableAutoRange() # type: ignore[attr-defined] + + if is_2d: + slider = QSlider(Qt.Orientation.Vertical) + slider.setMinimum(0) + slider.setMaximum(n_rows - 1) + slider.setValue(0) + slider.setFixedWidth(32) + slider.setPageStep(1) + + current_row_label = QLabel("0") + current_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + current_row_label.setFixedWidth(32) + + max_row_label = QLabel(f"0:{n_rows-1}") + max_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + max_row_label.setFixedWidth(32) + + def on_row_changed(row): + current_row_label.setText(str(row)) + new_data = data[row].astype(np.float32) + new_x = np.arange(new_data.size, dtype=np.float32) + curve.setData(new_x, new_data) + plot_item.enableAutoRange() # type: ignore[attr-defined] + + slider.valueChanged.connect(on_row_changed) + + slider_col = QWidget() + slider_col.setFixedWidth(36) + col_layout = QVBoxLayout(slider_col) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(2) + col_layout.addWidget(max_row_label) + col_layout.addWidget(slider) + col_layout.addWidget(current_row_label) + + container = QWidget() + h_layout = QHBoxLayout(container) + h_layout.setContentsMargins(0, 0, 0, 0) + h_layout.setSpacing(4) + h_layout.addWidget(slider_col) + h_layout.addWidget(self.plot_widget) + + self.content_layout.addWidget(container) + else: + self.content_layout.addWidget(self.plot_widget) + + self.apply_theme() + + def _show_image_2d(self, data): + self._clear_stack() + + stacked = False + n_images = 0 + rgb = False + img = data + if data.ndim == 3: + if data.shape[-1] in (3, 4): + rgb = True + else: + stacked = True + n_images = data.shape[0] + img = data[0, :] + elif data.ndim == 4: + if data.shape[-1] in (3, 4): + rgb = True + stacked = True + n_images = data.shape[0] + img = data[0, :] + + self.image_widget = pg.ImageView() + + self.image_widget.ui.roiBtn.hide() + self.image_widget.ui.menuBtn.hide() + + if not rgb: + self.image_widget.setColorMap(pg.colormap.get("inferno", source="matplotlib")) + + def set_image(img, autoLevels=True, autoHistogramRange=True): + if rgb: + self.image_widget.imageItem.setOpts(axisOrder="row-major") + self.image_widget.setImage(img) + else: + self.image_widget.setImage( + img.T, autoLevels=autoLevels, autoHistogramRange=autoHistogramRange + ) + + set_image(img) + self.image_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + if stacked: + slider = QSlider(Qt.Orientation.Vertical) + slider.setMinimum(0) + slider.setMaximum(n_images - 1) + slider.setValue(0) + slider.setFixedWidth(32) + slider.setPageStep(1) + + current_image_label = QLabel("0") + current_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + current_image_label.setFixedWidth(32) + + max_image_label = QLabel(f"0:{n_images-1}") + max_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + max_image_label.setFixedWidth(32) + + def on_image_changed(row): + current_image_label.setText(str(row)) + set_image(data[row, :], autoLevels=False, autoHistogramRange=False) + + slider.valueChanged.connect(on_image_changed) + + slider_col = QWidget() + slider_col.setFixedWidth(36) + col_layout = QVBoxLayout(slider_col) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(2) + col_layout.addWidget(max_image_label) + col_layout.addWidget(slider) + col_layout.addWidget(current_image_label) + + container = QWidget() + h_layout = QHBoxLayout(container) + h_layout.setContentsMargins(0, 0, 0, 0) + h_layout.setSpacing(4) + h_layout.addWidget(slider_col) + h_layout.addWidget(self.image_widget) + + self.content_layout.addWidget(container) + else: + self.content_layout.addWidget(self.image_widget) + + self.apply_theme() + + def _show_table(self, data): + self._clear_stack() + + if data.ndim == 0: + flat = data.reshape(1, 1) + elif data.ndim == 1: + flat = data.reshape(-1, 1) + elif data.ndim == 2: + flat = data + else: + flat = data.reshape(-1, data.shape[-1]) + + rows, cols = flat.shape + show_rows = min(rows, MAX_ROWS) + show_cols = min(cols, MAX_COLS) + + if rows > MAX_ROWS or cols > MAX_COLS: + note = QLabel(f"⚠ Showing {show_rows}/{rows} rows x {show_cols}/{cols} columns") + note.setObjectName("info_label") + note.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.content_layout.addWidget(note) + + table = QTableWidget(show_rows, show_cols) + table.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) + table.setSelectionMode(QAbstractItemView.SelectionMode.ContiguousSelection) + table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive) + + is_float = np.issubdtype(flat.dtype, np.floating) + is_complex = np.iscomplexobj(flat) + is_bytes = flat.dtype.kind == "S" + + for r in range(show_rows): + for c in range(show_cols): + val = flat[r, c] + txt = ( + f"{val:.6g}" + if is_float + else f"{val:.4g}" if is_complex else str(val.decode()) if is_bytes else str(val) + ) + cell = QTableWidgetItem(txt) + cell.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) + table.setItem(r, c, cell) + + self.content_layout.addWidget(table) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py new file mode 100644 index 0000000..f249b94 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py @@ -0,0 +1,32 @@ +# pylint: disable=E0611 +from qtpy.QtWidgets import QHBoxLayout, QVBoxLayout, QWidget + +# pylint: disable=E0402 +from ..widgets.qt_widgets import Button, Group, ListWidget + + +class InputPanel(QWidget): + """Panel for scan selection of the data viewer widget""" + + def __init__(self, parent=None): + super().__init__(parent) + self._layout = QHBoxLayout(self) + + # Scan selection + self.scan_sel = ListWidget("scan_sel") + self.load_button = Button(label_button="Load Dataset", enabled=True) + self.unload_button = Button(label_button="Unload all", enabled=True) + self.open_fm_button = Button(label_button="Open in File Manager", enabled=True) + + self._button_layout = QVBoxLayout() + self._button_layout.addWidget(self.load_button) + self._button_layout.addWidget(self.unload_button) + self._button_layout.addWidget(self.open_fm_button) + self._button_layout.addStretch() + + # Assemble complete scan selection group + self.input_group = Group( + "Scan selection", [self._button_layout, self.scan_sel], orientation="horizontal" + ) + + self._layout.addWidget(self.input_group) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py new file mode 100644 index 0000000..aa3ebe7 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py @@ -0,0 +1,213 @@ +""" +Scan viewer. Displays files of one or more scans in a tree view +""" + +import os +from typing import Literal, Optional + +from bec_lib import bec_logger +from bec_qthemes import material_icon +from bec_widgets.utils.colors import get_accent_colors +from qtpy.QtCore import Qt + +# pylint: disable=E0611 +from qtpy.QtGui import QBrush, QColor + +# pylint: disable=E0611 +from qtpy.QtWidgets import ( + QHBoxLayout, + QHeaderView, + QMainWindow, + QSplitter, + QTreeWidget, + QTreeWidgetItem, + QVBoxLayout, + QWidget, +) + +# pylint: disable=E0402 +from ..loaders import BaseFileLoader, registry +from ..widgets.qt_widgets import Group +from .data_view import DataView + +logger = bec_logger.logger + +ICON_SIZE = 20 + + +class ScanViewer(QMainWindow): + """ + Generic scan viewer. Supports any format registered in *registry*. + + Args: + filepath(str): Optional path to open on startup. + """ + + def __init__(self, filepath: Optional[str] = None): + super().__init__() + self.registry = registry + + self._open_files: dict[str, tuple[BaseFileLoader, str]] = {} + + central = QWidget() + self.setCentralWidget(central) + root_layout = QHBoxLayout(central) + + splitter = QSplitter(Qt.Orientation.Horizontal) + splitter.setChildrenCollapsible(False) + + left_pane = QWidget() + left_layout = QVBoxLayout(left_pane) + + self.tree = QTreeWidget() + self.tree.setMinimumWidth(250) + self.tree.setHeaderLabels(["Name", "Type", "Shape"]) + self.tree.header().setStretchLastSection(False) + self.tree.header().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.tree.header().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.tree.header().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tree.itemClicked.connect(self._on_item_clicked) + + left_layout.addWidget(self.tree, 1) + + self.data_panel = DataView() + + splitter.addWidget(left_pane) + splitter.addWidget(self.data_panel) + + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 3) + splitter.setSizes([300, 900]) + splitter.setHandleWidth(6) + splitter.setChildrenCollapsible(False) + + self.scan_view_group = Group("Scan view", [splitter]) + + root_layout.addWidget(self.scan_view_group) + + if filepath: + self.load_files([filepath]) + + def apply_theme(self, theme: Literal["dark", "light"]): + """ + Apply the theme + + Args: + theme (str): Theme, either "dark" or "light" + """ + self.data_panel.apply_theme(theme) + + def load_files(self, filepaths: list[str]) -> None: + """Open one or more files and add each as a top-level tree node.""" + for fp in filepaths: + if fp in self._open_files: + continue # already loaded + + loader = self.registry.get_loader(fp) + if loader is None: + supported = ", ".join(self.registry.supported_extensions) + logger.warning("No loader found for %r. Supported extensions: %s", fp, supported) + continue + + try: + loader.open(fp) + except Exception as exc: + logger.error("Failed to open %r: %s", fp, exc) + continue + + display_name = os.path.basename(fp) + self._open_files[fp] = (loader, display_name) + self._add_file_to_tree(fp, loader, display_name) + + def clear_files(self) -> None: + """Close all open files and reset the tree.""" + for loader, _ in self._open_files.values(): + loader.close() + self._open_files.clear() + self.tree.clear() + self.data_panel.show_empty() + + def closeEvent(self, event): + """Close all""" + for loader, _ in self._open_files.values(): + loader.close() + self._open_files.clear() + super().closeEvent(event) + + def _add_file_to_tree(self, filepath: str, loader: BaseFileLoader, display_name: str) -> None: + """Add a single file as a new top-level node in the tree.""" + dataset_icon = material_icon( + "dataset", size=(ICON_SIZE, ICON_SIZE), color=get_accent_colors().default.name() + ) + root_item = QTreeWidgetItem(self.tree, [display_name, "Group", ""]) + root_item.setIcon(0, dataset_icon) + root_item.setData(0, Qt.ItemDataRole.UserRole, "/") # path + root_item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") # kind + root_item.setData(0, Qt.ItemDataRole.UserRole + 2, filepath) # file key + + self._populate_tree(root_item, loader, "/") + self.tree.addTopLevelItem(root_item) + + # Expand first 2 levels by default + self.tree.expandItem(root_item) + for i in range(root_item.childCount()): + self.tree.expandItem(root_item.child(i)) + + self.tree.setCurrentItem(root_item) + + def _populate_tree( + self, parent_item: QTreeWidgetItem, loader: BaseFileLoader, path: str + ) -> None: + folder_icon = material_icon("folder", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + vector_icon = material_icon("show_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + array_icon = material_icon( + "stacked_line_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9" + ) + scalar_icon = material_icon("point_scan", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + str_icon = material_icon("text_snippet", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + + for node in loader.iter_nodes(path): + shape_str = "x".join(str(s) for s in node.shape) if node.shape else "scalar" + + if node.kind == "group": + item = QTreeWidgetItem(parent_item, [node.name, "Group", ""]) + item.setIcon(0, folder_icon) + item.setData(0, Qt.ItemDataRole.UserRole, node.path) + item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") + self._populate_tree(item, loader, node.path) + + else: # dataset + if shape_str == "scalar": + icon = str_icon if node.dtype == "str" else scalar_icon + elif "x" in shape_str: + icon = array_icon + else: + icon = vector_icon + + item = QTreeWidgetItem(parent_item, [node.name, node.dtype, shape_str]) + item.setIcon(0, icon) + item.setData(0, Qt.ItemDataRole.UserRole, node.path) + item.setData(0, Qt.ItemDataRole.UserRole + 1, "dataset") + item.setForeground(1, QBrush(QColor(get_accent_colors().success.name()))) + item.setForeground(2, QBrush(QColor("#656365"))) + + def _get_filepath_for_item(self, item: QTreeWidgetItem) -> str: + """Walk up the tree to find the filepath stored on the root node.""" + node = item + while node.parent(): + node = node.parent() + return node.data(0, Qt.ItemDataRole.UserRole + 2) + + def _on_item_clicked(self, item: QTreeWidgetItem, _col: int) -> None: + path = item.data(0, Qt.ItemDataRole.UserRole) + kind = item.data(0, Qt.ItemDataRole.UserRole + 1) + filepath = self._get_filepath_for_item(item) + + if not path or not filepath: + return + + loader, _ = self._open_files[filepath] + + if kind == "dataset": + data = loader.read_dataset(path) + self.data_panel.display(data, path) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py index 4ead24d..1be16c2 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py @@ -6,7 +6,7 @@ def main(): # pragma: no cover return from PySide6.QtDesigner import QPyDesignerCustomWidgetCollection - from debye_bec.bec_widgets.widgets.data_viewer.data_viewer_plugin import DataViewerPlugin + from .data_viewer_plugin import DataViewerPlugin QPyDesignerCustomWidgetCollection.addCustomWidget(DataViewerPlugin()) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/viewer.py deleted file mode 100644 index 93312fe..0000000 --- a/debye_bec/bec_widgets/widgets/data_viewer/viewer.py +++ /dev/null @@ -1,515 +0,0 @@ -""" -HDF5 Viewer — qtpy + pyqtgraph + h5py -""" - -from typing import Literal, Optional, cast - -import h5py -import numpy as np -import pyqtgraph as pg -from bec_lib import bec_logger -from bec_qthemes import material_icon -from bec_widgets.utils.colors import Colors, get_accent_colors -from qtpy.QtCore import Qt - -# pylint: disable=E0611 -from qtpy.QtGui import QBrush, QColor, QFont - -# pylint: disable=E0611 -from qtpy.QtWidgets import ( - QAbstractItemView, - QApplication, - QGroupBox, - QHBoxLayout, - QHeaderView, - QLabel, - QMainWindow, - QRadioButton, - QSizePolicy, - QSlider, - QSplitter, - QTableWidget, - QTableWidgetItem, - QTreeWidget, - QTreeWidgetItem, - QVBoxLayout, - QWidget, -) - -logger = bec_logger.logger - -ICON_SIZE = 20 - - -# ── Data / Plot panel ────────────────────────────────────────────────────── -class DataPanel(QWidget): - def __init__(self): - super().__init__() - self._layout = QVBoxLayout(self) - - # Header - hdr = QHBoxLayout() - self.path_label = QLabel("Select a dataset from the tree") - self.path_label.setObjectName("path_label") - self.info_label = QLabel("") - self.info_label.setObjectName("info_label") - hdr.addWidget(self.path_label, 1) - hdr.addWidget(self.info_label) - self._layout.addLayout(hdr) - - # View-mode selector - mode_box = QGroupBox("View mode") - mode_layout = QHBoxLayout(mode_box) - mode_layout.setContentsMargins(8, 4, 8, 4) - self.rb_auto = QRadioButton("Auto") - self.rb_plot = QRadioButton("Plot") - self.rb_image = QRadioButton("Image") - self.rb_table = QRadioButton("Table") - self.rb_auto.setChecked(True) - for rb in (self.rb_auto, self.rb_plot, self.rb_image, self.rb_table): - mode_layout.addWidget(rb) - rb.toggled.connect(self._on_mode_change) - mode_layout.addStretch() - self._layout.addWidget(mode_box) - - # Content stack - self.stack = QWidget() - self.stack_layout = QVBoxLayout(self.stack) - self.stack_layout.setContentsMargins(0, 0, 0, 0) - self._layout.addWidget(self.stack, 1) - - self.plot_widget = None - self.image_widget = None - - self._current_data = None - self.show_empty() - - # ── helpers ─────────────────────────────────────────────────────────── - def _clear_stack(self): - while self.stack_layout.count(): - item = self.stack_layout.takeAt(0) - if item.widget(): - item.widget().deleteLater() - self.plot_widget = None - self.image_widget = None - - def show_empty(self): - self._clear_stack() - lbl = QLabel("No data selected") - lbl.setObjectName("info_label") - lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.stack_layout.addWidget(lbl) - - def _on_mode_change(self): - if self._current_data is not None: - self.display(self._current_data, self.path_label.text()) - - def _active_mode(self): - if self.rb_plot.isChecked(): - return "plot" - if self.rb_image.isChecked(): - return "image" - if self.rb_table.isChecked(): - return "table" - return "auto" - - # ── public ──────────────────────────────────────────────────────────── - def display(self, data, path=""): - self._current_data = data - self.path_label.setText(path) - - if not isinstance(data, np.ndarray): - data = np.array(data) - - self.info_label.setText( - f"shape {data.shape} · dtype {data.dtype} · {data.size:,} elements" - ) - - mode = self._active_mode() - if mode == "auto": - if data.ndim <= 1 and data.size > 1: - mode = "plot" - elif data.ndim == 2 and min(data.shape) > 1: - mode = "image" - else: - mode = "table" - - if mode == "plot": - self._show_plot_1d(data) - elif mode == "image": - self._show_image_2d(data) - else: - self._show_table(data) - - # ── 1-D line plot ────────────────────────────────────────────────────── - def _show_plot_1d(self, data): - self._clear_stack() - - is_2d = data.ndim == 2 - - if is_2d: - n_rows, n_cols = data.shape - row_data = data[0].astype(np.float32) - else: - row_data = data.reshape(-1).astype(np.float32) - - x = np.arange(row_data.size, dtype=np.float32) - - self.plot_widget = pg.PlotWidget() - plot_item = self.plot_widget.getPlotItem() - assert plot_item is not None, "PlotWidget has no PlotItem" - plot_item.showGrid(x=True, y=True, alpha=0.25) - self.plot_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - - plot_item.setAutoVisible(y=False) # type: ignore[attr-defined] - - curve = pg.PlotDataItem( - x, row_data, pen=pg.mkPen(color="#2980b9", width=1.6), antialias=False - ) - self.plot_widget.addItem(curve) - - curve.setDownsampling(auto=True, method="peak") - curve.setClipToView(True) - curve.setSkipFiniteCheck(True) - - plot_item.enableAutoRange() # type: ignore[attr-defined] - - if is_2d: - # --- Slider --- - slider = QSlider(Qt.Orientation.Vertical) - slider.setMinimum(0) - slider.setMaximum(n_rows - 1) - slider.setValue(0) - slider.setFixedWidth(32) - slider.setPageStep(1) - - row_label = QLabel("0") - row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - row_label.setFixedWidth(32) - - def on_row_changed(row): - row_label.setText(str(row)) - new_data = data[row].astype(np.float32) - new_x = np.arange(new_data.size, dtype=np.float32) - curve.setData(new_x, new_data) - plot_item.enableAutoRange() # type: ignore[attr-defined] - - slider.valueChanged.connect(on_row_changed) - - slider_col = QWidget() - slider_col.setFixedWidth(36) - col_layout = QVBoxLayout(slider_col) - col_layout.setContentsMargins(0, 0, 0, 0) - col_layout.setSpacing(2) - col_layout.addWidget(row_label) - col_layout.addWidget(slider) - - container = QWidget() - h_layout = QHBoxLayout(container) - h_layout.setContentsMargins(0, 0, 0, 0) - h_layout.setSpacing(4) - h_layout.addWidget(slider_col) - h_layout.addWidget(self.plot_widget) - - self.stack_layout.addWidget(container) - else: - self.stack_layout.addWidget(self.plot_widget) - - self.apply_theme() - - def apply_theme(self, theme: Optional[Literal["dark", "light"]] = None): - """ - Apply the theme - - Args: - theme (Optional[str]): Theme, either "dark", "light", or None. Defaults to None. - """ - if theme is None: - app = QApplication.instance() - theme = app.theme.theme # type: ignore - - bg_color = pg.getConfigOption("background") - fg_color = pg.getConfigOption("foreground") - if self.plot_widget is not None: - n_curves = len(self.plot_widget.listDataItems()) - colors = Colors.golden_angle_color( - colormap="plasma", num=max(10, n_curves + 1), format="HEX" - ) - for idx, curve in enumerate(self.plot_widget.listDataItems()): - curve.setPen(pg.mkPen(color=colors[idx])) - # Background - self.plot_widget.setBackground(bg_color) - # Axes (tick marks, tick labels, axis line) - for axis in ["left", "bottom", "right", "top"]: - ax = self.plot_widget.getAxis(axis) - ax.setPen(pg.mkPen(color=fg_color)) - ax.setTextPen(pg.mkPen(color=fg_color)) - - if self.image_widget is not None: - self.image_widget.getView().setBackgroundColor(bg_color) - self.image_widget.ui.histogram.setBackground(bg_color) - - # ── 2-D image ────────────────────────────────────────────────────────── - def _show_image_2d(self, data): - self._clear_stack() - - squeezed = np.squeeze(data) - if squeezed.ndim > 2: - squeezed = squeezed.reshape(-1, squeezed.shape[-1]) - - # complex → magnitude - img_data = np.abs(squeezed) if np.iscomplexobj(squeezed) else squeezed.astype(float) - - # ImageView gives us colorbar + histogram + zoom for free - self.image_widget = pg.ImageView() - self.image_widget.ui.roiBtn.hide() - self.image_widget.ui.menuBtn.hide() - - # Use 'inferno'-like LUT - self.image_widget.setColorMap(pg.colormap.get("inferno", source="matplotlib")) - - # pyqtgraph ImageView expects (cols, rows) — transpose so row 0 is at top - self.image_widget.setImage(img_data.T, autoLevels=True, autoHistogramRange=True) - self.image_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - - self.stack_layout.addWidget(self.image_widget) - - self.apply_theme() - - # ── Table ────────────────────────────────────────────────────────────── - def _show_table(self, data): - self._clear_stack() - - MAX_ROWS, MAX_COLS = 2000, 500 - - if data.ndim == 0: - flat = data.reshape(1, 1) - elif data.ndim == 1: - flat = data.reshape(-1, 1) - elif data.ndim == 2: - flat = data - else: - flat = data.reshape(-1, data.shape[-1]) - - rows, cols = flat.shape - show_rows = min(rows, MAX_ROWS) - show_cols = min(cols, MAX_COLS) - - if rows > MAX_ROWS or cols > MAX_COLS: - note = QLabel(f"⚠ Showing {show_rows}/{rows} rows × {show_cols}/{cols} cols") - note.setObjectName("info_label") - note.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.stack_layout.addWidget(note) - - table = QTableWidget(show_rows, show_cols) - table.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) - table.setSelectionMode(QAbstractItemView.SelectionMode.ContiguousSelection) - table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive) - table.horizontalHeader().setDefaultSectionSize(90) - table.verticalHeader().setDefaultSectionSize(22) - table.setFont(QFont("JetBrains Mono, Consolas, monospace", 10)) - - is_float = np.issubdtype(flat.dtype, np.floating) - is_complex = np.iscomplexobj(flat) - is_bytes = flat.dtype.kind == "S" - - for r in range(show_rows): - for c in range(show_cols): - val = flat[r, c] - txt = ( - f"{val:.6g}" - if is_float - else f"{val:.4g}" if is_complex else str(val.decode()) if is_bytes else str(val) - ) - cell = QTableWidgetItem(txt) - cell.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) - table.setItem(r, c, cell) - - self.stack_layout.addWidget(table) - - -# ── Main window ──────────────────────────────────────────────────────────── -class HDF5Viewer(QMainWindow): - def __init__(self, filepath=None): - super().__init__() - self.h5files = {} # filepath -> h5py.File - self._build_ui() - if filepath: - self.load_files([filepath]) - - def apply_theme(self, theme: Literal["dark", "light"]): - """ - Apply the theme - - Args: - theme (str): Theme, either "dark" or "light" - """ - self.data_panel.apply_theme(theme) - - def load_files(self, filepaths: list[str]): - """Open one or more HDF5 files and add each as a top-level tree node.""" - for f in filepaths: - if f in self.h5files: - continue # already loaded - self.h5files[f] = h5py.File(f, "r") - self._add_file_to_tree(f) - - def _build_ui(self): - central = QWidget() - self.setCentralWidget(central) - - root_layout = QHBoxLayout(central) - - splitter = QSplitter(Qt.Orientation.Horizontal) - splitter.setChildrenCollapsible(False) - - # ── Left pane ── - left_pane = QWidget() - left_layout = QVBoxLayout(left_pane) - - self.tree = QTreeWidget() - self.tree.setMinimumWidth(250) - self.tree.setHeaderLabels(["Name", "Type", "Shape"]) - self.tree.header().setStretchLastSection(False) - self.tree.header().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) - self.tree.header().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) - self.tree.header().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) - self.tree.itemClicked.connect(self._on_item_clicked) - - left_layout.addWidget(self.tree, 1) - - # ── Right pane ── - self.data_panel = DataPanel() - - splitter.addWidget(left_pane) - splitter.addWidget(self.data_panel) - - splitter.setStretchFactor(0, 1) - splitter.setStretchFactor(1, 3) - splitter.setSizes([300, 900]) - splitter.setHandleWidth(6) - splitter.setChildrenCollapsible(False) - - root_layout.addWidget(splitter) - - def _add_file_to_tree(self, filepath: str): - """Add a single file as a new top-level node in the tree.""" - h5file = self.h5files[filepath] - filename = filepath.split("/")[-1] - - dataset_icon = material_icon( - "dataset", size=(ICON_SIZE, ICON_SIZE), color=get_accent_colors().default.name() - ) - root_item = QTreeWidgetItem(self.tree, [filename, "Group", ""]) - root_item.setIcon(0, dataset_icon) - root_item.setData(0, Qt.ItemDataRole.UserRole, "/") - root_item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") - root_item.setData(0, Qt.ItemDataRole.UserRole + 2, filepath) # so clicks know which file - - self.populate_tree(root_item, h5file) - self.tree.addTopLevelItem(root_item) - - # Expand first 2 levels - self.tree.expandItem(root_item) - for i in range(root_item.childCount()): - self.tree.expandItem(root_item.child(i)) - - self.tree.setCurrentItem(root_item) - - def populate_tree(self, parent_item, h5_obj): - folder_icon = material_icon("folder", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - vector_icon = material_icon("show_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - array_icon = material_icon( - "stacked_line_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9" - ) - scalar_icon = material_icon("point_scan", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - str_icon = material_icon("text_snippet", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - - if not isinstance(h5_obj, h5py.Group): - return - - for key in h5_obj.keys(): - try: - child = h5_obj[key] - except Exception: - continue - - if isinstance(child, h5py.Group): - item = QTreeWidgetItem(parent_item, [key, "Group", ""]) - item.setIcon(0, folder_icon) - - item.setData(0, Qt.ItemDataRole.UserRole, child.name) - item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") - - self.populate_tree(item, child) - - elif isinstance(child, h5py.Dataset): - shape_str = "x".join(str(s) for s in child.shape) or "scalar" - - d = child.dtype - dtype = "unknown" - if np.issubdtype(d, np.integer): - dtype = "int" - if np.issubdtype(d, np.floating): - dtype = "float" - if np.issubdtype(d, np.complexfloating): - dtype = "complex" - if d.kind in ("S", "U", "O"): - dtype = "str" - - if shape_str == "scalar": - if dtype == "str": - icon = str_icon - else: - icon = scalar_icon - elif "x" in shape_str: - icon = array_icon - else: - icon = vector_icon - - item = QTreeWidgetItem(parent_item, [key, dtype, shape_str]) - item.setIcon(0, icon) - - item.setData(0, Qt.ItemDataRole.UserRole, child.name) - item.setData(0, Qt.ItemDataRole.UserRole + 1, "dataset") - - item.setForeground(1, QBrush(QColor(get_accent_colors().success.name()))) - item.setForeground(2, QBrush(QColor("#656365"))) - - def _get_filepath_for_item(self, item: QTreeWidgetItem) -> str: - """Walk up the tree to find the filepath stored on the root node.""" - node = item - while node.parent(): - node = node.parent() - return node.data(0, Qt.ItemDataRole.UserRole + 2) - - def _on_item_clicked(self, item, _col): - path = item.data(0, Qt.ItemDataRole.UserRole) - kind = item.data(0, Qt.ItemDataRole.UserRole + 1) - filepath = self._get_filepath_for_item(item) - - if not path or not filepath: - return - - h5file = self.h5files[filepath] - obj = h5file[path] if path != "/" else h5file - - if kind == "dataset": - data = h5file[path][()] - self.data_panel.display(data, path) - else: - n = len(obj) if isinstance(obj, h5py.Group) else 0 - - def clear_files(self): - """Close all open HDF5 files and reset the tree.""" - for h5file in self.h5files.values(): - h5file.close() - self.h5files.clear() - self.tree.clear() - self.data_panel.show_empty() - - def closeEvent(self, event): - for h5file in self.h5files.values(): - h5file.close() - self.h5files.clear() - super().closeEvent(event) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/widgets/__init__,py b/debye_bec/bec_widgets/widgets/data_viewer/widgets/__init__,py new file mode 100644 index 0000000..e69de29 diff --git a/debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py b/debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py similarity index 50% rename from debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py rename to debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py index 9d7b4f7..187157e 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py @@ -1,17 +1,22 @@ """ -TaggedListWidget — a QListWidget where each item shows a label + styled tag pills. -Selection works natively; no popup, no paintEvent hacks. +Universal Qt widgets """ -import sys +from functools import partial +from bec_widgets.utils.colors import get_accent_colors + +# pylint: disable=E0611 from qtpy.QtCore import QPoint, QRect, QSize, Qt from qtpy.QtGui import QColor, QFont, QFontMetrics, QPainter, QPen from qtpy.QtWidgets import ( - QApplication, + QGroupBox, + QHBoxLayout, QLabel, + QLayout, QListWidget, QListWidgetItem, + QPushButton, QStyle, QStyledItemDelegate, QStyleOptionViewItem, @@ -19,7 +24,117 @@ from qtpy.QtWidgets import ( QWidget, ) -# ── Design tokens ────────────────────────────────────────────────────────────── + +class Group(QGroupBox): + def __init__(self, label, objs, orientation="vertical"): + super().__init__(label) + if orientation == "vertical": + self._layout = QVBoxLayout(self) + elif orientation == "horizontal": + self._layout = QHBoxLayout(self) + else: + raise ValueError(f"Orientation {orientation} is not supported!") + for obj in objs: + if isinstance(obj, QWidget): + self._layout.addWidget(obj) + elif isinstance(obj, QLayout): + self._layout.addLayout(obj) + + +class Button(QWidget): + def __init__(self, label=None, label_button: str = "", enabled=False): + super().__init__() + layout = QHBoxLayout(self) + layout.setContentsMargins(10, 0, 0, 0) + layout.setSpacing(0) + if label is not None: + self.label = QLabel(label) + self.label.setFixedWidth(140) + layout.addWidget(self.label) + self.button = QPushButton(label_button) + if label is not None: + self.button.setFixedWidth(160) + self.enable_button(enabled) + layout.addWidget(self.button) + + def clicked_connect(self, func): + """Connect a function to the button press.""" + self.button.clicked.connect(func) + + def enable_button(self, enable: bool = False): + if enable: + self.button.setStyleSheet( + f"QPushButton {{background-color: {get_accent_colors().default.name()}; color: white;}}" + ) + self.button.setEnabled(True) + else: # disabled + self.button.setStyleSheet( + "QPushButton {{background-color: rgb(120, 120, 120); color: white;}}" + ) + self.button.setDisabled(True) + + def setText(self, text): + self.button.setText(text) + + +class ListWidget(QWidget): + def __init__(self, identifier=""): + super().__init__() + layout = QHBoxLayout(self) + layout.setContentsMargins(10, 0, 0, 0) + layout.setSpacing(0) + self.identifier = identifier + self.value = TaggedListWidget() + layout.addWidget(self.value) + + def clear(self): + self.value.clear() + + def addTaggedItem(self, label, tags): + self.value.addTaggedItem(label, tags) + + def setCurrentIndex(self, text): + self.value.setCurrentIndex(text) + + def currentItemChanged_connect(self, func): + """Connect a function to the Enter/Return key press.""" + self.value.currentItemChanged.connect( + partial( + func, + identifier=self.identifier, + value_obj=self.value, + value=lambda: self.value.currentIndex(), + ) + ) + + def setDisabled(self, disable): + self.value.setDisabled(disable) + + +class TaggedListWidget(QListWidget): + """QListWidget with label + coloured tag pills per row.""" + + def __init__(self, parent=None) -> None: + super().__init__(parent) + self.setItemDelegate(TaggedDelegate(self)) + self.setMouseTracking(True) # enables hover highlight + + def addTaggedItem(self, label: str, tags: list | None = None) -> QListWidgetItem: + """ + Add a row. tags is a list of (text, hex_color) pairs, e.g. + [("v1.26", "#2563EB"), ("stable", "#16A34A")] + """ + item = QListWidgetItem(str(label)) + if tags: + item.setData(_UserRole, list(tags)) + self.addItem(item) + return item + + def currentTags(self) -> list: + item = self.currentItem() + return item.data(_UserRole) or [] if item else [] + + ITEM_HEIGHT = 30 H_PAD = 12 TAG_H_PAD = 7 @@ -28,30 +143,22 @@ TAG_GAP = 5 LABEL_TAG_GAP = 12 CORNER_RADIUS = 4 TAG_TEXT_COLOR = "#FFFFFF" - -# ── Enum compat (PySide6 nested vs PyQt5 flat) ──────────────────────────────── -try: - _DEMIBOLD = QFont.Weight.DemiBold - _MEDIUM = QFont.Weight.Medium -except AttributeError: - _DEMIBOLD = QFont.DemiBold - _MEDIUM = QFont.Medium - -_AlignVCenter = Qt.AlignmentFlag.AlignVCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignVCenter -_AlignCenter = Qt.AlignmentFlag.AlignCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignCenter -_UserRole = Qt.ItemDataRole.UserRole if hasattr(Qt, "ItemDataRole") else Qt.UserRole -_NoPen = Qt.PenStyle.NoPen if hasattr(Qt, "PenStyle") else Qt.NoPen -_AA = QPainter.RenderHint.Antialiasing if hasattr(QPainter, "RenderHint") else QPainter.Antialiasing - -try: - _State_Selected = QStyle.StateFlag.State_Selected - _State_MouseOver = QStyle.StateFlag.State_MouseOver -except AttributeError: - _State_Selected = QStyle.State_Selected - _State_MouseOver = QStyle.State_MouseOver +_DEMIBOLD = QFont.Weight.DemiBold +_MEDIUM = QFont.Weight.Medium +_AlignCenter = ( + Qt.AlignmentFlag.AlignCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignmentFlag.AlignCenter +) +_UserRole = Qt.ItemDataRole.UserRole if hasattr(Qt, "ItemDataRole") else Qt.ItemDataRole.UserRole +_NoPen = Qt.PenStyle.NoPen if hasattr(Qt, "PenStyle") else Qt.PenStyle.NoPen +_AA = ( + QPainter.RenderHint.Antialiasing + if hasattr(QPainter, "RenderHint") + else QPainter.RenderHint.Antialiasing +) +_State_Selected = QStyle.StateFlag.State_Selected +_State_MouseOver = QStyle.StateFlag.State_MouseOver -# ── Delegate ────────────────────────────────────────────────────────────────── class TaggedDelegate(QStyledItemDelegate): def sizeHint(self, option: QStyleOptionViewItem, index) -> QSize: @@ -73,7 +180,9 @@ class TaggedDelegate(QStyledItemDelegate): label_text = ( index.data( - Qt.ItemDataRole.DisplayRole if hasattr(Qt, "ItemDataRole") else Qt.DisplayRole + Qt.ItemDataRole.DisplayRole + if hasattr(Qt, "ItemDataRole") + else Qt.ItemDataRole.DisplayRole ) or "" ) @@ -124,28 +233,3 @@ class TaggedDelegate(QStyledItemDelegate): x += tw + TAG_GAP painter.restore() - - -# ── Widget ──────────────────────────────────────────────────────────────────── -class TaggedListWidget(QListWidget): - """QListWidget with label + coloured tag pills per row.""" - - def __init__(self, parent=None) -> None: - super().__init__(parent) - self.setItemDelegate(TaggedDelegate(self)) - self.setMouseTracking(True) # enables hover highlight - - def addTaggedItem(self, label: str, tags: list | None = None) -> QListWidgetItem: - """ - Add a row. tags is a list of (text, hex_color) pairs, e.g. - [("v1.26", "#2563EB"), ("stable", "#16A34A")] - """ - item = QListWidgetItem(str(label)) - if tags: - item.setData(_UserRole, list(tags)) - self.addItem(item) - return item - - def currentTags(self) -> list: - item = self.currentItem() - return item.data(_UserRole) or [] if item else [] diff --git a/debye_bec/bec_widgets/widgets/digital_twin/__init__.py b/debye_bec/bec_widgets/widgets/digital_twin/__init__.py index e69de29..a42cb09 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/__init__.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/__init__.py @@ -0,0 +1,3 @@ +from .beamline import get_parameters + +parameters = get_parameters() diff --git a/debye_bec/bec_widgets/widgets/digital_twin/beamline.py b/debye_bec/bec_widgets/widgets/digital_twin/beamline.py new file mode 100644 index 0000000..3653284 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/digital_twin/beamline.py @@ -0,0 +1,51 @@ +import socket + +from bec_lib import bec_logger + +from .types import BeamlineId + +logger = bec_logger.logger + + +def get_beamline_id() -> BeamlineId: + """ + Based on the bec servers hostname, tries to extract the beamline + identifier (e.g. x01da, x10da, etc). + + Raises: + ValueError if beamline cannot be extracted from hostname or beamline not implemented. + """ + bec_hostname = socket.gethostname() + start = bec_hostname.find("x") + if start != -1: + beamline = bec_hostname[start : start + 5] + match beamline: + case "x01da": + return BeamlineId.X01DA + case "x10da": + return BeamlineId.X10DA + case _: + raise ValueError(f"Not implemented beamline {beamline}") + else: + logger.warning(f"Failed to extract beamline from bec server hostname {bec_hostname}") + choice = input("Do you want to manually select a beamline? (yes/no): ").strip().lower() + if choice in ["yes", "y"]: + bl = input(f"Choose from: {[bl.value for bl in BeamlineId]}") + if bl in BeamlineId: + logger.info(f"Manually selected beamline {bl}") + return BeamlineId(bl) + else: + raise ValueError(f"Wrong selection {bl}") + else: + raise ValueError("Cannot open digital twin without a beamline") + + +def get_parameters(): + beamline = get_beamline_id() + if beamline == "x01da": + from . import x01da_parameters as parameters + elif beamline == "x10da": + from . import x10da_parameters as parameters + else: + raise ValueError(f"Unknown beamline: {beamline}") + return parameters diff --git a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_positions.py b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_positions.py index 885dcb3..6347d4b 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_positions.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_positions.py @@ -5,8 +5,8 @@ Calculates the positions of axes based on a beamline config import numpy as np from bec_lib import bec_logger -import debye_bec.bec_widgets.widgets.digital_twin.x01da_parameters as bl -from debye_bec.bec_widgets.widgets.digital_twin.types import ConfigDict +from .. import parameters as bl +from ..types import ConfigDict logger = bec_logger.logger diff --git a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_sideview.py b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_sideview.py index 7ec677d..135c76d 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_sideview.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_sideview.py @@ -4,8 +4,8 @@ Calculates the sideview coordinates based on a beamline config. import numpy as np -import debye_bec.bec_widgets.widgets.digital_twin.x01da_parameters as bl -from debye_bec.bec_widgets.widgets.digital_twin.types import ConfigDict, DataDict +from .. import parameters as bl +from ..types import ConfigDict, DataDict def calc_sideview(cfg: ConfigDict) -> DataDict: diff --git a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_surfaces.py b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_surfaces.py index 0b90b93..9411ded 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_surfaces.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_surfaces.py @@ -7,8 +7,8 @@ import re import numpy as np from bec_lib import bec_logger -import debye_bec.bec_widgets.widgets.digital_twin.x01da_parameters as bl -from debye_bec.bec_widgets.widgets.digital_twin.types import ConfigDict, SurfaceDict +from .. import parameters as bl +from ..types import ConfigDict, SurfaceDict logger = bec_logger.logger diff --git a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_varia.py b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_varia.py index db7c471..33ad7cf 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_varia.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/calculations/calc_varia.py @@ -10,7 +10,7 @@ from bec_lib import bec_logger from scipy.interpolate import UnivariateSpline from xrt.backends.raycing.physconsts import AVOGADRO, CHeVcm -import debye_bec.bec_widgets.widgets.digital_twin.x01da_parameters as bl +from .. import parameters as bl logger = bec_logger.logger diff --git a/debye_bec/bec_widgets/widgets/digital_twin/digital_twin.py b/debye_bec/bec_widgets/widgets/digital_twin/digital_twin.py index 96a6343..180e9fb 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/digital_twin.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/digital_twin.py @@ -2,7 +2,6 @@ Digital Twin: Custom BEC widget to support the beamline alignment. """ -import socket import sys from pathlib import Path from typing import Literal, cast @@ -36,10 +35,11 @@ from qtpy.QtWidgets import ( QWidget, ) -from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_positions import calc_positions -from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_sideview import calc_sideview -from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_surfaces import calc_surfaces -from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_varia import ( +from .beamline import get_beamline_id +from .calculations.calc_positions import calc_positions +from .calculations.calc_sideview import calc_sideview +from .calculations.calc_surfaces import calc_surfaces +from .calculations.calc_varia import ( cm_critical_angle, cm_reflectivity, cm_stripe_to_trx, @@ -53,12 +53,12 @@ from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_varia import ( sldi_gap_to_acc, table_to_smpl_pos, ) -from debye_bec.bec_widgets.widgets.digital_twin.panels.input_panel import InputPanel -from debye_bec.bec_widgets.widgets.digital_twin.panels.mover_panel import MoverPanel -from debye_bec.bec_widgets.widgets.digital_twin.panels.plots import SideviewPlot, SurfacePlots -from debye_bec.bec_widgets.widgets.digital_twin.panels.settings_panel import SettingsPanel -from debye_bec.bec_widgets.widgets.digital_twin.types import BeamlineId, ConfigDict -from debye_bec.bec_widgets.widgets.digital_twin.widgets.qt_widgets import ComboBox, InputNumberField +from .panels.input_panel import InputPanel +from .panels.mover_panel import MoverPanel +from .panels.plots import SideviewPlot, SurfacePlots +from .panels.settings_panel import SettingsPanel +from .types import ConfigDict +from .widgets.qt_widgets import ComboBox, InputNumberField logger = bec_logger.logger @@ -78,7 +78,7 @@ class DigitalTwin(BECWidget, QWidget): super().__init__(parent=parent, theme_update=True, *arg, **kwargs) self.get_bec_shortcuts() - self.beamline = self.get_beamline_id() + self.beamline = get_beamline_id() # Debugging, override beamline! # self.beamline = BeamlineId.X10DA @@ -200,38 +200,6 @@ class DigitalTwin(BECWidget, QWidget): self.surface_plots.apply_theme(theme) self.mover.apply_theme(theme) - def get_beamline_id(self) -> BeamlineId: - """ - Based on the bec servers hostname, tries to extract the beamline - identifier (e.g. x01da, x10da, etc). - - Raises: - ValueError if beamline cannot be extracted from hostname or beamline not implemented. - """ - bec_hostname = socket.gethostname() - start = bec_hostname.find("x") - if start != -1: - beamline = bec_hostname[start : start + 5] - match beamline: - case "x01da": - return BeamlineId.X01DA - case "x10da": - return BeamlineId.X10DA - case _: - raise ValueError(f"Not implemented beamline {beamline}") - else: - logger.warning(f"Failed to extract beamline from bec server hostname {bec_hostname}") - choice = input("Do you want to manually select a beamline? (yes/no): ").strip().lower() - if choice in ["yes", "y"]: - bl = input(f"Choose from: {[bl.value for bl in BeamlineId]}") - if bl in BeamlineId: - logger.info(f"Manually selected beamline {bl}") - return BeamlineId(bl) - else: - raise ValueError(f"Wrong selection {bl}") - else: - raise ValueError("Cannot open digital twin without a beamline") - @SafeSlot() def check_bec_config(self, *args): """ diff --git a/debye_bec/bec_widgets/widgets/digital_twin/digital_twin_plugin.py b/debye_bec/bec_widgets/widgets/digital_twin/digital_twin_plugin.py index 2decf97..921fd6b 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/digital_twin_plugin.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/digital_twin_plugin.py @@ -5,7 +5,7 @@ from bec_widgets.utils.bec_designer import designer_material_icon from qtpy.QtDesigner import QDesignerCustomWidgetInterface from qtpy.QtWidgets import QWidget -from debye_bec.bec_widgets.widgets.digital_twin.digital_twin import DigitalTwin +from .digital_twin import DigitalTwin DOM_XML = """ @@ -22,7 +22,7 @@ class DigitalTwinPlugin(QDesignerCustomWidgetInterface): # pragma: no cover def createWidget(self, parent): if parent is None: - return QWidget() + return QWidget() t = DigitalTwin(parent) return t diff --git a/debye_bec/bec_widgets/widgets/digital_twin/panels/input_panel.py b/debye_bec/bec_widgets/widgets/digital_twin/panels/input_panel.py index 2620b9b..c0ae458 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/panels/input_panel.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/panels/input_panel.py @@ -7,14 +7,8 @@ from typing import Union # pylint: disable=E0611 from qtpy.QtWidgets import QVBoxLayout, QWidget -from debye_bec.bec_widgets.widgets.digital_twin.types import BeamlineId -from debye_bec.bec_widgets.widgets.digital_twin.widgets.qt_widgets import ( - Button, - ComboBox, - Group, - InputNumberField, - NumberIndicator, -) +from ..types import BeamlineId +from ..widgets.qt_widgets import Button, ComboBox, Group, InputNumberField, NumberIndicator class InputPanel(QWidget): diff --git a/debye_bec/bec_widgets/widgets/digital_twin/panels/mover_panel.py b/debye_bec/bec_widgets/widgets/digital_twin/panels/mover_panel.py index fcf3703..1e1ba99 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/panels/mover_panel.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/panels/mover_panel.py @@ -7,11 +7,8 @@ from typing import Literal # pylint: disable=E0611 from qtpy.QtWidgets import QVBoxLayout, QWidget -from debye_bec.bec_widgets.widgets.digital_twin.widgets.move_widget import ( - AbsorberWidget, - MoveWidget, -) -from debye_bec.bec_widgets.widgets.digital_twin.widgets.qt_widgets import Group +from ..widgets.move_widget import AbsorberWidget, MoveWidget +from ..widgets.qt_widgets import Group class MoverPanel(QWidget): diff --git a/debye_bec/bec_widgets/widgets/digital_twin/panels/plots.py b/debye_bec/bec_widgets/widgets/digital_twin/panels/plots.py index 8cbafce..0080ef0 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/panels/plots.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/panels/plots.py @@ -15,14 +15,14 @@ from qtpy.QtGui import QBrush, QColor # pylint: disable=E0611 from qtpy.QtWidgets import QApplication, QGraphicsRectItem, QHBoxLayout, QVBoxLayout, QWidget -from debye_bec.bec_widgets.widgets.digital_twin.calculations.calc_varia import ( +from ..calculations.calc_varia import ( mirror_surface_geometries, mo_surface_geometries, pipe_geometries, wall_geometries, ) -from debye_bec.bec_widgets.widgets.digital_twin.types import DataDict, SurfaceDict -from debye_bec.bec_widgets.widgets.digital_twin.widgets.qt_widgets import Group +from ..types import DataDict, SurfaceDict +from ..widgets.qt_widgets import Group logger = bec_logger.logger diff --git a/debye_bec/bec_widgets/widgets/digital_twin/panels/settings_panel.py b/debye_bec/bec_widgets/widgets/digital_twin/panels/settings_panel.py index 88ce3e1..2bf74d2 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/panels/settings_panel.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/panels/settings_panel.py @@ -5,11 +5,7 @@ Settings panel for the digital twin widget # pylint: disable=E0611 from qtpy.QtWidgets import QVBoxLayout, QWidget -from debye_bec.bec_widgets.widgets.digital_twin.widgets.qt_widgets import ( - Button, - Group, - TextIndicator, -) +from ..widgets.qt_widgets import Button, Group, TextIndicator class SettingsPanel(QWidget): diff --git a/debye_bec/bec_widgets/widgets/digital_twin/register_digital_twin.py b/debye_bec/bec_widgets/widgets/digital_twin/register_digital_twin.py index 0c5d315..024c888 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/register_digital_twin.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/register_digital_twin.py @@ -6,7 +6,7 @@ def main(): # pragma: no cover return from PySide6.QtDesigner import QPyDesignerCustomWidgetCollection - from debye_bec.bec_widgets.widgets.digital_twin.digital_twin_plugin import DigitalTwinPlugin + from .digital_twin_plugin import DigitalTwinPlugin QPyDesignerCustomWidgetCollection.addCustomWidget(DigitalTwinPlugin()) diff --git a/debye_bec/bec_widgets/widgets/digital_twin/widgets/move_widget.py b/debye_bec/bec_widgets/widgets/digital_twin/widgets/move_widget.py index 65efd74..4bd529c 100644 --- a/debye_bec/bec_widgets/widgets/digital_twin/widgets/move_widget.py +++ b/debye_bec/bec_widgets/widgets/digital_twin/widgets/move_widget.py @@ -15,7 +15,8 @@ from qtpy.QtCore import QObject, QPropertyAnimation, Qt, QThread from qtpy.QtGui import QTransform from qtpy.QtWidgets import QApplication, QHBoxLayout, QLabel, QPushButton, QWidget -from debye_bec.devices.absorber import STATUS as ABS_STATUS +# pylint: disable=E0402 +from .....devices.absorber import STATUS as ABS_STATUS logger = bec_logger.logger