From 697b7a7bee8e491b4d9dcbde1bbb4c4506b0ac71 Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Mon, 15 Dec 2025 23:46:15 +0100 Subject: [PATCH] fix(colors): more benevolent fetching of colormap names, avoid hardcoded wrong colormap mapping from GradientWidget from pg --- bec_widgets/utils/colors.py | 110 ++++++++++++++++-- bec_widgets/widgets/plots/image/image_base.py | 52 +++++++-- bec_widgets/widgets/plots/image/image_item.py | 3 +- tests/unit_tests/test_color_utils.py | 39 +++++++ tests/unit_tests/test_image_layer.py | 1 - 5 files changed, 184 insertions(+), 21 deletions(-) diff --git a/bec_widgets/utils/colors.py b/bec_widgets/utils/colors.py index 789338fc..6d72dd76 100644 --- a/bec_widgets/utils/colors.py +++ b/bec_widgets/utils/colors.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from functools import lru_cache from typing import Literal import numpy as np @@ -9,6 +10,7 @@ from bec_lib import bec_logger from bec_qthemes import apply_theme as apply_theme_global from bec_qthemes._theme import AccentColors from pydantic_core import PydanticCustomError +from pyqtgraph.graphicsItems.GradientEditorItem import Gradients from qtpy.QtCore import QEvent, QEventLoop from qtpy.QtGui import QColor from qtpy.QtWidgets import QApplication @@ -57,6 +59,91 @@ def apply_theme(theme: Literal["dark", "light"]): class Colors: + @staticmethod + def list_available_colormaps() -> list[str]: + """ + List colormap names available via the pyqtgraph colormap registry. + + Note: This does not include `GradientEditorItem` presets (used by HistogramLUT menus). + """ + + def _list(source: str | None = None) -> list[str]: + try: + return pg.colormap.listMaps() if source is None else pg.colormap.listMaps(source) + except Exception: # pragma: no cover - backend may be missing + return [] + + return [*_list(None), *_list("matplotlib"), *_list("colorcet")] + + @staticmethod + def list_available_gradient_presets() -> list[str]: + """ + List `GradientEditorItem` preset names (HistogramLUT right-click menu entries). + """ + from pyqtgraph.graphicsItems.GradientEditorItem import Gradients + + return list(Gradients.keys()) + + @staticmethod + def canonical_colormap_name(color_map: str) -> str: + """ + Return an available colormap/preset name if a case-insensitive match exists. + """ + requested = (color_map or "").strip() + if not requested: + return requested + + registry = Colors.list_available_colormaps() + presets = Colors.list_available_gradient_presets() + available = set(registry) | set(presets) + + if requested in available: + return requested + + # Case-insensitive match. + lower_to_canonical = {name.lower(): name for name in available} + return lower_to_canonical.get(requested.lower(), requested) + + @staticmethod + def get_colormap(color_map: str) -> pg.ColorMap: + """ + Resolve a string into a `pg.ColorMap` using either: + - the `pg.colormap` registry (optionally including matplotlib/colorcet backends), or + - `GradientEditorItem` presets (HistogramLUT right-click menu). + """ + name = Colors.canonical_colormap_name(color_map) + if not name: + raise ValueError("Empty colormap name") + + return Colors._get_colormap_cached(name) + + @staticmethod + @lru_cache(maxsize=256) + def _get_colormap_cached(name: str) -> pg.ColorMap: + # 1) Registry/backends + try: + cmap = pg.colormap.get(name) + if cmap is not None: + return cmap + except Exception: + pass + for source in ("matplotlib", "colorcet"): + try: + cmap = pg.colormap.get(name, source=source) + if cmap is not None: + return cmap + except Exception: + continue + + # 2) Presets -> ColorMap + + if name not in Gradients: + raise KeyError(f"Colormap '{name}' not found") + + ge = pg.GradientEditorItem() + ge.loadPreset(name) + + return ge.colorMap() @staticmethod def golden_ratio(num: int) -> list: @@ -138,7 +225,7 @@ class Colors: if theme_offset < 0 or theme_offset > 1: raise ValueError("theme_offset must be between 0 and 1") - cmap = pg.colormap.get(colormap) + cmap = Colors.get_colormap(colormap) min_pos, max_pos = Colors.set_theme_offset(theme, theme_offset) # Generate positions that are evenly spaced within the acceptable range @@ -186,7 +273,7 @@ class Colors: ValueError: If theme_offset is not between 0 and 1. """ - cmap = pg.colormap.get(colormap) + cmap = Colors.get_colormap(colormap) phi = (1 + np.sqrt(5)) / 2 # Golden ratio golden_angle_conjugate = 1 - (1 / phi) # Approximately 0.38196601125 @@ -452,21 +539,24 @@ class Colors: Raises: PydanticCustomError: If colormap is invalid. """ - available_pg_maps = pg.colormap.listMaps() - available_mpl_maps = pg.colormap.listMaps("matplotlib") - available_mpl_colorcet = pg.colormap.listMaps("colorcet") - - available_colormaps = available_pg_maps + available_mpl_maps + available_mpl_colorcet - if color_map not in available_colormaps: + normalized = Colors.canonical_colormap_name(color_map) + try: + Colors.get_colormap(normalized) + except Exception as ext: + logger.warning(f"Colormap validation error: {ext}") if return_error: + available_colormaps = sorted( + set(Colors.list_available_colormaps()) + | set(Colors.list_available_gradient_presets()) + ) raise PydanticCustomError( "unsupported colormap", - f"Colormap '{color_map}' not found in the current installation of pyqtgraph. Choose on the following: {available_colormaps}.", + f"Colormap '{color_map}' not found in the current installation of pyqtgraph. Choose from the following: {available_colormaps}.", {"wrong_value": color_map}, ) else: return False - return color_map + return normalized @staticmethod def relative_luminance(color: QColor) -> float: diff --git a/bec_widgets/widgets/plots/image/image_base.py b/bec_widgets/widgets/plots/image/image_base.py index 1b638906..ac14d72a 100644 --- a/bec_widgets/widgets/plots/image/image_base.py +++ b/bec_widgets/widgets/plots/image/image_base.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError from qtpy.QtCore import QPointF, Signal, SignalInstance from qtpy.QtWidgets import QDialog, QVBoxLayout +from bec_widgets.utils import Colors from bec_widgets.utils.container_utils import WidgetContainerUtils from bec_widgets.utils.error_popups import SafeProperty, SafeSlot from bec_widgets.utils.side_panel import SidePanel @@ -131,8 +132,9 @@ class ImageLayerManager: image.setZValue(z_position) image.removed.connect(self._remove_destroyed_layer) - # FIXME: For now, we hard-code the default color map here. In the future, this should be configurable. - image.color_map = "plasma" + color_map = getattr(getattr(self.parent, "config", None), "color_map", None) + if color_map: + image.color_map = color_map self.layers[name] = ImageLayer(name=name, image=image, sync=sync) self.plot_item.addItem(image) @@ -249,6 +251,8 @@ class ImageBase(PlotBase): Base class for the Image widget. """ + MAX_TICKS_COLORBAR = 10 + sync_colorbar_with_autorange = Signal() image_updated = Signal() layer_added = Signal(str) @@ -460,18 +464,20 @@ class ImageBase(PlotBase): self.setProperty("autorange", False) if style == "simple": - self._color_bar = pg.ColorBarItem(colorMap=self.config.color_map) + cmap = Colors.get_colormap(self.config.color_map) + self._color_bar = pg.ColorBarItem(colorMap=cmap) self._color_bar.setImageItem(self.layer_manager["main"].image) self._color_bar.sigLevelsChangeFinished.connect(disable_autorange) + self.config.color_bar = "simple" elif style == "full": self._color_bar = pg.HistogramLUTItem() self._color_bar.setImageItem(self.layer_manager["main"].image) - self._color_bar.gradient.loadPreset(self.config.color_map) + self.config.color_bar = "full" + self._apply_colormap_to_colorbar(self.config.color_map) self._color_bar.sigLevelsChanged.connect(disable_autorange) self.plot_widget.addItem(self._color_bar, row=0, col=1) - self.config.color_bar = style else: if self._color_bar: self.plot_widget.removeItem(self._color_bar) @@ -484,6 +490,37 @@ class ImageBase(PlotBase): if vrange: # should be at the end to disable the autorange if defined self.v_range = vrange + def _apply_colormap_to_colorbar(self, color_map: str) -> None: + if not self._color_bar: + return + + cmap = Colors.get_colormap(color_map) + + if self.config.color_bar == "simple": + self._color_bar.setColorMap(cmap) + return + + if self.config.color_bar != "full": + return + + gradient = getattr(self._color_bar, "gradient", None) + if gradient is None: + return + + positions = np.linspace(0.0, 1.0, self.MAX_TICKS_COLORBAR) + colors = cmap.map(positions, mode="byte") + + colors = np.asarray(colors) + if colors.ndim != 2: + return + if colors.shape[1] == 3: # add alpha + alpha = np.full((colors.shape[0], 1), 255, dtype=colors.dtype) + colors = np.concatenate([colors, alpha], axis=1) + + ticks = [(float(p), tuple(int(x) for x in c)) for p, c in zip(positions, colors)] + state = {"mode": "rgb", "ticks": ticks} + gradient.restoreState(state) + ################################################################################ # Static rois with roi manager @@ -754,10 +791,7 @@ class ImageBase(PlotBase): layer.image.color_map = value if self._color_bar: - if self.config.color_bar == "simple": - self._color_bar.setColorMap(value) - elif self.config.color_bar == "full": - self._color_bar.gradient.loadPreset(value) + self._apply_colormap_to_colorbar(self.config.color_map) except ValidationError: return diff --git a/bec_widgets/widgets/plots/image/image_item.py b/bec_widgets/widgets/plots/image/image_item.py index df182538..6f24ca3b 100644 --- a/bec_widgets/widgets/plots/image/image_item.py +++ b/bec_widgets/widgets/plots/image/image_item.py @@ -119,7 +119,8 @@ class ImageItem(BECConnector, pg.ImageItem): """Set a new color map.""" try: self.config.color_map = value - self.setColorMap(value) + cmap = Colors.get_colormap(self.config.color_map) + self.setColorMap(cmap) except ValidationError: logger.error(f"Invalid colormap '{value}' provided.") diff --git a/tests/unit_tests/test_color_utils.py b/tests/unit_tests/test_color_utils.py index ac2ef246..39c46473 100644 --- a/tests/unit_tests/test_color_utils.py +++ b/tests/unit_tests/test_color_utils.py @@ -82,6 +82,45 @@ def test_rgba_to_hex(): assert Colors.rgba_to_hex(255, 87, 51) == "#FF5733FF" +def test_canonical_colormap_name_case_insensitive(): + available = Colors.list_available_colormaps() + presets = Colors.list_available_gradient_presets() + if not available and not presets: + pytest.skip("No colormaps or presets available to test canonical mapping.") + + name = (available or presets)[0] + requested = name.swapcase() + assert Colors.canonical_colormap_name(requested) == name + + +def test_validate_color_map_returns_canonical_name(): + available = Colors.list_available_colormaps() + presets = Colors.list_available_gradient_presets() + if not available and not presets: + pytest.skip("No colormaps or presets available to test validation.") + + name = (available or presets)[0] + requested = name.swapcase() + assert Colors.validate_color_map(requested) == name + + +def test_get_colormap_uses_gradient_preset_fallback(monkeypatch): + presets = Colors.list_available_gradient_presets() + if not presets: + pytest.skip("No gradient presets available to test fallback.") + + preset = presets[0] + Colors._get_colormap_cached.cache_clear() + + def _raise(*args, **kwargs): + raise Exception("registry unavailable") + + monkeypatch.setattr(pg.colormap, "get", _raise) + + cmap = Colors._get_colormap_cached(preset) + assert isinstance(cmap, pg.ColorMap) + + @pytest.mark.parametrize("num", [10, 100, 400]) def test_evenly_spaced_colors(num): colors_qcolor = Colors.evenly_spaced_colors(colormap="magma", num=num, format="QColor") diff --git a/tests/unit_tests/test_image_layer.py b/tests/unit_tests/test_image_layer.py index 078ffa0a..c5463b4b 100644 --- a/tests/unit_tests/test_image_layer.py +++ b/tests/unit_tests/test_image_layer.py @@ -4,7 +4,6 @@ import pyqtgraph as pg import pytest from bec_widgets.widgets.plots.image.image_base import ImageLayerManager -from bec_widgets.widgets.plots.image.image_item import ImageItem @pytest.fixture()