diff --git a/src/pydase/components/image.py b/src/pydase/components/image.py index 13bb340..4131663 100644 --- a/src/pydase/components/image.py +++ b/src/pydase/components/image.py @@ -1,19 +1,16 @@ import base64 import io from pathlib import Path +from typing import TYPE_CHECKING, Optional +from urllib.request import urlopen import PIL.Image from loguru import logger -from urllib.request import urlopen from pydase.data_service.data_service import DataService - -class Figure: - """Mock class for matplotlib.Figure""" - - def savefig(self, format="png"): - pass +if TYPE_CHECKING: + from matplotlib.figure import Figure class Image(DataService): @@ -36,19 +33,24 @@ class Image(DataService): with PIL.Image.open(path) as image: self._load_from_PIL(image) - def load_from_matplotlib_figure(self, fig: Figure, format_: str = "png") -> None: + def load_from_matplotlib_figure(self, fig: "Figure", format_: str = "png") -> None: buffer = io.BytesIO() - fig.savefig(buffer, format=format_) + fig.savefig(buffer, format=format_) # type: ignore value_ = base64.b64encode(buffer.getvalue()) self._load_from_base64(value_, format_) - def load_from_url(self, url: str): + def load_from_url(self, url: str) -> None: image = PIL.Image.open(urlopen(url)) self._load_from_PIL(image) - def load_from_base64(self, value_: bytes, format_: str | None = None) -> None: + def load_from_base64(self, value_: bytes, format_: Optional[str] = None) -> None: if format_ is None: format_ = self._get_image_format_from_bytes(value_) + if format_ is None: + logger.warning( + "Format of passed byte string could not be determined. Skipping..." + ) + return self._load_from_base64(value_, format_) def _load_from_base64(self, value_: bytes, format_: str) -> None: @@ -66,10 +68,9 @@ class Image(DataService): else: logger.error("Image format is 'None'. Skipping...") - def _get_image_format_from_bytes(self, value_: bytes): + def _get_image_format_from_bytes(self, value_: bytes) -> str | None: image_data = base64.b64decode(value_) # Create a writable memory buffer for the image image_buffer = io.BytesIO(image_data) - # Read the image from the buffer - image = PIL.Image.open(image_buffer) - return image.format + # Read the image from the buffer and return format + return PIL.Image.open(image_buffer).format