diff --git a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py index 3ad7c35..dd43d52 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer.py @@ -6,8 +6,7 @@ import os import subprocess import sys from datetime import datetime -from functools import partial -from typing import Literal, Optional, cast +from typing import Literal from bec_lib import bec_logger from bec_lib.endpoints import MessageEndpoints @@ -15,27 +14,12 @@ from bec_widgets.utils.bec_dispatcher import BECDispatcher from bec_widgets.utils.bec_widget import BECWidget from bec_widgets.utils.colors import apply_theme, get_accent_colors from bec_widgets.utils.error_popups import SafeSlot -from qtpy.QtCore import Qt # pylint: disable=E0611 -from qtpy.QtGui import QFont +from qtpy.QtWidgets import QApplication, QVBoxLayout, QWidget -# pylint: disable=E0611 -from qtpy.QtWidgets import ( - QApplication, - QComboBox, - QDoubleSpinBox, - QGroupBox, - QHBoxLayout, - QLabel, - QLayout, - QPushButton, - QVBoxLayout, - QWidget, -) - -from debye_bec.bec_widgets.widgets.data_viewer.qt_widgets import TaggedListWidget -from debye_bec.bec_widgets.widgets.data_viewer.viewer import HDF5Viewer +from .panels.input_panel import InputPanel +from .panels.scan_view import ScanViewer logger = bec_logger.logger @@ -54,20 +38,15 @@ class DataViewer(BECWidget, QWidget): super().__init__(parent=parent, theme_update=True, *arg, **kwargs) self.get_bec_shortcuts() - logger.info(f"Type of self.client: {type(self.client)}") - central = QWidget() self.root_layout = QVBoxLayout(central) - self.input = InputPanel() - self.viewer = HDF5Viewer() - + self.viewer = ScanViewer() self.root_layout.addWidget(self.input, 0) self.root_layout.addWidget(self.viewer, 1) self.setLayout(self.root_layout) self.setWindowTitle("Data Viewer") - # self.resize(1800, 800) self.history = [] self.bec_dispatcher.connect_slot(self.on_history_update, MessageEndpoints.scan_history()) @@ -75,14 +54,9 @@ class DataViewer(BECWidget, QWidget): self.current_row = 0 - logger.info(self.client.acl) - logger.info(self.client.active_account) - logger.info(self.client.proc) - logger.info(self.client.username) - self.input.scan_sel.currentItemChanged_connect(self.scan_sel_changed) - self.input.load_button.clicked_connect(self.load_dataset) - self.input.unload_button.clicked_connect(self.unload_all_datasets) + self.input.load_button.clicked_connect(self.load_scan) + self.input.unload_button.clicked_connect(self.unload_all_scans) self.input.open_fm_button.clicked_connect(self.open_in_file_manager) def apply_theme(self, theme: Literal["dark", "light"]): @@ -97,10 +71,12 @@ class DataViewer(BECWidget, QWidget): @SafeSlot() def scan_sel_changed(self, *_, **kwargs): + """Updates the current row value of the scan selection list""" self.current_row = kwargs["value"]().row() @SafeSlot() def open_in_file_manager(self, *_): + """Open the scan folder in the systems default file manager""" if len(self.history) > 0: scan = self.history[self.current_row] filepath = scan["file_components"][0].decode().rsplit("/", 1)[0] @@ -113,9 +89,11 @@ class DataViewer(BECWidget, QWidget): ) @SafeSlot() - def load_dataset( - self, *_ - ): # TODO: Check scan file components for combined xas/xrd scans. Is the Pilatus file in there as well? + def load_scan(self, *_): + """ + Loads a scan. Find all files within the scan folder, sort them and + then load the files in the scan view + """ if len(self.history) > 0: scan = self.history[self.current_row] base_filepath = scan["file_components"][0].decode().rsplit("/", 1)[0] @@ -138,44 +116,67 @@ class DataViewer(BECWidget, QWidget): self.viewer.load_files(sorted_files) @SafeSlot() - def unload_all_datasets(self, *_): + def unload_all_scans(self, *_): + """Removes all scans from the scan view""" self.viewer.clear_files() - def duration_string(self, start: str, end: str) -> str: + def duration_formatted(self, start: str, end: str) -> str: + """ + Calculates the duration of a scan based on start end end time and + formats it as an easy readable string. + + Args: + start(str): start time in iso-format + end(str): end time in iso-format + + Returns: + str: Formatted duration, e.g. '1min 10s' or '1h 13min' + """ start_dt = datetime.fromisoformat(start) end_dt = datetime.fromisoformat(end) - seconds = abs(int((end_dt - start_dt).total_seconds())) days, remainder = divmod(seconds, 86400) hours, remainder = divmod(remainder, 3600) - minutes, _ = divmod(remainder, 60) + minutes, remainder = divmod(remainder, 60) + seconds, _ = divmod(remainder, 60) parts = [] - if days: parts.append(f"{days}d") if hours: parts.append(f"{hours}h") if minutes: parts.append(f"{minutes}min") - + if not days and not hours and minutes < 10: + parts.append(f"{seconds}s") return " ".join(parts) if parts else "<1min" + def time_formatted(self, iso_time: str) -> str: + """ + Formates a time as an easy readable string. + + Args: + iso_time(str): Time in iso-format + + Returns: + str: Time formatted with format '%d.%m.%Y %H:%M', e.g. '14.01.1995 08:12' + """ + dt = datetime.fromisoformat(iso_time) + return dt.strftime("%d.%m.%Y %H:%M") + @SafeSlot() def on_history_update(self, *_): + """Updates the scan list based on the bec scan history.""" self.history = [] self.input.scan_sel.clear() - # Get the length of the scan history, which is 0 when the bec server was started - # and no scan has finished yet. Limit the history to the latest 20 scans. + if self.client.history is None: + return max_scans = min(len(self.client.history), MAX_HIST_LEN) for n in range(1, max_scans): # last scans, limited by MAX_HIST_LEN - # logger.info(self.client.history[-n].metadata["bec"]["status"]) - start_time = self.client.history[-n].metadata["start_time"] - end_time = self.client.history[-n].metadata["end_time"] - # logger.info(type(start_time)) - scan_data = self.client.history[-n].metadata["bec"] - # logger.info(scan_data) + start_time = self.client.history[-n].metadata["start_time"] # type: ignore + end_time = self.client.history[-n].metadata["end_time"] # type: ignore + scan_data = self.client.history[-n].metadata["bec"] # type: ignore scan_number = scan_data["scan_number"] scan_name = scan_data["scan_name"] comment = scan_data["metadata"]["user_metadata"]["comment"] @@ -205,138 +206,9 @@ class DataViewer(BECWidget, QWidget): tags.append((status, get_accent_colors().emergency.name())) else: tags.append((status, "#656365")) - tags.append((self.duration_string(start_time, end_time), "#656365")) + tags.append((self.duration_formatted(start_time, end_time), "#656365")) + tags.append((self.time_formatted(start_time), "#656365")) self.input.scan_sel.addTaggedItem(label=str(scan_number), tags=tags) - # logger.info(f"Scan history: {self.history}") - - -class InputPanel(QWidget): - """Panel for scan selection of the data viewer widget""" - - def __init__(self, parent=None): - super().__init__(parent) - self._layout = QHBoxLayout(self) - # self._layout.setSizeConstraint(QLayout.SetFixedSize) # type: ignore - - # Scan selection - self.scan_sel = ListWidget("scan_sel", "Scan", ["Si", "Rh", "Pt"]) - self.load_button = Button(label_button="Load Dataset", enabled=True) - self.unload_button = Button(label_button="Unload all", enabled=True) - self.open_fm_button = Button(label_button="Open in File Manager", enabled=True) - - self._button_layout = QVBoxLayout() - self._button_layout.addWidget(self.load_button) - self._button_layout.addWidget(self.unload_button) - self._button_layout.addWidget(self.open_fm_button) - self._button_layout.addStretch() - - # Assemble complete scan selection group - self.input_group = Group( - "Scan selection", [self._button_layout, self.scan_sel], orientation="horizontal" - ) - - self._layout.addWidget(self.input_group) - # self._layout.addStretch() - - -class Group(QGroupBox): - def __init__(self, label, objs, orientation="vertical"): - super().__init__(label) - if orientation == "vertical": - self._layout = QVBoxLayout(self) # type: ignore - elif orientation == "horizontal": # assume horizontal - self._layout = QHBoxLayout(self) # type: ignore - else: - raise ValueError(f"Orientation {orientation} is not supported!") - for obj in objs: - if isinstance(obj, QWidget): - self._layout.addWidget(obj) # type: ignore - elif isinstance(obj, QLayout): - self._layout.addLayout(obj) - - -class ListWidget(QWidget): - def __init__(self, identifier="", label="", enums=[]): - super().__init__() - layout = QHBoxLayout(self) - layout.setContentsMargins(10, 0, 0, 0) - layout.setSpacing(0) - self.identifier = identifier - # self.label = QLabel(label) - # self.label.setFixedWidth(140) - # self.label.setContentsMargins(0, 0, 10, 0) - # self.label.setWordWrap(True) - # layout.addWidget(self.label) - self.value = TaggedListWidget() - # self.value.setFixedWidth(400) - # for entry in enums: - # self.value.addItem(entry) - layout.addWidget(self.value) - - def clear(self): - self.value.clear() - - def addTaggedItem(self, label, tags): - self.value.addTaggedItem(label, tags) - - def setCurrentIndex(self, text): - self.value.setCurrentIndex(text) - - # def currentIndex(self) -> int: - # return self.value.currentIndex() - - # def has_focus(self) -> bool: - # return QApplication.focusWidget() is self.value.view() - - def currentItemChanged_connect(self, func): - """Connect a function to the Enter/Return key press.""" - self.value.currentItemChanged.connect( - partial( - func, - identifier=self.identifier, - value_obj=self.value, - value=lambda: self.value.currentIndex(), - ) - ) - - def setDisabled(self, disable): - self.value.setDisabled(disable) - - -class Button(QWidget): - def __init__(self, label=None, label_button: str = "", enabled=False): - super().__init__() - layout = QHBoxLayout(self) - layout.setContentsMargins(10, 0, 0, 0) - layout.setSpacing(0) - if label is not None: - self.label = QLabel(label) - self.label.setFixedWidth(140) - layout.addWidget(self.label) - self.button = QPushButton(label_button) - if label is not None: - self.button.setFixedWidth(160) - self.enable_button(enabled) - layout.addWidget(self.button) - - def clicked_connect(self, func): - """Connect a function to the button press.""" - self.button.clicked.connect(func) - - def enable_button(self, enable: bool = False): - if enable: - self.button.setStyleSheet( - f"QPushButton {{background-color: {get_accent_colors().default.name()}; color: white;}}" - ) - self.button.setEnabled(True) - else: # disabled - self.button.setStyleSheet( - "QPushButton {{background-color: rgb(120, 120, 120); color: white;}}" - ) - self.button.setDisabled(True) - - def setText(self, text): - self.button.setText(text) if __name__ == "__main__": diff --git a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py index cf5b92f..41560d4 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/data_viewer_plugin.py @@ -5,7 +5,7 @@ from bec_widgets.utils.bec_designer import designer_material_icon from qtpy.QtDesigner import QDesignerCustomWidgetInterface from qtpy.QtWidgets import QWidget -from debye_bec.bec_widgets.widgets.data_viewer.data_viewer import DataViewer +from .data_viewer import DataViewer DOM_XML = """ @@ -22,7 +22,7 @@ class DataViewerPlugin(QDesignerCustomWidgetInterface): # pragma: no cover def createWidget(self, parent): if parent is None: - return QWidget() + return QWidget() t = DataViewer(parent) return t diff --git a/debye_bec/bec_widgets/widgets/data_viewer/loaders.py b/debye_bec/bec_widgets/widgets/data_viewer/loaders.py new file mode 100644 index 0000000..603cad3 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/loaders.py @@ -0,0 +1,250 @@ +""" +File-format loader (HDF5, images). +""" + +import os +from abc import ABC, abstractmethod +from typing import Iterator, Literal, Optional + +import h5py +import numpy as np + + +class NodeInfo: + """Describes a single node (group or dataset) inside a loaded file.""" + + __slots__ = ("name", "path", "kind", "dtype", "shape", "display_hint") + + def __init__( + self, + name: str, + path: str, + kind: Literal["group", "dataset"], + dtype: str = "", + shape: tuple[int, ...] = (), + display_hint: str = "", + ): + self.name = name + self.path = path + self.kind = kind + self.dtype = dtype # e.g. "float", "int", "str", … + self.shape = shape # empty tuple for scalars / groups + self.display_hint = display_hint # e.g. "image_native" — passed through to DataPanel + + +class BaseFileLoader(ABC): + """ + Abstract base class for file-format loaders. + + Subclass this to add support for a new format. Three things are required: + + 1. ``EXTENSIONS`` — tuple of lowercase extensions this loader handles, + e.g. ``(".h5", ".hdf5")``. + + 2. ``open(filepath)`` — open the file and keep any handles alive. + + 3. ``iter_nodes(path)`` — yield ``NodeInfo`` objects for the direct + children of *path* (depth-1 walk; the tree widget calls this + recursively as the user expands nodes). + + 4. ``read_dataset(path)`` — return the dataset at *path* as a + ``numpy.ndarray``. + + 5. ``close()`` — release any open file handles. + + 6. ``child_count(path)`` — return the number of direct children of a + group node (used for the status label; override if cheap to compute). + """ + + EXTENSIONS: tuple[str, ...] = () + + @abstractmethod + def open(self, filepath: str) -> None: ... + + @abstractmethod + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: ... + + @abstractmethod + def read_dataset(self, path: str) -> np.ndarray: ... + + @abstractmethod + def close(self) -> None: ... + + def child_count(self, path: str) -> int: + return sum(1 for _ in self.iter_nodes(path)) + + +# ══════════════════════════════════════════════════════════════════════════════ +# HDF5 loader (replaces the inline h5py logic that was in HDF5Viewer) +# ══════════════════════════════════════════════════════════════════════════════ + + +class HDF5Loader(BaseFileLoader): + """Loader for HDF5 / NeXus files (.h5, .hdf5, .nxs, .nx).""" + + EXTENSIONS = (".h5", ".hdf5", ".hdf", ".nxs", ".nx") + + def __init__(self): + self._file: Optional[h5py.File] = None + + def open(self, filepath: str) -> None: + self._file = h5py.File(filepath, "r") + + def close(self) -> None: + if self._file is not None: + self._file.close() + self._file = None + + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: + assert self._file is not None, "File not open" + obj = self._file[path] if path != "/" else self._file + + if not isinstance(obj, h5py.Group): + return + + for key in obj.keys(): + try: + child = obj[key] + except Exception: + continue + + child_path = child.name # h5py always gives the absolute path + + if isinstance(child, h5py.Group): + yield NodeInfo(name=key, path=child_path, kind="group") + + elif isinstance(child, h5py.Dataset): + shape_tuple = child.shape + d = child.dtype + dtype = "unknown" + if np.issubdtype(d, np.integer): + dtype = "int" + if np.issubdtype(d, np.floating): + dtype = "float" + if np.issubdtype(d, np.complexfloating): + dtype = "complex" + if d.kind in ("S", "U", "O"): + dtype = "str" + + yield NodeInfo( + name=key, path=child_path, kind="dataset", dtype=dtype, shape=shape_tuple + ) + + def read_dataset(self, path: str) -> np.ndarray: + assert self._file is not None, "File not open" + return self._file[path][()], "h5" + + def child_count(self, path: str) -> int: + assert self._file is not None, "File not open" + obj = self._file[path] if path != "/" else self._file + return len(obj) if isinstance(obj, h5py.Group) else 0 + + +class ImageLoader(BaseFileLoader): + """ + Loader for raster image files. + + The file is treated as a single, flat dataset. ``iter_nodes`` yields one + leaf node whose ``display_hint`` is set to ``"image_native"`` so that + ``DataPanel`` knows to render it as a native Qt image rather than pushing + it through the pyqtgraph pipeline. + + Requires: Pillow (``pip install Pillow``) + """ + + EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".gif", + ".tiff", + ".tif", + ".bmp", + ".webp", + ".ico", + ".ppm", + ".pgm", + ".pbm", + ) + + def __init__(self): + self._filepath: Optional[str] = None + + def open(self, filepath: str) -> None: + # Validate that Pillow can open it; keep only the path. + try: + from PIL import Image as _PILImage # noqa: F401 — existence check + + _PILImage.open(filepath).verify() + except Exception as exc: + raise OSError(f"Cannot open image {filepath!r}: {exc}") from exc + self._filepath = filepath + + def close(self) -> None: + self._filepath = None + + def iter_nodes(self, path: str) -> Iterator[NodeInfo]: + """Images have no internal hierarchy — yield a single leaf node.""" + if path != "/" or self._filepath is None: + return + from PIL import Image as _PILImage + + with _PILImage.open(self._filepath) as img: + w, h = img.size + mode = img.mode # e.g. "RGB", "RGBA", "L", … + + name = os.path.basename(self._filepath) + yield NodeInfo( + name=name, + path="/image", + kind="dataset", + dtype=mode, + shape=(h, w), + display_hint="image_native", + ) + + def read_dataset(self, path: str) -> np.ndarray: + """Return the image as a uint8 numpy array (H × W × C or H × W).""" + from PIL import Image as _PILImage + + with _PILImage.open(self._filepath) as img: # type: ignore[arg-type] + # Animated GIF → first frame only + if hasattr(img, "n_frames") and img.n_frames > 1: + img.seek(0) + return np.asarray(img), "image_native" + + def child_count(self, path: str) -> int: + return 1 if path == "/" else 0 + + +# ══════════════════════════════════════════════════════════════════════════════ +# Loader registry +# ══════════════════════════════════════════════════════════════════════════════ + + +class LoaderRegistry: + """Maps file extensions to loader classes.""" + + def __init__(self): + self._registry: dict[str, type[BaseFileLoader]] = {} + + def register(self, loader_cls: type[BaseFileLoader]) -> None: + """Register a loader class for all extensions it declares.""" + for ext in loader_cls.EXTENSIONS: + self._registry[ext.lower()] = loader_cls + + def get_loader(self, filepath: str) -> Optional[BaseFileLoader]: + """Return a fresh loader instance for *filepath*, or None if unsupported.""" + ext = os.path.splitext(filepath)[1].lower() + cls = self._registry.get(ext) + return cls() if cls is not None else None + + @property + def supported_extensions(self) -> list[str]: + return sorted(self._registry) + + +# Default global registry — pre-populated with built-in loaders. +_registry = LoaderRegistry() +_registry.register(HDF5Loader) +_registry.register(ImageLoader) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/__init__,py b/debye_bec/bec_widgets/widgets/data_viewer/panels/__init__,py new file mode 100644 index 0000000..e69de29 diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py new file mode 100644 index 0000000..d08bc40 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/data_view.py @@ -0,0 +1,472 @@ +""" +Data viewer +""" + +from typing import Literal, Optional + +import numpy as np +import pyqtgraph as pg +from bec_lib import bec_logger +from bec_widgets.utils.colors import Colors +from qtpy.QtCore import Qt + +# pylint: disable=E0611 +from qtpy.QtGui import QFont, QPixmap + +# pylint: disable=E0611 +from qtpy.QtWidgets import ( + QAbstractItemView, + QApplication, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QRadioButton, + QSizePolicy, + QSlider, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +logger = bec_logger.logger + + +class DataView(QWidget): + def __init__(self): + super().__init__() + self._layout = QVBoxLayout(self) + + # Header + hdr = QHBoxLayout() + self.path_label = QLabel("Select a dataset from the tree") + self.path_label.setObjectName("path_label") + self.info_label = QLabel("") + self.info_label.setObjectName("info_label") + hdr.addWidget(self.path_label, 1) + hdr.addWidget(self.info_label) + self._layout.addLayout(hdr) + + # View-mode selector + mode_box = QGroupBox("View mode") + mode_layout = QHBoxLayout(mode_box) + mode_layout.setContentsMargins(8, 4, 8, 4) + self.rb_auto = QRadioButton("Auto") + self.rb_plot = QRadioButton("Plot") + self.rb_image = QRadioButton("Image") + self.rb_table = QRadioButton("Table") + self.rb_auto.setChecked(True) + for rb in (self.rb_auto, self.rb_plot, self.rb_image, self.rb_table): + mode_layout.addWidget(rb) + rb.toggled.connect(self._on_mode_change) + mode_layout.addStretch() + self._layout.addWidget(mode_box) + + # Content stack + self.stack = QWidget() + self.stack_layout = QVBoxLayout(self.stack) + self.stack_layout.setContentsMargins(0, 0, 0, 0) + self._layout.addWidget(self.stack, 1) + + self.plot_widget = None + self.image_widget = None + + self._current_data = None + self.show_empty() + + # ── helpers ─────────────────────────────────────────────────────────── + def _clear_stack(self): + while self.stack_layout.count(): + item = self.stack_layout.takeAt(0) + if item.widget(): + item.widget().deleteLater() + self.plot_widget = None + self.image_widget = None + + def show_empty(self): + self._clear_stack() + lbl = QLabel("No data selected") + lbl.setObjectName("info_label") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.stack_layout.addWidget(lbl) + + def show_unsupported(self, path: str = "") -> None: + """Display a friendly 'not implemented' message for unknown file types.""" + self._clear_stack() + self._current_data = None + self.path_label.setText(path) + self.info_label.setText("") + lbl = QLabel("File type not supported") + lbl.setObjectName("info_label") + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.stack_layout.addWidget(lbl) + + def _on_mode_change(self): + if self._current_data is not None: + self.display(self._current_data, self.path_label.text()) + + def _active_mode(self): + if self.rb_plot.isChecked(): + return "plot" + if self.rb_image.isChecked(): + return "image" + if self.rb_table.isChecked(): + return "table" + return "auto" + + # ── public ──────────────────────────────────────────────────────────── + def display(self, data, path: str = "", display_hint: str = "") -> None: + """ + Render *data* in the panel. + + Parameters + ---------- + data: + A ``numpy.ndarray`` (or anything convertible to one). + path: + Human-readable label shown in the header. + display_hint: + Optional hint from the loader. ``"image_native"`` bypasses the + pyqtgraph pipeline and renders via a Qt ``QLabel``/``QPixmap`` + instead, which correctly handles RGB/RGBA colour images. + """ + self._current_data = data + self.path_label.setText(path) + + if not isinstance(data, np.ndarray): + data = np.array(data) + + self.info_label.setText( + f"shape {data.shape} · dtype {data.dtype} · {data.size:,} elements" + ) + + mode = self._active_mode() + if mode == "auto": + if data.ndim <= 1 and data.size > 1: + mode = "plot" + elif data.ndim <= 4 and min(data.shape) > 1: + mode = "image" + else: + mode = "table" + + if mode == "plot": + self._show_plot_1d(data) + elif mode == "image": + self._show_image_2d(data) + else: + self._show_table(data) + + # ── Native raster image (RGB / RGBA / greyscale) ─────────────────────── + def _show_native_image(self, data: np.ndarray) -> None: + """ + Render a uint8 numpy array (H×W, H×W×3, or H×W×4) using a plain + Qt QLabel so that colour images display correctly without pyqtgraph's + colourmap pipeline. The image is scaled to fit the available space + while preserving the aspect ratio. + """ + from qtpy.QtGui import QImage + + self._clear_stack() + + arr = data + if arr.dtype != np.uint8: + # Normalise to 0–255 for display + lo, hi = arr.min(), arr.max() + arr = ((arr - lo) / max(hi - lo, 1) * 255).astype(np.uint8) + + if arr.ndim == 2: + # Greyscale → replicate to RGB so QImage is straightforward + arr = np.stack([arr, arr, arr], axis=-1) + + if arr.ndim == 3 and arr.shape[2] == 4: + fmt = QImage.Format.Format_RGBA8888 + else: + if arr.shape[2] != 3: + arr = arr[:, :, :3] + fmt = QImage.Format.Format_RGB888 + + h, w, ch = arr.shape + bytes_per_line = ch * w + arr_contiguous = np.ascontiguousarray(arr) + qimg = QImage(arr_contiguous.data, w, h, bytes_per_line, fmt) + pixmap = QPixmap.fromImage(qimg) + + label = QLabel() + label.setAlignment(Qt.AlignmentFlag.AlignCenter) + label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + label.setMinimumSize(1, 1) + label.setPixmap( + pixmap.scaled( + label.size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + ) + + # Re-scale when the panel is resized + def _on_resize(event, _lbl=label, _px=pixmap): + _lbl.setPixmap( + _px.scaled( + _lbl.size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + ) + super(QLabel, _lbl).resizeEvent(event) # type: ignore[arg-type] + + label.resizeEvent = _on_resize # type: ignore[method-assign] + + self.stack_layout.addWidget(label) + + # ── 1-D line plot ────────────────────────────────────────────────────── + def _show_plot_1d(self, data): + self._clear_stack() + + is_2d = data.ndim == 2 + + if is_2d: + n_rows, n_cols = data.shape + row_data = data[0].astype(np.float32) + else: + row_data = data.reshape(-1).astype(np.float32) + + x = np.arange(row_data.size, dtype=np.float32) + + self.plot_widget = pg.PlotWidget() + plot_item = self.plot_widget.getPlotItem() + assert plot_item is not None, "PlotWidget has no PlotItem" + plot_item.showGrid(x=True, y=True, alpha=0.25) + self.plot_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + plot_item.setAutoVisible(y=False) # type: ignore[attr-defined] + + curve = pg.PlotDataItem( + x, row_data, pen=pg.mkPen(color="#2980b9", width=1.6), antialias=False + ) + self.plot_widget.addItem(curve) + + curve.setDownsampling(auto=True, method="peak") + curve.setClipToView(True) + curve.setSkipFiniteCheck(True) + + plot_item.enableAutoRange() # type: ignore[attr-defined] + + if is_2d: + # --- Slider --- + slider = QSlider(Qt.Orientation.Vertical) + slider.setMinimum(0) + slider.setMaximum(n_rows - 1) + slider.setValue(0) + slider.setFixedWidth(32) + slider.setPageStep(1) + + current_row_label = QLabel("0") + current_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + current_row_label.setFixedWidth(32) + + max_row_label = QLabel(f"0:{n_rows-1}") + max_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + max_row_label.setFixedWidth(32) + + def on_row_changed(row): + current_row_label.setText(str(row)) + new_data = data[row].astype(np.float32) + new_x = np.arange(new_data.size, dtype=np.float32) + curve.setData(new_x, new_data) + plot_item.enableAutoRange() # type: ignore[attr-defined] + + slider.valueChanged.connect(on_row_changed) + + slider_col = QWidget() + slider_col.setFixedWidth(36) + col_layout = QVBoxLayout(slider_col) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(2) + col_layout.addWidget(max_row_label) + col_layout.addWidget(slider) + col_layout.addWidget(current_row_label) + + container = QWidget() + h_layout = QHBoxLayout(container) + h_layout.setContentsMargins(0, 0, 0, 0) + h_layout.setSpacing(4) + h_layout.addWidget(slider_col) + h_layout.addWidget(self.plot_widget) + + self.stack_layout.addWidget(container) + else: + self.stack_layout.addWidget(self.plot_widget) + + self.apply_theme() + + def apply_theme(self, theme: Optional[Literal["dark", "light"]] = None): + """ + Apply the theme + + Args: + theme (Optional[str]): Theme, either "dark", "light", or None. Defaults to None. + """ + if theme is None: + app = QApplication.instance() + theme = app.theme.theme # type: ignore + + bg_color = pg.getConfigOption("background") + fg_color = pg.getConfigOption("foreground") + if self.plot_widget is not None: + n_curves = len(self.plot_widget.listDataItems()) + colors = Colors.golden_angle_color( + colormap="plasma", num=max(10, n_curves + 1), format="HEX" + ) + for idx, curve in enumerate(self.plot_widget.listDataItems()): + curve.setPen(pg.mkPen(color=colors[idx])) + # Background + self.plot_widget.setBackground(bg_color) + # Axes (tick marks, tick labels, axis line) + for axis in ["left", "bottom", "right", "top"]: + ax = self.plot_widget.getAxis(axis) + ax.setPen(pg.mkPen(color=fg_color)) + ax.setTextPen(pg.mkPen(color=fg_color)) + + if self.image_widget is not None: + self.image_widget.getView().setBackgroundColor(bg_color) + self.image_widget.ui.histogram.setBackground(bg_color) + + # ── 2-D image ────────────────────────────────────────────────────────── + def _show_image_2d(self, data): + self._clear_stack() + + stacked = False + n_images = 0 + rgb = False + img = data + if data.ndim == 3: + if data.shape[-1] in (3, 4): + rgb = True + else: + stacked = True + n_images = data.shape[0] + img = data[0, :] + elif data.ndim == 4: + if data.shape[-1] in (3, 4): + rgb = True + stacked = True + n_images = data.shape[0] + img = data[0, :] + + self.image_widget = pg.ImageView() + + self.image_widget.ui.roiBtn.hide() + self.image_widget.ui.menuBtn.hide() + + if not rgb: + self.image_widget.setColorMap(pg.colormap.get("inferno", source="matplotlib")) + + def set_image(img, autoLevels=True, autoHistogramRange=True): + if rgb: + self.image_widget.imageItem.setOpts(axisOrder="row-major") + self.image_widget.setImage(img) + else: + self.image_widget.setImage( + img.T, autoLevels=autoLevels, autoHistogramRange=autoHistogramRange + ) + + set_image(img) + + self.image_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + if stacked: + # --- Slider --- + slider = QSlider(Qt.Orientation.Vertical) + slider.setMinimum(0) + slider.setMaximum(n_images - 1) + slider.setValue(0) + slider.setFixedWidth(32) + slider.setPageStep(1) + + current_image_label = QLabel("0") + current_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + current_image_label.setFixedWidth(32) + + max_image_label = QLabel(f"0:{n_images-1}") + max_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + max_image_label.setFixedWidth(32) + + def on_image_changed(row): + current_image_label.setText(str(row)) + set_image(data[row, :], autoLevels=False, autoHistogramRange=False) + + slider.valueChanged.connect(on_image_changed) + + slider_col = QWidget() + slider_col.setFixedWidth(36) + col_layout = QVBoxLayout(slider_col) + col_layout.setContentsMargins(0, 0, 0, 0) + col_layout.setSpacing(2) + col_layout.addWidget(max_image_label) + col_layout.addWidget(slider) + col_layout.addWidget(current_image_label) + + container = QWidget() + h_layout = QHBoxLayout(container) + h_layout.setContentsMargins(0, 0, 0, 0) + h_layout.setSpacing(4) + h_layout.addWidget(slider_col) + h_layout.addWidget(self.image_widget) + + self.stack_layout.addWidget(container) + else: + self.stack_layout.addWidget(self.image_widget) + + self.apply_theme() + + # ── Table ────────────────────────────────────────────────────────────── + def _show_table(self, data): + self._clear_stack() + + MAX_ROWS, MAX_COLS = 2000, 500 + + if data.ndim == 0: + flat = data.reshape(1, 1) + elif data.ndim == 1: + flat = data.reshape(-1, 1) + elif data.ndim == 2: + flat = data + else: + flat = data.reshape(-1, data.shape[-1]) + + rows, cols = flat.shape + show_rows = min(rows, MAX_ROWS) + show_cols = min(cols, MAX_COLS) + + if rows > MAX_ROWS or cols > MAX_COLS: + note = QLabel(f"⚠ Showing {show_rows}/{rows} rows × {show_cols}/{cols} cols") + note.setObjectName("info_label") + note.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.stack_layout.addWidget(note) + + table = QTableWidget(show_rows, show_cols) + table.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) + table.setSelectionMode(QAbstractItemView.SelectionMode.ContiguousSelection) + table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive) + table.horizontalHeader().setDefaultSectionSize(90) + table.verticalHeader().setDefaultSectionSize(22) + table.setFont(QFont("JetBrains Mono, Consolas, monospace", 10)) + + is_float = np.issubdtype(flat.dtype, np.floating) + is_complex = np.iscomplexobj(flat) + is_bytes = flat.dtype.kind == "S" + + for r in range(show_rows): + for c in range(show_cols): + val = flat[r, c] + txt = ( + f"{val:.6g}" + if is_float + else f"{val:.4g}" if is_complex else str(val.decode()) if is_bytes else str(val) + ) + cell = QTableWidgetItem(txt) + cell.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) + table.setItem(r, c, cell) + + self.stack_layout.addWidget(table) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py new file mode 100644 index 0000000..f249b94 --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/input_panel.py @@ -0,0 +1,32 @@ +# pylint: disable=E0611 +from qtpy.QtWidgets import QHBoxLayout, QVBoxLayout, QWidget + +# pylint: disable=E0402 +from ..widgets.qt_widgets import Button, Group, ListWidget + + +class InputPanel(QWidget): + """Panel for scan selection of the data viewer widget""" + + def __init__(self, parent=None): + super().__init__(parent) + self._layout = QHBoxLayout(self) + + # Scan selection + self.scan_sel = ListWidget("scan_sel") + self.load_button = Button(label_button="Load Dataset", enabled=True) + self.unload_button = Button(label_button="Unload all", enabled=True) + self.open_fm_button = Button(label_button="Open in File Manager", enabled=True) + + self._button_layout = QVBoxLayout() + self._button_layout.addWidget(self.load_button) + self._button_layout.addWidget(self.unload_button) + self._button_layout.addWidget(self.open_fm_button) + self._button_layout.addStretch() + + # Assemble complete scan selection group + self.input_group = Group( + "Scan selection", [self._button_layout, self.scan_sel], orientation="horizontal" + ) + + self._layout.addWidget(self.input_group) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py b/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py new file mode 100644 index 0000000..8fa327c --- /dev/null +++ b/debye_bec/bec_widgets/widgets/data_viewer/panels/scan_view.py @@ -0,0 +1,234 @@ +""" +File Viewer — qtpy + pyqtgraph +Supports pluggable file-format loaders (HDF5 built-in; extend via BaseFileLoader). +""" + +from __future__ import annotations + +import os +from typing import Literal, Optional + +from bec_lib import bec_logger +from bec_qthemes import material_icon +from bec_widgets.utils.colors import get_accent_colors +from qtpy.QtCore import Qt + +# pylint: disable=E0611 +from qtpy.QtGui import QBrush, QColor + +# pylint: disable=E0611 +from qtpy.QtWidgets import ( + QHBoxLayout, + QHeaderView, + QMainWindow, + QSplitter, + QTreeWidget, + QTreeWidgetItem, + QVBoxLayout, + QWidget, +) + +from ..loaders import BaseFileLoader, LoaderRegistry, _registry +from .data_view import DataView + +logger = bec_logger.logger + +ICON_SIZE = 20 + + +class ScanViewer(QMainWindow): + """ + Generic file viewer. Supports any format registered in *registry*. + + Parameters + ---------- + filepath: + Optional path to open on startup. + registry: + ``LoaderRegistry`` to use. Defaults to the module-level ``_registry`` + which ships with ``HDF5Loader`` pre-registered. + """ + + def __init__(self, filepath: Optional[str] = None, registry: Optional[LoaderRegistry] = None): + super().__init__() + self._registry = registry or _registry + + # filepath -> (loader_instance, display_name) + self._open_files: dict[str, tuple[BaseFileLoader, str]] = {} + + self._build_ui() + if filepath: + self.load_files([filepath]) + + def apply_theme(self, theme: Literal["dark", "light"]): + """ + Apply the theme + + Args: + theme (str): Theme, either "dark" or "light" + """ + self.data_panel.apply_theme(theme) + + # ── public API ──────────────────────────────────────────────────────── + + def load_files(self, filepaths: list[str]) -> None: + """Open one or more files and add each as a top-level tree node.""" + for fp in filepaths: + if fp in self._open_files: + continue # already loaded + + loader = self._registry.get_loader(fp) + if loader is None: + supported = ", ".join(self._registry.supported_extensions) + logger.warning("No loader found for %r. Supported extensions: %s", fp, supported) + continue + + try: + loader.open(fp) + except Exception as exc: + logger.error("Failed to open %r: %s", fp, exc) + continue + + display_name = os.path.basename(fp) + self._open_files[fp] = (loader, display_name) + self._add_file_to_tree(fp, loader, display_name) + + def clear_files(self) -> None: + """Close all open files and reset the tree.""" + for loader, _ in self._open_files.values(): + loader.close() + self._open_files.clear() + self.tree.clear() + self.data_panel.show_empty() + + def closeEvent(self, event): + for loader, _ in self._open_files.values(): + loader.close() + self._open_files.clear() + super().closeEvent(event) + + # ── UI construction ─────────────────────────────────────────────────── + + def _build_ui(self): + central = QWidget() + self.setCentralWidget(central) + + root_layout = QHBoxLayout(central) + + splitter = QSplitter(Qt.Orientation.Horizontal) + splitter.setChildrenCollapsible(False) + + # ── Left pane ── + left_pane = QWidget() + left_layout = QVBoxLayout(left_pane) + + self.tree = QTreeWidget() + self.tree.setMinimumWidth(250) + self.tree.setHeaderLabels(["Name", "Type", "Shape"]) + self.tree.header().setStretchLastSection(False) + self.tree.header().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.tree.header().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.tree.header().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tree.itemClicked.connect(self._on_item_clicked) + + left_layout.addWidget(self.tree, 1) + + # ── Right pane ── + self.data_panel = DataView() + + splitter.addWidget(left_pane) + splitter.addWidget(self.data_panel) + + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 3) + splitter.setSizes([300, 900]) + splitter.setHandleWidth(6) + splitter.setChildrenCollapsible(False) + + root_layout.addWidget(splitter) + + # ── Tree helpers ────────────────────────────────────────────────────── + + def _add_file_to_tree(self, filepath: str, loader: BaseFileLoader, display_name: str) -> None: + """Add a single file as a new top-level node in the tree.""" + dataset_icon = material_icon( + "dataset", size=(ICON_SIZE, ICON_SIZE), color=get_accent_colors().default.name() + ) + root_item = QTreeWidgetItem(self.tree, [display_name, "Group", ""]) + root_item.setIcon(0, dataset_icon) + root_item.setData(0, Qt.ItemDataRole.UserRole, "/") # path + root_item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") # kind + root_item.setData(0, Qt.ItemDataRole.UserRole + 2, filepath) # file key + + self._populate_tree(root_item, loader, "/") + self.tree.addTopLevelItem(root_item) + + # Expand first 2 levels + self.tree.expandItem(root_item) + for i in range(root_item.childCount()): + self.tree.expandItem(root_item.child(i)) + + self.tree.setCurrentItem(root_item) + + def _populate_tree( + self, parent_item: QTreeWidgetItem, loader: BaseFileLoader, path: str + ) -> None: + folder_icon = material_icon("folder", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + vector_icon = material_icon("show_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + array_icon = material_icon( + "stacked_line_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9" + ) + scalar_icon = material_icon("point_scan", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + str_icon = material_icon("text_snippet", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") + + for node in loader.iter_nodes(path): + shape_str = "x".join(str(s) for s in node.shape) if node.shape else "scalar" + + if node.kind == "group": + item = QTreeWidgetItem(parent_item, [node.name, "Group", ""]) + item.setIcon(0, folder_icon) + item.setData(0, Qt.ItemDataRole.UserRole, node.path) + item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") + self._populate_tree(item, loader, node.path) + + else: # dataset + if shape_str == "scalar": + icon = str_icon if node.dtype == "str" else scalar_icon + elif "x" in shape_str: + icon = array_icon + else: + icon = vector_icon + + item = QTreeWidgetItem(parent_item, [node.name, node.dtype, shape_str]) + item.setIcon(0, icon) + item.setData(0, Qt.ItemDataRole.UserRole, node.path) + item.setData(0, Qt.ItemDataRole.UserRole + 1, "dataset") + item.setForeground(1, QBrush(QColor(get_accent_colors().success.name()))) + item.setForeground(2, QBrush(QColor("#656365"))) + + # ── Event handling ──────────────────────────────────────────────────── + + def _get_filepath_for_item(self, item: QTreeWidgetItem) -> str: + """Walk up the tree to find the filepath stored on the root node.""" + node = item + while node.parent(): + node = node.parent() + return node.data(0, Qt.ItemDataRole.UserRole + 2) + + def _on_item_clicked(self, item: QTreeWidgetItem, _col: int) -> None: + path = item.data(0, Qt.ItemDataRole.UserRole) + kind = item.data(0, Qt.ItemDataRole.UserRole + 1) + filepath = self._get_filepath_for_item(item) + + if not path or not filepath: + return + + loader, _ = self._open_files[filepath] + + if kind == "dataset": + data, ftype = loader.read_dataset(path) + self.data_panel.display(data, path, ftype) + else: + # Group selected — show child count in status bar (optional) + count = loader.child_count(path) + self.statusBar().showMessage(f"{path} ({count} items)") diff --git a/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py index 4ead24d..1be16c2 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/register_data_viewer.py @@ -6,7 +6,7 @@ def main(): # pragma: no cover return from PySide6.QtDesigner import QPyDesignerCustomWidgetCollection - from debye_bec.bec_widgets.widgets.data_viewer.data_viewer_plugin import DataViewerPlugin + from .data_viewer_plugin import DataViewerPlugin QPyDesignerCustomWidgetCollection.addCustomWidget(DataViewerPlugin()) diff --git a/debye_bec/bec_widgets/widgets/data_viewer/viewer.py b/debye_bec/bec_widgets/widgets/data_viewer/viewer.py deleted file mode 100644 index a222669..0000000 --- a/debye_bec/bec_widgets/widgets/data_viewer/viewer.py +++ /dev/null @@ -1,990 +0,0 @@ -""" -File Viewer — qtpy + pyqtgraph -Supports pluggable file-format loaders (HDF5 built-in; extend via BaseFileLoader). -""" - -from __future__ import annotations - -import os -from abc import ABC, abstractmethod -from typing import Any, Iterator, Literal, Optional - -import h5py -import numpy as np -import pyqtgraph as pg -from bec_lib import bec_logger -from bec_qthemes import material_icon -from bec_widgets.utils.colors import Colors, get_accent_colors -from qtpy.QtCore import Qt - -# pylint: disable=E0611 -from qtpy.QtGui import QBrush, QColor, QFont, QPixmap - -# pylint: disable=E0611 -from qtpy.QtWidgets import ( - QAbstractItemView, - QApplication, - QGroupBox, - QHBoxLayout, - QHeaderView, - QLabel, - QMainWindow, - QRadioButton, - QSizePolicy, - QSlider, - QSplitter, - QTableWidget, - QTableWidgetItem, - QTreeWidget, - QTreeWidgetItem, - QVBoxLayout, - QWidget, -) - -logger = bec_logger.logger - -ICON_SIZE = 20 - - -# ══════════════════════════════════════════════════════════════════════════════ -# Loader plugin interface -# ══════════════════════════════════════════════════════════════════════════════ - - -class NodeInfo: - """Describes a single node (group or dataset) inside a loaded file.""" - - __slots__ = ("name", "path", "kind", "dtype", "shape", "display_hint") - - def __init__( - self, - name: str, - path: str, - kind: Literal["group", "dataset"], - dtype: str = "", - shape: tuple[int, ...] = (), - display_hint: str = "", - ): - self.name = name - self.path = path - self.kind = kind - self.dtype = dtype # e.g. "float", "int", "str", … - self.shape = shape # empty tuple for scalars / groups - self.display_hint = display_hint # e.g. "image_native" — passed through to DataPanel - - -class BaseFileLoader(ABC): - """ - Abstract base class for file-format loaders. - - Subclass this to add support for a new format. Three things are required: - - 1. ``EXTENSIONS`` — tuple of lowercase extensions this loader handles, - e.g. ``(".h5", ".hdf5")``. - - 2. ``open(filepath)`` — open the file and keep any handles alive. - - 3. ``iter_nodes(path)`` — yield ``NodeInfo`` objects for the direct - children of *path* (depth-1 walk; the tree widget calls this - recursively as the user expands nodes). - - 4. ``read_dataset(path)`` — return the dataset at *path* as a - ``numpy.ndarray``. - - 5. ``close()`` — release any open file handles. - - 6. ``child_count(path)`` — return the number of direct children of a - group node (used for the status label; override if cheap to compute). - """ - - EXTENSIONS: tuple[str, ...] = () - - @abstractmethod - def open(self, filepath: str) -> None: ... - - @abstractmethod - def iter_nodes(self, path: str) -> Iterator[NodeInfo]: ... - - @abstractmethod - def read_dataset(self, path: str) -> np.ndarray: ... - - @abstractmethod - def close(self) -> None: ... - - def child_count(self, path: str) -> int: - return sum(1 for _ in self.iter_nodes(path)) - - -# ══════════════════════════════════════════════════════════════════════════════ -# HDF5 loader (replaces the inline h5py logic that was in HDF5Viewer) -# ══════════════════════════════════════════════════════════════════════════════ - - -class HDF5Loader(BaseFileLoader): - """Loader for HDF5 / NeXus files (.h5, .hdf5, .nxs, .nx).""" - - EXTENSIONS = (".h5", ".hdf5", ".hdf", ".nxs", ".nx") - - def __init__(self): - self._file: Optional[h5py.File] = None - - def open(self, filepath: str) -> None: - self._file = h5py.File(filepath, "r") - - def close(self) -> None: - if self._file is not None: - self._file.close() - self._file = None - - def iter_nodes(self, path: str) -> Iterator[NodeInfo]: - assert self._file is not None, "File not open" - obj = self._file[path] if path != "/" else self._file - - if not isinstance(obj, h5py.Group): - return - - for key in obj.keys(): - try: - child = obj[key] - except Exception: - continue - - child_path = child.name # h5py always gives the absolute path - - if isinstance(child, h5py.Group): - yield NodeInfo(name=key, path=child_path, kind="group") - - elif isinstance(child, h5py.Dataset): - shape_tuple = child.shape - d = child.dtype - dtype = "unknown" - if np.issubdtype(d, np.integer): - dtype = "int" - if np.issubdtype(d, np.floating): - dtype = "float" - if np.issubdtype(d, np.complexfloating): - dtype = "complex" - if d.kind in ("S", "U", "O"): - dtype = "str" - - yield NodeInfo( - name=key, path=child_path, kind="dataset", dtype=dtype, shape=shape_tuple - ) - - def read_dataset(self, path: str) -> np.ndarray: - assert self._file is not None, "File not open" - return self._file[path][()], "h5" - - def child_count(self, path: str) -> int: - assert self._file is not None, "File not open" - obj = self._file[path] if path != "/" else self._file - return len(obj) if isinstance(obj, h5py.Group) else 0 - - -# ══════════════════════════════════════════════════════════════════════════════ -# Example stub: NumPy .npz loader -# (Uncomment / flesh out to add .npz support) -# ══════════════════════════════════════════════════════════════════════════════ - -# class NpzLoader(BaseFileLoader): -# """Loader for NumPy .npz archives.""" -# -# EXTENSIONS = (".npz",) -# -# def __init__(self): -# self._archive: Optional[np.lib.npyio.NpzFile] = None -# -# def open(self, filepath: str) -> None: -# self._archive = np.load(filepath, allow_pickle=False) -# -# def close(self) -> None: -# if self._archive is not None: -# self._archive.close() -# self._archive = None -# -# def iter_nodes(self, path: str) -> Iterator[NodeInfo]: -# assert self._archive is not None -# # .npz has no hierarchy — treat all arrays as top-level datasets -# if path != "/": -# return -# for key in self._archive.files: -# arr = self._archive[key] -# yield NodeInfo(name=key, path=f"/{key}", kind="dataset", -# dtype=str(arr.dtype), shape=arr.shape) -# -# def read_dataset(self, path: str) -> np.ndarray: -# assert self._archive is not None -# key = path.lstrip("/") -# return self._archive[key] - - -# ══════════════════════════════════════════════════════════════════════════════ -# Image loader (.jpg, .png, .gif, .tiff, …) -# ══════════════════════════════════════════════════════════════════════════════ - - -class ImageLoader(BaseFileLoader): - """ - Loader for raster image files. - - The file is treated as a single, flat dataset. ``iter_nodes`` yields one - leaf node whose ``display_hint`` is set to ``"image_native"`` so that - ``DataPanel`` knows to render it as a native Qt image rather than pushing - it through the pyqtgraph pipeline. - - Requires: Pillow (``pip install Pillow``) - """ - - EXTENSIONS = ( - ".jpg", - ".jpeg", - ".png", - ".gif", - ".tiff", - ".tif", - ".bmp", - ".webp", - ".ico", - ".ppm", - ".pgm", - ".pbm", - ) - - def __init__(self): - self._filepath: Optional[str] = None - - def open(self, filepath: str) -> None: - # Validate that Pillow can open it; keep only the path. - try: - from PIL import Image as _PILImage # noqa: F401 — existence check - - _PILImage.open(filepath).verify() - except Exception as exc: - raise OSError(f"Cannot open image {filepath!r}: {exc}") from exc - self._filepath = filepath - - def close(self) -> None: - self._filepath = None - - def iter_nodes(self, path: str) -> Iterator[NodeInfo]: - """Images have no internal hierarchy — yield a single leaf node.""" - if path != "/" or self._filepath is None: - return - from PIL import Image as _PILImage - - with _PILImage.open(self._filepath) as img: - w, h = img.size - mode = img.mode # e.g. "RGB", "RGBA", "L", … - - name = os.path.basename(self._filepath) - yield NodeInfo( - name=name, - path="/image", - kind="dataset", - dtype=mode, - shape=(h, w), - display_hint="image_native", - ) - - def read_dataset(self, path: str) -> np.ndarray: - """Return the image as a uint8 numpy array (H × W × C or H × W).""" - from PIL import Image as _PILImage - - with _PILImage.open(self._filepath) as img: # type: ignore[arg-type] - # Animated GIF → first frame only - if hasattr(img, "n_frames") and img.n_frames > 1: - img.seek(0) - return np.asarray(img), "image_native" - - def child_count(self, path: str) -> int: - return 1 if path == "/" else 0 - - -# ══════════════════════════════════════════════════════════════════════════════ -# Loader registry -# ══════════════════════════════════════════════════════════════════════════════ - - -class LoaderRegistry: - """Maps file extensions to loader classes.""" - - def __init__(self): - self._registry: dict[str, type[BaseFileLoader]] = {} - - def register(self, loader_cls: type[BaseFileLoader]) -> None: - """Register a loader class for all extensions it declares.""" - for ext in loader_cls.EXTENSIONS: - self._registry[ext.lower()] = loader_cls - - def get_loader(self, filepath: str) -> Optional[BaseFileLoader]: - """Return a fresh loader instance for *filepath*, or None if unsupported.""" - ext = os.path.splitext(filepath)[1].lower() - cls = self._registry.get(ext) - return cls() if cls is not None else None - - @property - def supported_extensions(self) -> list[str]: - return sorted(self._registry) - - -# Default global registry — pre-populated with built-in loaders. -_registry = LoaderRegistry() -_registry.register(HDF5Loader) -_registry.register(ImageLoader) -# _registry.register(NpzLoader) ← uncomment once fleshed out - - -# ══════════════════════════════════════════════════════════════════════════════ -# Data / Plot panel (unchanged from original) -# ══════════════════════════════════════════════════════════════════════════════ - - -class DataPanel(QWidget): - def __init__(self): - super().__init__() - self._layout = QVBoxLayout(self) - - # Header - hdr = QHBoxLayout() - self.path_label = QLabel("Select a dataset from the tree") - self.path_label.setObjectName("path_label") - self.info_label = QLabel("") - self.info_label.setObjectName("info_label") - hdr.addWidget(self.path_label, 1) - hdr.addWidget(self.info_label) - self._layout.addLayout(hdr) - - # View-mode selector - mode_box = QGroupBox("View mode") - mode_layout = QHBoxLayout(mode_box) - mode_layout.setContentsMargins(8, 4, 8, 4) - self.rb_auto = QRadioButton("Auto") - self.rb_plot = QRadioButton("Plot") - self.rb_image = QRadioButton("Image") - self.rb_table = QRadioButton("Table") - self.rb_auto.setChecked(True) - for rb in (self.rb_auto, self.rb_plot, self.rb_image, self.rb_table): - mode_layout.addWidget(rb) - rb.toggled.connect(self._on_mode_change) - mode_layout.addStretch() - self._layout.addWidget(mode_box) - - # Content stack - self.stack = QWidget() - self.stack_layout = QVBoxLayout(self.stack) - self.stack_layout.setContentsMargins(0, 0, 0, 0) - self._layout.addWidget(self.stack, 1) - - self.plot_widget = None - self.image_widget = None - - self._current_data = None - self.show_empty() - - # ── helpers ─────────────────────────────────────────────────────────── - def _clear_stack(self): - while self.stack_layout.count(): - item = self.stack_layout.takeAt(0) - if item.widget(): - item.widget().deleteLater() - self.plot_widget = None - self.image_widget = None - - def show_empty(self): - self._clear_stack() - lbl = QLabel("No data selected") - lbl.setObjectName("info_label") - lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.stack_layout.addWidget(lbl) - - def show_unsupported(self, path: str = "") -> None: - """Display a friendly 'not implemented' message for unknown file types.""" - self._clear_stack() - self._current_data = None - self.path_label.setText(path) - self.info_label.setText("") - lbl = QLabel("File type not supported") - lbl.setObjectName("info_label") - lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.stack_layout.addWidget(lbl) - - def _on_mode_change(self): - if self._current_data is not None: - self.display(self._current_data, self.path_label.text()) - - def _active_mode(self): - if self.rb_plot.isChecked(): - return "plot" - if self.rb_image.isChecked(): - return "image" - if self.rb_table.isChecked(): - return "table" - return "auto" - - # ── public ──────────────────────────────────────────────────────────── - def display(self, data, path: str = "", display_hint: str = "") -> None: - """ - Render *data* in the panel. - - Parameters - ---------- - data: - A ``numpy.ndarray`` (or anything convertible to one). - path: - Human-readable label shown in the header. - display_hint: - Optional hint from the loader. ``"image_native"`` bypasses the - pyqtgraph pipeline and renders via a Qt ``QLabel``/``QPixmap`` - instead, which correctly handles RGB/RGBA colour images. - """ - self._current_data = data - self.path_label.setText(path) - - if not isinstance(data, np.ndarray): - data = np.array(data) - - self.info_label.setText( - f"shape {data.shape} · dtype {data.dtype} · {data.size:,} elements" - ) - - # Native-image hint takes priority over the radio-button mode selector. - # if display_hint == "image_native": - # self._show_native_image(data) - # return - - mode = self._active_mode() - if mode == "auto": - if data.ndim <= 1 and data.size > 1: - mode = "plot" - elif data.ndim <= 4 and min(data.shape) > 1: - mode = "image" - else: - mode = "table" - - if mode == "plot": - self._show_plot_1d(data) - elif mode == "image": - self._show_image_2d(data) - else: - self._show_table(data) - - # ── Native raster image (RGB / RGBA / greyscale) ─────────────────────── - def _show_native_image(self, data: np.ndarray) -> None: - """ - Render a uint8 numpy array (H×W, H×W×3, or H×W×4) using a plain - Qt QLabel so that colour images display correctly without pyqtgraph's - colourmap pipeline. The image is scaled to fit the available space - while preserving the aspect ratio. - """ - from qtpy.QtGui import QImage - - self._clear_stack() - - arr = data - if arr.dtype != np.uint8: - # Normalise to 0–255 for display - lo, hi = arr.min(), arr.max() - arr = ((arr - lo) / max(hi - lo, 1) * 255).astype(np.uint8) - - if arr.ndim == 2: - # Greyscale → replicate to RGB so QImage is straightforward - arr = np.stack([arr, arr, arr], axis=-1) - - if arr.ndim == 3 and arr.shape[2] == 4: - fmt = QImage.Format.Format_RGBA8888 - else: - if arr.shape[2] != 3: - arr = arr[:, :, :3] - fmt = QImage.Format.Format_RGB888 - - h, w, ch = arr.shape - bytes_per_line = ch * w - arr_contiguous = np.ascontiguousarray(arr) - qimg = QImage(arr_contiguous.data, w, h, bytes_per_line, fmt) - pixmap = QPixmap.fromImage(qimg) - - label = QLabel() - label.setAlignment(Qt.AlignmentFlag.AlignCenter) - label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - label.setMinimumSize(1, 1) - label.setPixmap( - pixmap.scaled( - label.size(), - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - ) - - # Re-scale when the panel is resized - def _on_resize(event, _lbl=label, _px=pixmap): - _lbl.setPixmap( - _px.scaled( - _lbl.size(), - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation, - ) - ) - super(QLabel, _lbl).resizeEvent(event) # type: ignore[arg-type] - - label.resizeEvent = _on_resize # type: ignore[method-assign] - - self.stack_layout.addWidget(label) - - # ── 1-D line plot ────────────────────────────────────────────────────── - def _show_plot_1d(self, data): - self._clear_stack() - - is_2d = data.ndim == 2 - - if is_2d: - n_rows, n_cols = data.shape - row_data = data[0].astype(np.float32) - else: - row_data = data.reshape(-1).astype(np.float32) - - x = np.arange(row_data.size, dtype=np.float32) - - self.plot_widget = pg.PlotWidget() - plot_item = self.plot_widget.getPlotItem() - assert plot_item is not None, "PlotWidget has no PlotItem" - plot_item.showGrid(x=True, y=True, alpha=0.25) - self.plot_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - - plot_item.setAutoVisible(y=False) # type: ignore[attr-defined] - - curve = pg.PlotDataItem( - x, row_data, pen=pg.mkPen(color="#2980b9", width=1.6), antialias=False - ) - self.plot_widget.addItem(curve) - - curve.setDownsampling(auto=True, method="peak") - curve.setClipToView(True) - curve.setSkipFiniteCheck(True) - - plot_item.enableAutoRange() # type: ignore[attr-defined] - - if is_2d: - # --- Slider --- - slider = QSlider(Qt.Orientation.Vertical) - slider.setMinimum(0) - slider.setMaximum(n_rows - 1) - slider.setValue(0) - slider.setFixedWidth(32) - slider.setPageStep(1) - - current_row_label = QLabel("0") - current_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - current_row_label.setFixedWidth(32) - - max_row_label = QLabel(f"0:{n_rows-1}") - max_row_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - max_row_label.setFixedWidth(32) - - def on_row_changed(row): - current_row_label.setText(str(row)) - new_data = data[row].astype(np.float32) - new_x = np.arange(new_data.size, dtype=np.float32) - curve.setData(new_x, new_data) - plot_item.enableAutoRange() # type: ignore[attr-defined] - - slider.valueChanged.connect(on_row_changed) - - slider_col = QWidget() - slider_col.setFixedWidth(36) - col_layout = QVBoxLayout(slider_col) - col_layout.setContentsMargins(0, 0, 0, 0) - col_layout.setSpacing(2) - col_layout.addWidget(max_row_label) - col_layout.addWidget(slider) - col_layout.addWidget(current_row_label) - - container = QWidget() - h_layout = QHBoxLayout(container) - h_layout.setContentsMargins(0, 0, 0, 0) - h_layout.setSpacing(4) - h_layout.addWidget(slider_col) - h_layout.addWidget(self.plot_widget) - - self.stack_layout.addWidget(container) - else: - self.stack_layout.addWidget(self.plot_widget) - - self.apply_theme() - - def apply_theme(self, theme: Optional[Literal["dark", "light"]] = None): - """ - Apply the theme - - Args: - theme (Optional[str]): Theme, either "dark", "light", or None. Defaults to None. - """ - if theme is None: - app = QApplication.instance() - theme = app.theme.theme # type: ignore - - bg_color = pg.getConfigOption("background") - fg_color = pg.getConfigOption("foreground") - if self.plot_widget is not None: - n_curves = len(self.plot_widget.listDataItems()) - colors = Colors.golden_angle_color( - colormap="plasma", num=max(10, n_curves + 1), format="HEX" - ) - for idx, curve in enumerate(self.plot_widget.listDataItems()): - curve.setPen(pg.mkPen(color=colors[idx])) - # Background - self.plot_widget.setBackground(bg_color) - # Axes (tick marks, tick labels, axis line) - for axis in ["left", "bottom", "right", "top"]: - ax = self.plot_widget.getAxis(axis) - ax.setPen(pg.mkPen(color=fg_color)) - ax.setTextPen(pg.mkPen(color=fg_color)) - - if self.image_widget is not None: - self.image_widget.getView().setBackgroundColor(bg_color) - self.image_widget.ui.histogram.setBackground(bg_color) - - # ── 2-D image ────────────────────────────────────────────────────────── - def _show_image_2d(self, data): - self._clear_stack() - - stacked = False - n_images = 0 - rgb = False - img = data - if data.ndim == 3: - if data.shape[-1] in (3, 4): - rgb = True - else: - stacked = True - n_images = data.shape[0] - img = data[0, :] - elif data.ndim == 4: - if data.shape[-1] in (3, 4): - rgb = True - stacked = True - n_images = data.shape[0] - img = data[0, :] - - self.image_widget = pg.ImageView() - - self.image_widget.ui.roiBtn.hide() - self.image_widget.ui.menuBtn.hide() - - if not rgb: - self.image_widget.setColorMap(pg.colormap.get("inferno", source="matplotlib")) - - def set_image(img, autoLevels=True, autoHistogramRange=True): - if rgb: - self.image_widget.imageItem.setOpts(axisOrder="row-major") - self.image_widget.setImage(img) - else: - self.image_widget.setImage( - img.T, autoLevels=autoLevels, autoHistogramRange=autoHistogramRange - ) - - set_image(img) - - self.image_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - - if stacked: - # --- Slider --- - slider = QSlider(Qt.Orientation.Vertical) - slider.setMinimum(0) - slider.setMaximum(n_images - 1) - slider.setValue(0) - slider.setFixedWidth(32) - slider.setPageStep(1) - - current_image_label = QLabel("0") - current_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - current_image_label.setFixedWidth(32) - - max_image_label = QLabel(f"0:{n_images-1}") - max_image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - max_image_label.setFixedWidth(32) - - def on_image_changed(row): - current_image_label.setText(str(row)) - set_image(data[row, :], autoLevels=False, autoHistogramRange=False) - - slider.valueChanged.connect(on_image_changed) - - slider_col = QWidget() - slider_col.setFixedWidth(36) - col_layout = QVBoxLayout(slider_col) - col_layout.setContentsMargins(0, 0, 0, 0) - col_layout.setSpacing(2) - col_layout.addWidget(max_image_label) - col_layout.addWidget(slider) - col_layout.addWidget(current_image_label) - - container = QWidget() - h_layout = QHBoxLayout(container) - h_layout.setContentsMargins(0, 0, 0, 0) - h_layout.setSpacing(4) - h_layout.addWidget(slider_col) - h_layout.addWidget(self.image_widget) - - self.stack_layout.addWidget(container) - else: - self.stack_layout.addWidget(self.image_widget) - - self.apply_theme() - - # ── Table ────────────────────────────────────────────────────────────── - def _show_table(self, data): - self._clear_stack() - - MAX_ROWS, MAX_COLS = 2000, 500 - - if data.ndim == 0: - flat = data.reshape(1, 1) - elif data.ndim == 1: - flat = data.reshape(-1, 1) - elif data.ndim == 2: - flat = data - else: - flat = data.reshape(-1, data.shape[-1]) - - rows, cols = flat.shape - show_rows = min(rows, MAX_ROWS) - show_cols = min(cols, MAX_COLS) - - if rows > MAX_ROWS or cols > MAX_COLS: - note = QLabel(f"⚠ Showing {show_rows}/{rows} rows × {show_cols}/{cols} cols") - note.setObjectName("info_label") - note.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.stack_layout.addWidget(note) - - table = QTableWidget(show_rows, show_cols) - table.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) - table.setSelectionMode(QAbstractItemView.SelectionMode.ContiguousSelection) - table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive) - table.horizontalHeader().setDefaultSectionSize(90) - table.verticalHeader().setDefaultSectionSize(22) - table.setFont(QFont("JetBrains Mono, Consolas, monospace", 10)) - - is_float = np.issubdtype(flat.dtype, np.floating) - is_complex = np.iscomplexobj(flat) - is_bytes = flat.dtype.kind == "S" - - for r in range(show_rows): - for c in range(show_cols): - val = flat[r, c] - txt = ( - f"{val:.6g}" - if is_float - else f"{val:.4g}" if is_complex else str(val.decode()) if is_bytes else str(val) - ) - cell = QTableWidgetItem(txt) - cell.setTextAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) - table.setItem(r, c, cell) - - self.stack_layout.addWidget(table) - - -# ══════════════════════════════════════════════════════════════════════════════ -# Main window -# ══════════════════════════════════════════════════════════════════════════════ - - -class FileViewer(QMainWindow): - """ - Generic file viewer. Supports any format registered in *registry*. - - Parameters - ---------- - filepath: - Optional path to open on startup. - registry: - ``LoaderRegistry`` to use. Defaults to the module-level ``_registry`` - which ships with ``HDF5Loader`` pre-registered. - """ - - def __init__(self, filepath: Optional[str] = None, registry: Optional[LoaderRegistry] = None): - super().__init__() - self._registry = registry or _registry - - # filepath -> (loader_instance, display_name) - self._open_files: dict[str, tuple[BaseFileLoader, str]] = {} - - self._build_ui() - if filepath: - self.load_files([filepath]) - - def apply_theme(self, theme: Literal["dark", "light"]): - """ - Apply the theme - - Args: - theme (str): Theme, either "dark" or "light" - """ - self.data_panel.apply_theme(theme) - - # ── public API ──────────────────────────────────────────────────────── - - def load_files(self, filepaths: list[str]) -> None: - """Open one or more files and add each as a top-level tree node.""" - for fp in filepaths: - if fp in self._open_files: - continue # already loaded - - loader = self._registry.get_loader(fp) - if loader is None: - supported = ", ".join(self._registry.supported_extensions) - logger.warning("No loader found for %r. Supported extensions: %s", fp, supported) - continue - - try: - loader.open(fp) - except Exception as exc: - logger.error("Failed to open %r: %s", fp, exc) - continue - - display_name = os.path.basename(fp) - self._open_files[fp] = (loader, display_name) - self._add_file_to_tree(fp, loader, display_name) - - def clear_files(self) -> None: - """Close all open files and reset the tree.""" - for loader, _ in self._open_files.values(): - loader.close() - self._open_files.clear() - self.tree.clear() - self.data_panel.show_empty() - - def closeEvent(self, event): - for loader, _ in self._open_files.values(): - loader.close() - self._open_files.clear() - super().closeEvent(event) - - # ── UI construction ─────────────────────────────────────────────────── - - def _build_ui(self): - central = QWidget() - self.setCentralWidget(central) - - root_layout = QHBoxLayout(central) - - splitter = QSplitter(Qt.Orientation.Horizontal) - splitter.setChildrenCollapsible(False) - - # ── Left pane ── - left_pane = QWidget() - left_layout = QVBoxLayout(left_pane) - - self.tree = QTreeWidget() - self.tree.setMinimumWidth(250) - self.tree.setHeaderLabels(["Name", "Type", "Shape"]) - self.tree.header().setStretchLastSection(False) - self.tree.header().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) - self.tree.header().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) - self.tree.header().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) - self.tree.itemClicked.connect(self._on_item_clicked) - - left_layout.addWidget(self.tree, 1) - - # ── Right pane ── - self.data_panel = DataPanel() - - splitter.addWidget(left_pane) - splitter.addWidget(self.data_panel) - - splitter.setStretchFactor(0, 1) - splitter.setStretchFactor(1, 3) - splitter.setSizes([300, 900]) - splitter.setHandleWidth(6) - splitter.setChildrenCollapsible(False) - - root_layout.addWidget(splitter) - - # ── Tree helpers ────────────────────────────────────────────────────── - - def _add_file_to_tree(self, filepath: str, loader: BaseFileLoader, display_name: str) -> None: - """Add a single file as a new top-level node in the tree.""" - dataset_icon = material_icon( - "dataset", size=(ICON_SIZE, ICON_SIZE), color=get_accent_colors().default.name() - ) - root_item = QTreeWidgetItem(self.tree, [display_name, "Group", ""]) - root_item.setIcon(0, dataset_icon) - root_item.setData(0, Qt.ItemDataRole.UserRole, "/") # path - root_item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") # kind - root_item.setData(0, Qt.ItemDataRole.UserRole + 2, filepath) # file key - - self._populate_tree(root_item, loader, "/") - self.tree.addTopLevelItem(root_item) - - # Expand first 2 levels - self.tree.expandItem(root_item) - for i in range(root_item.childCount()): - self.tree.expandItem(root_item.child(i)) - - self.tree.setCurrentItem(root_item) - - def _populate_tree( - self, parent_item: QTreeWidgetItem, loader: BaseFileLoader, path: str - ) -> None: - folder_icon = material_icon("folder", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - vector_icon = material_icon("show_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - array_icon = material_icon( - "stacked_line_chart", size=(ICON_SIZE, ICON_SIZE), color="#2980b9" - ) - scalar_icon = material_icon("point_scan", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - str_icon = material_icon("text_snippet", size=(ICON_SIZE, ICON_SIZE), color="#2980b9") - - for node in loader.iter_nodes(path): - shape_str = "x".join(str(s) for s in node.shape) if node.shape else "scalar" - - if node.kind == "group": - item = QTreeWidgetItem(parent_item, [node.name, "Group", ""]) - item.setIcon(0, folder_icon) - item.setData(0, Qt.ItemDataRole.UserRole, node.path) - item.setData(0, Qt.ItemDataRole.UserRole + 1, "group") - self._populate_tree(item, loader, node.path) - - else: # dataset - if shape_str == "scalar": - icon = str_icon if node.dtype == "str" else scalar_icon - elif "x" in shape_str: - icon = array_icon - else: - icon = vector_icon - - item = QTreeWidgetItem(parent_item, [node.name, node.dtype, shape_str]) - item.setIcon(0, icon) - item.setData(0, Qt.ItemDataRole.UserRole, node.path) - item.setData(0, Qt.ItemDataRole.UserRole + 1, "dataset") - item.setForeground(1, QBrush(QColor(get_accent_colors().success.name()))) - item.setForeground(2, QBrush(QColor("#656365"))) - - # ── Event handling ──────────────────────────────────────────────────── - - def _get_filepath_for_item(self, item: QTreeWidgetItem) -> str: - """Walk up the tree to find the filepath stored on the root node.""" - node = item - while node.parent(): - node = node.parent() - return node.data(0, Qt.ItemDataRole.UserRole + 2) - - def _on_item_clicked(self, item: QTreeWidgetItem, _col: int) -> None: - path = item.data(0, Qt.ItemDataRole.UserRole) - kind = item.data(0, Qt.ItemDataRole.UserRole + 1) - filepath = self._get_filepath_for_item(item) - - if not path or not filepath: - return - - loader, _ = self._open_files[filepath] - - if kind == "dataset": - data, ftype = loader.read_dataset(path) - self.data_panel.display(data, path, ftype) - else: - # Group selected — show child count in status bar (optional) - count = loader.child_count(path) - self.statusBar().showMessage(f"{path} ({count} items)") - - -# Backwards-compatible alias for code that imported HDF5Viewer by name. -HDF5Viewer = FileViewer diff --git a/debye_bec/bec_widgets/widgets/data_viewer/widgets/__init__,py b/debye_bec/bec_widgets/widgets/data_viewer/widgets/__init__,py new file mode 100644 index 0000000..e69de29 diff --git a/debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py b/debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py similarity index 50% rename from debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py rename to debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py index 9d7b4f7..187157e 100644 --- a/debye_bec/bec_widgets/widgets/data_viewer/qt_widgets.py +++ b/debye_bec/bec_widgets/widgets/data_viewer/widgets/qt_widgets.py @@ -1,17 +1,22 @@ """ -TaggedListWidget — a QListWidget where each item shows a label + styled tag pills. -Selection works natively; no popup, no paintEvent hacks. +Universal Qt widgets """ -import sys +from functools import partial +from bec_widgets.utils.colors import get_accent_colors + +# pylint: disable=E0611 from qtpy.QtCore import QPoint, QRect, QSize, Qt from qtpy.QtGui import QColor, QFont, QFontMetrics, QPainter, QPen from qtpy.QtWidgets import ( - QApplication, + QGroupBox, + QHBoxLayout, QLabel, + QLayout, QListWidget, QListWidgetItem, + QPushButton, QStyle, QStyledItemDelegate, QStyleOptionViewItem, @@ -19,7 +24,117 @@ from qtpy.QtWidgets import ( QWidget, ) -# ── Design tokens ────────────────────────────────────────────────────────────── + +class Group(QGroupBox): + def __init__(self, label, objs, orientation="vertical"): + super().__init__(label) + if orientation == "vertical": + self._layout = QVBoxLayout(self) + elif orientation == "horizontal": + self._layout = QHBoxLayout(self) + else: + raise ValueError(f"Orientation {orientation} is not supported!") + for obj in objs: + if isinstance(obj, QWidget): + self._layout.addWidget(obj) + elif isinstance(obj, QLayout): + self._layout.addLayout(obj) + + +class Button(QWidget): + def __init__(self, label=None, label_button: str = "", enabled=False): + super().__init__() + layout = QHBoxLayout(self) + layout.setContentsMargins(10, 0, 0, 0) + layout.setSpacing(0) + if label is not None: + self.label = QLabel(label) + self.label.setFixedWidth(140) + layout.addWidget(self.label) + self.button = QPushButton(label_button) + if label is not None: + self.button.setFixedWidth(160) + self.enable_button(enabled) + layout.addWidget(self.button) + + def clicked_connect(self, func): + """Connect a function to the button press.""" + self.button.clicked.connect(func) + + def enable_button(self, enable: bool = False): + if enable: + self.button.setStyleSheet( + f"QPushButton {{background-color: {get_accent_colors().default.name()}; color: white;}}" + ) + self.button.setEnabled(True) + else: # disabled + self.button.setStyleSheet( + "QPushButton {{background-color: rgb(120, 120, 120); color: white;}}" + ) + self.button.setDisabled(True) + + def setText(self, text): + self.button.setText(text) + + +class ListWidget(QWidget): + def __init__(self, identifier=""): + super().__init__() + layout = QHBoxLayout(self) + layout.setContentsMargins(10, 0, 0, 0) + layout.setSpacing(0) + self.identifier = identifier + self.value = TaggedListWidget() + layout.addWidget(self.value) + + def clear(self): + self.value.clear() + + def addTaggedItem(self, label, tags): + self.value.addTaggedItem(label, tags) + + def setCurrentIndex(self, text): + self.value.setCurrentIndex(text) + + def currentItemChanged_connect(self, func): + """Connect a function to the Enter/Return key press.""" + self.value.currentItemChanged.connect( + partial( + func, + identifier=self.identifier, + value_obj=self.value, + value=lambda: self.value.currentIndex(), + ) + ) + + def setDisabled(self, disable): + self.value.setDisabled(disable) + + +class TaggedListWidget(QListWidget): + """QListWidget with label + coloured tag pills per row.""" + + def __init__(self, parent=None) -> None: + super().__init__(parent) + self.setItemDelegate(TaggedDelegate(self)) + self.setMouseTracking(True) # enables hover highlight + + def addTaggedItem(self, label: str, tags: list | None = None) -> QListWidgetItem: + """ + Add a row. tags is a list of (text, hex_color) pairs, e.g. + [("v1.26", "#2563EB"), ("stable", "#16A34A")] + """ + item = QListWidgetItem(str(label)) + if tags: + item.setData(_UserRole, list(tags)) + self.addItem(item) + return item + + def currentTags(self) -> list: + item = self.currentItem() + return item.data(_UserRole) or [] if item else [] + + ITEM_HEIGHT = 30 H_PAD = 12 TAG_H_PAD = 7 @@ -28,30 +143,22 @@ TAG_GAP = 5 LABEL_TAG_GAP = 12 CORNER_RADIUS = 4 TAG_TEXT_COLOR = "#FFFFFF" - -# ── Enum compat (PySide6 nested vs PyQt5 flat) ──────────────────────────────── -try: - _DEMIBOLD = QFont.Weight.DemiBold - _MEDIUM = QFont.Weight.Medium -except AttributeError: - _DEMIBOLD = QFont.DemiBold - _MEDIUM = QFont.Medium - -_AlignVCenter = Qt.AlignmentFlag.AlignVCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignVCenter -_AlignCenter = Qt.AlignmentFlag.AlignCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignCenter -_UserRole = Qt.ItemDataRole.UserRole if hasattr(Qt, "ItemDataRole") else Qt.UserRole -_NoPen = Qt.PenStyle.NoPen if hasattr(Qt, "PenStyle") else Qt.NoPen -_AA = QPainter.RenderHint.Antialiasing if hasattr(QPainter, "RenderHint") else QPainter.Antialiasing - -try: - _State_Selected = QStyle.StateFlag.State_Selected - _State_MouseOver = QStyle.StateFlag.State_MouseOver -except AttributeError: - _State_Selected = QStyle.State_Selected - _State_MouseOver = QStyle.State_MouseOver +_DEMIBOLD = QFont.Weight.DemiBold +_MEDIUM = QFont.Weight.Medium +_AlignCenter = ( + Qt.AlignmentFlag.AlignCenter if hasattr(Qt, "AlignmentFlag") else Qt.AlignmentFlag.AlignCenter +) +_UserRole = Qt.ItemDataRole.UserRole if hasattr(Qt, "ItemDataRole") else Qt.ItemDataRole.UserRole +_NoPen = Qt.PenStyle.NoPen if hasattr(Qt, "PenStyle") else Qt.PenStyle.NoPen +_AA = ( + QPainter.RenderHint.Antialiasing + if hasattr(QPainter, "RenderHint") + else QPainter.RenderHint.Antialiasing +) +_State_Selected = QStyle.StateFlag.State_Selected +_State_MouseOver = QStyle.StateFlag.State_MouseOver -# ── Delegate ────────────────────────────────────────────────────────────────── class TaggedDelegate(QStyledItemDelegate): def sizeHint(self, option: QStyleOptionViewItem, index) -> QSize: @@ -73,7 +180,9 @@ class TaggedDelegate(QStyledItemDelegate): label_text = ( index.data( - Qt.ItemDataRole.DisplayRole if hasattr(Qt, "ItemDataRole") else Qt.DisplayRole + Qt.ItemDataRole.DisplayRole + if hasattr(Qt, "ItemDataRole") + else Qt.ItemDataRole.DisplayRole ) or "" ) @@ -124,28 +233,3 @@ class TaggedDelegate(QStyledItemDelegate): x += tw + TAG_GAP painter.restore() - - -# ── Widget ──────────────────────────────────────────────────────────────────── -class TaggedListWidget(QListWidget): - """QListWidget with label + coloured tag pills per row.""" - - def __init__(self, parent=None) -> None: - super().__init__(parent) - self.setItemDelegate(TaggedDelegate(self)) - self.setMouseTracking(True) # enables hover highlight - - def addTaggedItem(self, label: str, tags: list | None = None) -> QListWidgetItem: - """ - Add a row. tags is a list of (text, hex_color) pairs, e.g. - [("v1.26", "#2563EB"), ("stable", "#16A34A")] - """ - item = QListWidgetItem(str(label)) - if tags: - item.setData(_UserRole, list(tags)) - self.addItem(item) - return item - - def currentTags(self) -> list: - item = self.currentItem() - return item.data(_UserRole) or [] if item else []