implements ruff suggestions

This commit is contained in:
Mose Müller 2023-11-27 17:37:15 +01:00
parent e576f6eb80
commit 9e9d3f17bc
12 changed files with 70 additions and 63 deletions

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from urllib.request import urlopen from urllib.request import urlopen
import PIL.Image # type: ignore import PIL.Image # type: ignore[import-untyped]
from pydase.data_service.data_service import DataService from pydase.data_service.data_service import DataService
@ -33,17 +33,17 @@ class Image(DataService):
def load_from_path(self, path: Path | str) -> None: def load_from_path(self, path: Path | str) -> None:
with PIL.Image.open(path) as image: with PIL.Image.open(path) as image:
self._load_from_PIL(image) self._load_from_pil(image)
def load_from_matplotlib_figure(self, fig: "Figure", format_: str = "png") -> None: def load_from_matplotlib_figure(self, fig: "Figure", format_: str = "png") -> None:
buffer = io.BytesIO() buffer = io.BytesIO()
fig.savefig(buffer, format=format_) # type: ignore fig.savefig(buffer, format=format_) # type: ignore[reportUnknownMemberType]
value_ = base64.b64encode(buffer.getvalue()) value_ = base64.b64encode(buffer.getvalue())
self._load_from_base64(value_, format_) self._load_from_base64(value_, format_)
def load_from_url(self, url: str) -> None: def load_from_url(self, url: str) -> None:
image = PIL.Image.open(urlopen(url)) image = PIL.Image.open(urlopen(url))
self._load_from_PIL(image) self._load_from_pil(image)
def load_from_base64(self, value_: bytes, format_: Optional[str] = None) -> None: def load_from_base64(self, value_: bytes, format_: Optional[str] = None) -> None:
if format_ is None: if format_ is None:
@ -60,7 +60,7 @@ class Image(DataService):
self._value = value self._value = value
self._format = format_ self._format = format_
def _load_from_PIL(self, image: PIL.Image.Image) -> None: def _load_from_pil(self, image: PIL.Image.Image) -> None:
if image.format is not None: if image.format is not None:
format_ = image.format format_ = image.format
buffer = io.BytesIO() buffer = io.BytesIO()

View File

@ -3,7 +3,7 @@ from typing import Literal
from confz import BaseConfig, EnvSource from confz import BaseConfig, EnvSource
class OperationMode(BaseConfig): # type: ignore class OperationMode(BaseConfig): # type: ignore[misc]
environment: Literal["development"] | Literal["production"] = "development" environment: Literal["development", "production"] = "development"
CONFIG_SOURCES = EnvSource(allow=["ENVIRONMENT"]) CONFIG_SOURCES = EnvSource(allow=["ENVIRONMENT"])

View File

@ -2,8 +2,7 @@ from __future__ import annotations
import inspect import inspect
import logging import logging
from collections.abc import Callable from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any
from pydase.data_service.abstract_data_service import AbstractDataService from pydase.data_service.abstract_data_service import AbstractDataService
from pydase.utils.helpers import get_class_and_instance_attributes from pydase.utils.helpers import get_class_and_instance_attributes
@ -11,13 +10,15 @@ from pydase.utils.helpers import get_class_and_instance_attributes
from .data_service_list import DataServiceList from .data_service_list import DataServiceList
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from .data_service import DataService from .data_service import DataService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CallbackManager: class CallbackManager:
_notification_callbacks: list[Callable[[str, str, Any], Any]] = [] _notification_callbacks: ClassVar[list[Callable[[str, str, Any], Any]]] = []
""" """
A list of callback functions that are executed when a change occurs in the A list of callback functions that are executed when a change occurs in the
DataService instance. These functions are intended to handle or respond to these DataService instance. These functions are intended to handle or respond to these
@ -38,7 +39,7 @@ class CallbackManager:
This implementation follows the observer pattern, with the DataService instance as This implementation follows the observer pattern, with the DataService instance as
the "subject" and the callback functions as the "observers". the "subject" and the callback functions as the "observers".
""" """
_list_mapping: dict[int, DataServiceList] = {} _list_mapping: ClassVar[dict[int, DataServiceList]] = {}
""" """
A dictionary mapping the id of the original lists to the corresponding A dictionary mapping the id of the original lists to the corresponding
DataServiceList instances. DataServiceList instances.
@ -53,7 +54,7 @@ class CallbackManager:
self.service = service self.service = service
def _register_list_change_callbacks( # noqa: C901 def _register_list_change_callbacks( # noqa: C901
self, obj: "AbstractDataService", parent_path: str self, obj: AbstractDataService, parent_path: str
) -> None: ) -> None:
""" """
This method ensures that notifications are emitted whenever a public list This method ensures that notifications are emitted whenever a public list
@ -136,7 +137,7 @@ class CallbackManager:
self._register_list_change_callbacks(item, new_path) self._register_list_change_callbacks(item, new_path)
def _register_DataService_instance_callbacks( def _register_DataService_instance_callbacks(
self, obj: "AbstractDataService", parent_path: str self, obj: AbstractDataService, parent_path: str
) -> None: ) -> None:
""" """
This function is a key part of the observer pattern implemented by the This function is a key part of the observer pattern implemented by the
@ -208,7 +209,7 @@ class CallbackManager:
) )
def _register_service_callbacks( def _register_service_callbacks(
self, nested_attr: "AbstractDataService", parent_path: str, attr_name: str self, nested_attr: AbstractDataService, parent_path: str, attr_name: str
) -> None: ) -> None:
"""Handles registration of callbacks for DataService attributes""" """Handles registration of callbacks for DataService attributes"""
@ -221,7 +222,7 @@ class CallbackManager:
def __register_recursive_parameter_callback( def __register_recursive_parameter_callback(
self, self,
obj: "AbstractDataService | DataServiceList", obj: AbstractDataService | DataServiceList,
callback: Callable[[str | int, Any], None], callback: Callable[[str | int, Any], None],
) -> None: ) -> None:
""" """
@ -255,7 +256,7 @@ class CallbackManager:
def _register_property_callbacks( # noqa: C901 def _register_property_callbacks( # noqa: C901
self, self,
obj: "AbstractDataService", obj: AbstractDataService,
parent_path: str, parent_path: str,
) -> None: ) -> None:
""" """
@ -284,8 +285,8 @@ class CallbackManager:
item, parent_path=f"{parent_path}.{attr_name}[{i}]" item, parent_path=f"{parent_path}.{attr_name}[{i}]"
) )
if isinstance(attr_value, property): if isinstance(attr_value, property):
dependencies = attr_value.fget.__code__.co_names # type: ignore dependencies = attr_value.fget.__code__.co_names # type: ignore[union-attr]
source_code_string = inspect.getsource(attr_value.fget) # type: ignore source_code_string = inspect.getsource(attr_value.fget) # type: ignore[arg-type]
for dependency in dependencies: for dependency in dependencies:
# check if the dependencies are attributes of obj # check if the dependencies are attributes of obj
@ -304,7 +305,7 @@ class CallbackManager:
dependency_value = getattr(obj, dependency) dependency_value = getattr(obj, dependency)
if isinstance( if isinstance(
dependency_value, (DataServiceList, AbstractDataService) dependency_value, DataServiceList | AbstractDataService
): ):
def list_or_data_service_callback( def list_or_data_service_callback(
@ -345,8 +346,8 @@ class CallbackManager:
# Add to callbacks # Add to callbacks
obj._callback_manager.callbacks.add(callback) obj._callback_manager.callbacks.add(callback)
def _register_start_stop_task_callbacks( # noqa def _register_start_stop_task_callbacks( # noqa: C901
self, obj: "AbstractDataService", parent_path: str self, obj: AbstractDataService, parent_path: str
) -> None: ) -> None:
""" """
This function registers callbacks for start and stop methods of async functions. This function registers callbacks for start and stop methods of async functions.

View File

@ -1,10 +1,9 @@
import logging import logging
import warnings import warnings
from enum import Enum from enum import Enum
from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, get_type_hints
from typing import Any, Optional, get_type_hints
import rpyc # type: ignore import rpyc # type: ignore[import-untyped]
import pydase.units as u import pydase.units as u
from pydase.data_service.abstract_data_service import AbstractDataService from pydase.data_service.abstract_data_service import AbstractDataService
@ -27,6 +26,9 @@ from pydase.utils.warnings import (
warn_if_instance_class_does_not_inherit_from_DataService, warn_if_instance_class_does_not_inherit_from_DataService,
) )
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,8 +58,8 @@ class DataService(rpyc.Service, AbstractDataService):
filename = kwargs.pop("filename", None) filename = kwargs.pop("filename", None)
if filename is not None: if filename is not None:
warnings.warn( warnings.warn(
"The 'filename' argument is deprecated and will be removed in a future version. " "The 'filename' argument is deprecated and will be removed in a future "
"Please pass the 'filename' argument to `pydase.Server`.", "version. Please pass the 'filename' argument to `pydase.Server`.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
@ -80,7 +82,7 @@ class DataService(rpyc.Service, AbstractDataService):
super().__setattr__(__name, __value) super().__setattr__(__name, __value)
if self.__dict__.get("_initialised") and not __name == "_initialised": if self.__dict__.get("_initialised") and __name != "_initialised":
for callback in self._callback_manager.callbacks: for callback in self._callback_manager.callbacks:
callback(__name, __value) callback(__name, __value)
elif __name.startswith(f"_{self.__class__.__name__}__"): elif __name.startswith(f"_{self.__class__.__name__}__"):
@ -98,7 +100,7 @@ class DataService(rpyc.Service, AbstractDataService):
if not attr_name.startswith("_"): if not attr_name.startswith("_"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value) warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
def __set_attribute_based_on_type( # noqa:CFQ002 def __set_attribute_based_on_type( # noqa: PLR0913
self, self,
target_obj: Any, target_obj: Any,
attr_name: str, attr_name: str,
@ -155,9 +157,11 @@ class DataService(rpyc.Service, AbstractDataService):
) )
if hasattr(self, "_state_manager"): if hasattr(self, "_state_manager"):
getattr(self, "_state_manager").save_state() self._state_manager.save_state() # type: ignore[reportGeneralTypeIssue]
def load_DataService_from_JSON(self, json_dict: dict[str, Any]) -> None: def load_DataService_from_JSON( # noqa: N802
self, json_dict: dict[str, Any]
) -> None:
warnings.warn( warnings.warn(
"'load_DataService_from_JSON' is deprecated and will be removed in a " "'load_DataService_from_JSON' is deprecated and will be removed in a "
"future version. " "future version. "
@ -202,7 +206,7 @@ class DataService(rpyc.Service, AbstractDataService):
class_value_type, class_value_type,
) )
def serialize(self) -> dict[str, dict[str, Any]]: # noqa def serialize(self) -> dict[str, dict[str, Any]]:
""" """
Serializes the instance into a dictionary, preserving the structure of the Serializes the instance into a dictionary, preserving the structure of the
instance. instance.
@ -221,7 +225,7 @@ class DataService(rpyc.Service, AbstractDataService):
""" """
return Serializer.serialize_object(self)["value"] return Serializer.serialize_object(self)["value"]
def update_DataService_attribute( def update_DataService_attribute( # noqa: N802
self, self,
path_list: list[str], path_list: list[str],
attr_name: str, attr_name: str,

View File

@ -41,9 +41,9 @@ class DataServiceList(list):
# prevent gc to delete the passed list by keeping a reference # prevent gc to delete the passed list by keeping a reference
self._original_list = args[0] self._original_list = args[0]
super().__init__(*args, **kwargs) # type: ignore super().__init__(*args, **kwargs) # type: ignore[reportUnknownMemberType]
def __setitem__(self, key: int, value: Any) -> None: # type: ignore def __setitem__(self, key: int, value: Any) -> None: # type: ignore[override]
current_value = self.__getitem__(key) current_value = self.__getitem__(key)
# parse ints into floats if current value is a float # parse ints into floats if current value is a float
@ -52,7 +52,7 @@ class DataServiceList(list):
if isinstance(current_value, u.Quantity): if isinstance(current_value, u.Quantity):
value = u.convert_to_quantity(value, str(current_value.u)) value = u.convert_to_quantity(value, str(current_value.u))
super().__setitem__(key, value) # type: ignore super().__setitem__(key, value) # type: ignore[reportUnknownMemberType]
for callback in self._callbacks: for callback in self._callbacks:
callback(key, value) callback(key, value)

View File

@ -41,7 +41,7 @@ def load_state(func: Callable[..., Any]) -> Callable[..., Any]:
... self._name = value ... self._name = value
""" """
func._load_state = True # type: ignore func._load_state = True # type: ignore[attr-defined]
return func return func
@ -51,7 +51,7 @@ def has_load_state_decorator(prop: property) -> bool:
""" """
try: try:
return getattr(prop.fset, "_load_state") return prop.fset._load_state # type: ignore[union-attr]
except AttributeError: except AttributeError:
return False return False
@ -96,7 +96,9 @@ class StateManager:
update. update.
""" """
def __init__(self, service: "DataService", filename: Optional[str | Path] = None): def __init__(
self, service: "DataService", filename: Optional[str | Path] = None
) -> None:
self.filename = getattr(service, "_filename", None) self.filename = getattr(service, "_filename", None)
if filename is not None: if filename is not None:
@ -136,7 +138,7 @@ class StateManager:
""" """
# Traverse the serialized representation and set the attributes of the class # Traverse the serialized representation and set the attributes of the class
json_dict = self._get_state_dict_from_JSON_file() json_dict = self._get_state_dict_from_json_file()
if json_dict == {}: if json_dict == {}:
logger.debug("Could not load the service state.") logger.debug("Could not load the service state.")
return return
@ -162,9 +164,9 @@ class StateManager:
class_attr_value_type, class_attr_value_type,
) )
def _get_state_dict_from_JSON_file(self) -> dict[str, Any]: def _get_state_dict_from_json_file(self) -> dict[str, Any]:
if self.filename is not None and os.path.exists(self.filename): if self.filename is not None and os.path.exists(self.filename):
with open(self.filename, "r") as f: with open(self.filename) as f:
# Load JSON data from file and update class attributes with these # Load JSON data from file and update class attributes with these
# values # values
return cast(dict[str, Any], json.load(f)) return cast(dict[str, Any], json.load(f))

View File

@ -95,7 +95,7 @@ class TaskManager:
self._set_start_and_stop_for_async_methods() self._set_start_and_stop_for_async_methods()
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901 def _set_start_and_stop_for_async_methods(self) -> None:
# inspect the methods of the class # inspect the methods of the class
for name, method in inspect.getmembers( for name, method in inspect.getmembers(
self.service, predicate=inspect.iscoroutinefunction self.service, predicate=inspect.iscoroutinefunction
@ -119,11 +119,11 @@ class TaskManager:
self._initiate_task_startup() self._initiate_task_startup()
attrs = get_class_and_instance_attributes(self.service) attrs = get_class_and_instance_attributes(self.service)
for _, attr_value in attrs.items(): for attr_value in attrs.values():
if isinstance(attr_value, AbstractDataService): if isinstance(attr_value, AbstractDataService):
attr_value._task_manager.start_autostart_tasks() attr_value._task_manager.start_autostart_tasks()
elif isinstance(attr_value, DataServiceList): elif isinstance(attr_value, DataServiceList):
for i, item in enumerate(attr_value): for item in attr_value:
if isinstance(item, AbstractDataService): if isinstance(item, AbstractDataService):
item._task_manager.start_autostart_tasks() item._task_manager.start_autostart_tasks()
@ -146,7 +146,7 @@ class TaskManager:
return stop_task return stop_task
def _make_start_task( # noqa def _make_start_task( # noqa: C901
self, name: str, method: Callable[..., Any] self, name: str, method: Callable[..., Any]
) -> Callable[..., Any]: ) -> Callable[..., Any]:
""" """
@ -162,7 +162,7 @@ class TaskManager:
""" """
@wraps(method) @wraps(method)
def start_task(*args: Any, **kwargs: Any) -> None: def start_task(*args: Any, **kwargs: Any) -> None: # noqa: C901
def task_done_callback(task: asyncio.Task[None], name: str) -> None: def task_done_callback(task: asyncio.Task[None], name: str) -> None:
"""Handles tasks that have finished. """Handles tasks that have finished.
@ -210,7 +210,7 @@ class TaskManager:
# with the 'kwargs' dictionary. If a parameter is specified in both # with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used. # 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = { kwargs_updated = {
**dict(zip(parameter_names, args_padded)), **dict(zip(parameter_names, args_padded, strict=True)),
**kwargs, **kwargs,
} }

View File

@ -10,7 +10,7 @@ from types import FrameType
from typing import Any, Optional, Protocol, TypedDict from typing import Any, Optional, Protocol, TypedDict
import uvicorn import uvicorn
from rpyc import ForkingServer, ThreadedServer # type: ignore from rpyc import ForkingServer, ThreadedServer # type: ignore[import-untyped]
from uvicorn.server import HANDLED_SIGNALS from uvicorn.server import HANDLED_SIGNALS
from pydase import DataService from pydase import DataService
@ -164,7 +164,7 @@ class Server:
Additional keyword arguments. Additional keyword arguments.
""" """
def __init__( # noqa: CFQ002 def __init__(
self, self,
service: DataService, service: DataService,
host: str = "0.0.0.0", host: str = "0.0.0.0",
@ -319,7 +319,7 @@ class Server:
async def notify() -> None: async def notify() -> None:
try: try:
await self._wapi.sio.emit( # type: ignore await self._wapi.sio.emit( # type: ignore[reportUnknownMemberType]
"notify", "notify",
{ {
"data": { "data": {
@ -338,7 +338,7 @@ class Server:
# overwrite uvicorn's signal handlers, otherwise it will bogart SIGINT and # overwrite uvicorn's signal handlers, otherwise it will bogart SIGINT and
# SIGTERM, which makes it impossible to escape out of # SIGTERM, which makes it impossible to escape out of
web_server.install_signal_handlers = lambda: None # type: ignore web_server.install_signal_handlers = lambda: None # type: ignore[method-assign]
future_or_task = self._loop.create_task(web_server.serve()) future_or_task = self._loop.create_task(web_server.serve())
self.servers["web"] = future_or_task self.servers["web"] = future_or_task
@ -413,7 +413,7 @@ class Server:
async def emit_exception() -> None: async def emit_exception() -> None:
try: try:
await self._wapi.sio.emit( # type: ignore await self._wapi.sio.emit( # type: ignore[reportUnknownMemberType]
"exception", "exception",
{ {
"data": { "data": {

View File

@ -2,7 +2,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any, TypedDict from typing import Any, TypedDict
import socketio # type: ignore import socketio # type: ignore[import-untyped]
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -70,7 +70,7 @@ class WebAPI:
__sio_app: socketio.ASGIApp __sio_app: socketio.ASGIApp
__fastapi_app: FastAPI __fastapi_app: FastAPI
def __init__( # noqa: CFQ002 def __init__(
self, self,
service: DataService, service: DataService,
state_manager: StateManager, state_manager: StateManager,
@ -80,7 +80,7 @@ class WebAPI:
info: dict[str, Any] = {}, info: dict[str, Any] = {},
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
): ) -> None:
self.service = service self.service = service
self.state_manager = state_manager self.state_manager = state_manager
self.frontend = frontend self.frontend = frontend
@ -105,7 +105,7 @@ class WebAPI:
else: else:
sio = socketio.AsyncServer(async_mode="asgi") sio = socketio.AsyncServer(async_mode="asgi")
@sio.event # type: ignore @sio.event # type: ignore[reportUnknownMemberType]
def set_attribute(sid: str, data: UpdateDict) -> Any: def set_attribute(sid: str, data: UpdateDict) -> Any:
logger.debug("Received frontend update: %s", data) logger.debug("Received frontend update: %s", data)
path_list = [*data["parent_path"].split("."), data["name"]] path_list = [*data["parent_path"].split("."), data["name"]]
@ -115,7 +115,7 @@ class WebAPI:
path=path, value=data["value"] path=path, value=data["value"]
) )
@sio.event # type: ignore @sio.event # type: ignore[reportUnknownMemberType]
def run_method(sid: str, data: RunMethodDict) -> Any: def run_method(sid: str, data: RunMethodDict) -> Any:
logger.debug("Running method: %s", data) logger.debug("Running method: %s", data)
path_list = [*data["parent_path"].split("."), data["name"]] path_list = [*data["parent_path"].split("."), data["name"]]
@ -126,7 +126,7 @@ class WebAPI:
self.__sio = sio self.__sio = sio
self.__sio_app = socketio.ASGIApp(self.__sio) self.__sio_app = socketio.ASGIApp(self.__sio)
def setup_fastapi_app(self) -> None: # noqa def setup_fastapi_app(self) -> None: # noqa: C901
app = FastAPI() app = FastAPI()
if self.enable_CORS: if self.enable_CORS:

View File

@ -15,7 +15,7 @@ class QuantityDict(TypedDict):
def convert_to_quantity( def convert_to_quantity(
value: QuantityDict | float | int | Quantity, unit: str = "" value: QuantityDict | float | Quantity, unit: str = ""
) -> Quantity: ) -> Quantity:
""" """
Convert a given value into a pint.Quantity object with the specified unit. Convert a given value into a pint.Quantity object with the specified unit.
@ -53,4 +53,4 @@ def convert_to_quantity(
quantity = float(value["magnitude"]) * Unit(value["unit"]) quantity = float(value["magnitude"]) * Unit(value["unit"])
else: else:
quantity = value quantity = value
return quantity # type: ignore return quantity # type: ignore[reportUnknownMemberType]

View File

@ -4,7 +4,7 @@ import sys
from copy import copy from copy import copy
from typing import Optional from typing import Optional
import socketio import socketio # type: ignore[import-untyped]
import uvicorn.logging import uvicorn.logging
from uvicorn.config import LOGGING_CONFIG from uvicorn.config import LOGGING_CONFIG
@ -33,7 +33,7 @@ class DefaultFormatter(uvicorn.logging.ColourizedFormatter):
return logging.Formatter.formatMessage(self, recordcopy) return logging.Formatter.formatMessage(self, recordcopy)
def should_use_colors(self) -> bool: def should_use_colors(self) -> bool:
return sys.stderr.isatty() # pragma: no cover return sys.stderr.isatty()
class SocketIOHandler(logging.Handler): class SocketIOHandler(logging.Handler):

View File

@ -1,4 +1,4 @@
from importlib.metadata import distribution from importlib.metadata import distribution
__version__ = distribution("pydase").version __version__ = distribution("pydase").version
__major__, __minor__, __patch__ = [int(v) for v in __version__.split(".")] __major__, __minor__, __patch__ = (int(v) for v in __version__.split("."))