From 4aee899dbed858b79feebdc2f38139df0f2f8171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Wed, 6 Mar 2024 17:49:46 +0100 Subject: [PATCH] updates type hints for serialized objects --- src/pydase/data_service/data_service.py | 3 +- src/pydase/data_service/data_service_cache.py | 27 +++- .../data_service/data_service_observer.py | 11 +- src/pydase/data_service/state_manager.py | 24 ++- src/pydase/server/web_server/sio_setup.py | 3 +- src/pydase/server/web_server/web_server.py | 6 +- src/pydase/utils/serializer.py | 143 ++++++++++++------ tests/utils/test_serializer.py | 5 +- 8 files changed, 148 insertions(+), 74 deletions(-) diff --git a/src/pydase/data_service/data_service.py b/src/pydase/data_service/data_service.py index ab195c0..a6b29e6 100644 --- a/src/pydase/data_service/data_service.py +++ b/src/pydase/data_service/data_service.py @@ -17,6 +17,7 @@ from pydase.utils.helpers import ( is_property_attribute, ) from pydase.utils.serializer import ( + SerializedObject, Serializer, ) @@ -125,7 +126,7 @@ class DataService(rpyc.Service, AbstractDataService): # allow all other attributes setattr(self, name, value) - def serialize(self) -> dict[str, dict[str, Any]]: + def serialize(self) -> SerializedObject: """ Serializes the instance into a dictionary, preserving the structure of the instance. diff --git a/src/pydase/data_service/data_service_cache.py b/src/pydase/data_service/data_service_cache.py index d25f352..2a45d68 100644 --- a/src/pydase/data_service/data_service_cache.py +++ b/src/pydase/data_service/data_service_cache.py @@ -1,9 +1,10 @@ import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from pydase.utils.serializer import ( SerializationPathError, SerializationValueError, + SerializedObject, get_nested_dict_by_path, set_nested_value_by_path, ) @@ -16,12 +17,12 @@ logger = logging.getLogger(__name__) class DataServiceCache: def __init__(self, service: "DataService") -> None: - self._cache: dict[str, Any] = {} + self._cache: SerializedObject self.service = service self._initialize_cache() @property - def cache(self) -> dict[str, Any]: + def cache(self) -> SerializedObject: return self._cache def _initialize_cache(self) -> None: @@ -30,10 +31,22 @@ class DataServiceCache: self._cache = self.service.serialize() def update_cache(self, full_access_path: str, value: Any) -> None: - set_nested_value_by_path(self._cache["value"], full_access_path, value) + set_nested_value_by_path( + cast(dict[str, SerializedObject], self._cache["value"]), + full_access_path, + value, + ) - def get_value_dict_from_cache(self, full_access_path: str) -> dict[str, Any]: + def get_value_dict_from_cache(self, full_access_path: str) -> SerializedObject: try: - return get_nested_dict_by_path(self._cache["value"], full_access_path) + return get_nested_dict_by_path( + cast(dict[str, SerializedObject], self._cache["value"]), + full_access_path, + ) except (SerializationPathError, SerializationValueError, KeyError): - return {} + return { + "value": None, + "type": None, + "doc": None, + "readonly": False, + } diff --git a/src/pydase/data_service/data_service_observer.py b/src/pydase/data_service/data_service_observer.py index ec2262e..8d2e673 100644 --- a/src/pydase/data_service/data_service_observer.py +++ b/src/pydase/data_service/data_service_observer.py @@ -9,7 +9,7 @@ from pydase.observer_pattern.observer.property_observer import ( PropertyObserver, ) from pydase.utils.helpers import get_object_attr_from_path_list -from pydase.utils.serializer import dump +from pydase.utils.serializer import SerializedObject, dump logger = logging.getLogger(__name__) @@ -18,7 +18,7 @@ class DataServiceObserver(PropertyObserver): def __init__(self, state_manager: StateManager) -> None: self.state_manager = state_manager self._notification_callbacks: list[ - Callable[[str, Any, dict[str, Any]], None] + Callable[[str, Any, SerializedObject], None] ] = [] super().__init__(state_manager.service) @@ -59,7 +59,10 @@ class DataServiceObserver(PropertyObserver): self._notify_dependent_property_changes(full_access_path) def _update_cache_value( - self, full_access_path: str, value: Any, cached_value_dict: dict[str, Any] + self, + full_access_path: str, + value: Any, + cached_value_dict: SerializedObject | dict[str, Any], ) -> None: value_dict = dump(value) if cached_value_dict != {}: @@ -93,7 +96,7 @@ class DataServiceObserver(PropertyObserver): ) def add_notification_callback( - self, callback: Callable[[str, Any, dict[str, Any]], None] + self, callback: Callable[[str, Any, SerializedObject], None] ) -> None: """ Registers a callback function to be invoked upon attribute changes in the diff --git a/src/pydase/data_service/state_manager.py b/src/pydase/data_service/state_manager.py index 5ad2eac..1b3c25f 100644 --- a/src/pydase/data_service/state_manager.py +++ b/src/pydase/data_service/state_manager.py @@ -13,6 +13,7 @@ from pydase.utils.helpers import ( parse_list_attr_and_index, ) from pydase.utils.serializer import ( + SerializedObject, dump, generate_serialized_data_paths, get_nested_dict_by_path, @@ -114,10 +115,17 @@ class StateManager: self._data_service_cache = DataServiceCache(self.service) @property - def cache(self) -> dict[str, Any]: + def cache(self) -> SerializedObject: """Returns the cached DataService state.""" return self._data_service_cache.cache + @property + def cache_value(self) -> dict[str, SerializedObject]: + """Returns the "value" value of the DataService serialization.""" + return cast( + dict[str, SerializedObject], self._data_service_cache.cache["value"] + ) + def save_state(self) -> None: """ Saves the DataService's current state to a JSON file defined by `self.filename`. @@ -126,7 +134,7 @@ class StateManager: if self.filename is not None: with open(self.filename, "w") as f: - json.dump(self.cache["value"], f, indent=4) + json.dump(self.cache_value, f, indent=4) else: logger.info( "State manager was not initialised with a filename. Skipping " @@ -191,7 +199,7 @@ class StateManager: value: The new value to set for the attribute. """ - current_value_dict = get_nested_dict_by_path(self.cache["value"], path) + current_value_dict = get_nested_dict_by_path(self.cache_value, path) # This will also filter out methods as they are 'read-only' if current_value_dict["readonly"]: @@ -216,10 +224,12 @@ class StateManager: return dump(value_object)["value"] != current_value def __convert_value_if_needed( - self, value: Any, current_value_dict: dict[str, Any] + self, value: Any, current_value_dict: SerializedObject ) -> Any: if current_value_dict["type"] == "Quantity": - return u.convert_to_quantity(value, current_value_dict["value"]["unit"]) + return u.convert_to_quantity( + value, cast(dict[str, Any], current_value_dict["value"])["unit"] + ) if current_value_dict["type"] == "float" and not isinstance(value, float): return float(value) return value @@ -234,7 +244,7 @@ class StateManager: # Update path to reflect the attribute without list indices path = ".".join([*parent_path_list, attr_name]) - attr_cache_type = get_nested_dict_by_path(self.cache["value"], path)["type"] + attr_cache_type = get_nested_dict_by_path(self.cache_value, path)["type"] # Traverse the object according to the path parts target_obj = get_object_attr_from_path_list(self.service, parent_path_list) @@ -273,7 +283,7 @@ class StateManager: return has_decorator cached_serialization_dict = get_nested_dict_by_path( - self.cache["value"], full_access_path + self.cache_value, full_access_path ) if cached_serialization_dict["value"] == "method": diff --git a/src/pydase/server/web_server/sio_setup.py b/src/pydase/server/web_server/sio_setup.py index 4dc40f5..0e8f581 100644 --- a/src/pydase/server/web_server/sio_setup.py +++ b/src/pydase/server/web_server/sio_setup.py @@ -9,6 +9,7 @@ from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.state_manager import StateManager from pydase.utils.helpers import get_object_attr_from_path_list from pydase.utils.logging import SocketIOHandler +from pydase.utils.serializer import SerializedObject logger = logging.getLogger(__name__) @@ -93,7 +94,7 @@ def setup_sio_server( # Add notification callback to observer def sio_callback( - full_access_path: str, value: Any, cached_value_dict: dict[str, Any] + full_access_path: str, value: Any, cached_value_dict: SerializedObject ) -> None: if cached_value_dict != {}: diff --git a/src/pydase/server/web_server/web_server.py b/src/pydase/server/web_server/web_server.py index d72a0db..ce3e8c5 100644 --- a/src/pydase/server/web_server/web_server.py +++ b/src/pydase/server/web_server/web_server.py @@ -16,7 +16,7 @@ from pydase.data_service.data_service_observer import DataServiceObserver from pydase.server.web_server.sio_setup import ( setup_sio_server, ) -from pydase.utils.serializer import generate_serialized_data_paths +from pydase.utils.serializer import SerializedObject, generate_serialized_data_paths from pydase.version import __version__ logger = logging.getLogger(__name__) @@ -126,7 +126,7 @@ class WebServer: @property def web_settings(self) -> dict[str, dict[str, Any]]: current_web_settings = self._get_web_settings_from_file() - for path in generate_serialized_data_paths(self.state_manager.cache["value"]): + for path in generate_serialized_data_paths(self.state_manager.cache_value): if path in current_web_settings: continue @@ -160,7 +160,7 @@ class WebServer: return type(self.service).__name__ @app.get("/service-properties") - def service_properties() -> dict[str, Any]: + def service_properties() -> SerializedObject: return self.state_manager.cache @app.get("/web-settings") diff --git a/src/pydase/utils/serializer.py b/src/pydase/utils/serializer.py index c54504c..9a3e51b 100644 --- a/src/pydase/utils/serializer.py +++ b/src/pydase/utils/serializer.py @@ -1,9 +1,15 @@ +from __future__ import annotations + import inspect import logging import sys -from collections.abc import Callable from enum import Enum -from typing import Any, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict, cast + +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired import pydase.units as u from pydase.data_service.abstract_data_service import AbstractDataService @@ -16,6 +22,9 @@ from pydase.utils.helpers import ( render_in_frontend, ) +if TYPE_CHECKING: + from collections.abc import Callable + logger = logging.getLogger(__name__) @@ -27,10 +36,31 @@ class SerializationValueError(Exception): pass +class SignatureDict(TypedDict): + parameters: dict[str, dict[str, Any]] + return_annotation: dict[str, Any] + + +SerializedObject = TypedDict( + "SerializedObject", + { + "name": NotRequired[str], + "value": "list[SerializedObject] | float | int | str | bool | dict[str, Any] | None", # noqa: E501 + "type": str | None, + "doc": str | None, + "readonly": bool, + "enum": NotRequired[dict[str, Any]], + "async": NotRequired[bool], + "signature": NotRequired[SignatureDict], + "frontend_render": NotRequired[bool], + }, +) + + class Serializer: @staticmethod - def serialize_object(obj: Any) -> dict[str, Any]: - result: dict[str, Any] = {} + def serialize_object(obj: Any) -> SerializedObject: + result: SerializedObject if isinstance(obj, AbstractDataService): result = Serializer._serialize_data_service(obj) @@ -67,7 +97,7 @@ class Serializer: return result @staticmethod - def _serialize_enum(obj: Enum) -> dict[str, Any]: + def _serialize_enum(obj: Enum) -> SerializedObject: import pydase.components.coloured_enum value = obj.name @@ -91,7 +121,7 @@ class Serializer: } @staticmethod - def _serialize_quantity(obj: u.Quantity) -> dict[str, Any]: + def _serialize_quantity(obj: u.Quantity) -> SerializedObject: obj_type = "Quantity" readonly = False doc = get_attribute_doc(obj) @@ -104,7 +134,7 @@ class Serializer: } @staticmethod - def _serialize_dict(obj: dict[str, Any]) -> dict[str, Any]: + def _serialize_dict(obj: dict[str, Any]) -> SerializedObject: obj_type = "dict" readonly = False doc = get_attribute_doc(obj) @@ -117,7 +147,7 @@ class Serializer: } @staticmethod - def _serialize_list(obj: list[Any]) -> dict[str, Any]: + def _serialize_list(obj: list[Any]) -> SerializedObject: obj_type = "list" readonly = False doc = get_attribute_doc(obj) @@ -130,7 +160,7 @@ class Serializer: } @staticmethod - def _serialize_method(obj: Callable[..., Any]) -> dict[str, Any]: + def _serialize_method(obj: Callable[..., Any]) -> SerializedObject: obj_type = "method" value = None readonly = True @@ -141,16 +171,12 @@ class Serializer: sig = inspect.signature(obj) sig.return_annotation - class SignatureDict(TypedDict): - parameters: dict[str, dict[str, Any]] - return_annotation: dict[str, Any] - signature: SignatureDict = {"parameters": {}, "return_annotation": {}} for k, v in sig.parameters.items(): signature["parameters"][k] = { "annotation": str(v.annotation), - "default": dump(v.default) if v.default != inspect._empty else {}, + "default": {} if v.default == inspect._empty else dump(v.default), } return { @@ -164,7 +190,7 @@ class Serializer: } @staticmethod - def _serialize_data_service(obj: AbstractDataService) -> dict[str, Any]: + def _serialize_data_service(obj: AbstractDataService) -> SerializedObject: readonly = False doc = get_attribute_doc(obj) obj_type = "DataService" @@ -184,7 +210,7 @@ class Serializer: # Get the difference between the two sets derived_only_attr_set = obj_attr_set - data_service_attr_set - value = {} + value: dict[str, SerializedObject] = {} # Iterate over attributes, properties, class attributes, and methods for key in sorted(derived_only_attr_set): @@ -224,12 +250,12 @@ class Serializer: } -def dump(obj: Any) -> dict[str, Any]: +def dump(obj: Any) -> SerializedObject: return Serializer.serialize_object(obj) def set_nested_value_by_path( - serialization_dict: dict[str, Any], path: str, value: Any + serialization_dict: dict[str, SerializedObject], path: str, value: Any ) -> None: """ Set a value in a nested dictionary structure, which conforms to the serialization @@ -251,16 +277,18 @@ def set_nested_value_by_path( """ parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1] - current_dict: dict[str, Any] = serialization_dict + current_dict: dict[str, SerializedObject] = serialization_dict try: for path_part in parent_path_parts: - current_dict = get_next_level_dict_by_key( + next_level_serialized_object = get_next_level_dict_by_key( current_dict, path_part, allow_append=False ) - current_dict = current_dict["value"] + current_dict = cast( + dict[str, SerializedObject], next_level_serialized_object["value"] + ) - current_dict = get_next_level_dict_by_key( + next_level_serialized_object = get_next_level_dict_by_key( current_dict, attr_name, allow_append=True ) except (SerializationPathError, SerializationValueError, KeyError) as e: @@ -270,47 +298,53 @@ def set_nested_value_by_path( serialized_value = dump(value) keys_to_keep = set(serialized_value.keys()) - if current_dict == {}: # adding an attribute / element to a list or dict + if ( + next_level_serialized_object == {} + ): # adding an attribute / element to a list or dict pass - elif current_dict["type"] == "method": # state change of task - keys_to_keep = set(current_dict.keys()) + elif next_level_serialized_object["type"] == "method": # state change of task + keys_to_keep = set(next_level_serialized_object.keys()) - serialized_value = current_dict - serialized_value["value"] = value.name if isinstance(value, Enum) else None + serialized_value = {} # type: ignore + next_level_serialized_object["value"] = ( + value.name if isinstance(value, Enum) else None + ) else: # attribute-specific information should not be overwritten by new value - serialized_value.pop("readonly") - serialized_value.pop("doc") + serialized_value.pop("readonly") # type: ignore + serialized_value.pop("doc") # type: ignore - current_dict.update(serialized_value) + next_level_serialized_object.update(serialized_value) # removes keys that are not present in the serialized new value - for key in list(current_dict.keys()): + for key in list(next_level_serialized_object.keys()): if key not in keys_to_keep: - current_dict.pop(key, None) + next_level_serialized_object.pop(key, None) # type: ignore def get_nested_dict_by_path( - serialization_dict: dict[str, Any], + serialization_dict: dict[str, SerializedObject], path: str, -) -> dict[str, Any]: +) -> SerializedObject: parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1] - current_dict: dict[str, Any] = serialization_dict + current_dict: dict[str, SerializedObject] = serialization_dict for path_part in parent_path_parts: - current_dict = get_next_level_dict_by_key( + next_level_serialized_object = get_next_level_dict_by_key( current_dict, path_part, allow_append=False ) - current_dict = current_dict["value"] + current_dict = cast( + dict[str, SerializedObject], next_level_serialized_object["value"] + ) return get_next_level_dict_by_key(current_dict, attr_name, allow_append=False) def get_next_level_dict_by_key( - serialization_dict: dict[str, Any], + serialization_dict: dict[str, SerializedObject], attr_name: str, *, allow_append: bool = False, -) -> dict[str, Any]: +) -> SerializedObject: """ Retrieve a nested dictionary entry or list item from a data structure serialized with `pydase.utils.serializer.Serializer`. @@ -335,14 +369,25 @@ def get_next_level_dict_by_key( try: if index is not None: - serialization_dict = serialization_dict[attr_name]["value"][index] + next_level_serialized_object = cast( + list[SerializedObject], serialization_dict[attr_name]["value"] + )[index] else: - serialization_dict = serialization_dict[attr_name] + next_level_serialized_object = serialization_dict[attr_name] except IndexError as e: - if allow_append and index == len(serialization_dict[attr_name]["value"]): + if ( + index is not None + and allow_append + and index + == len(cast(list[SerializedObject], serialization_dict[attr_name]["value"])) + ): # Appending to list - serialization_dict[attr_name]["value"].append({}) - serialization_dict = serialization_dict[attr_name]["value"][index] + cast(list[SerializedObject], serialization_dict[attr_name]["value"]).append( + {} # type: ignore + ) + next_level_serialized_object = cast( + list[SerializedObject], serialization_dict[attr_name]["value"] + )[index] else: raise SerializationPathError( f"Error occured trying to change '{attr_name}[{index}]': {e}" @@ -354,17 +399,17 @@ def get_next_level_dict_by_key( "a 'value' key." ) - if not isinstance(serialization_dict, dict): + if not isinstance(next_level_serialized_object, dict): raise SerializationValueError( f"Expected a dictionary at '{attr_name}', but found type " - f"'{type(serialization_dict).__name__}' instead." + f"'{type(next_level_serialized_object).__name__}' instead." ) - return serialization_dict + return next_level_serialized_object def generate_serialized_data_paths( - data: dict[str, dict[str, Any]], parent_path: str = "" + data: dict[str, Any], parent_path: str = "" ) -> list[str]: """ Generate a list of access paths for all attributes in a dictionary representing @@ -404,7 +449,7 @@ def generate_serialized_data_paths( return paths -def serialized_dict_is_nested_object(serialized_dict: dict[str, Any]) -> bool: +def serialized_dict_is_nested_object(serialized_dict: SerializedObject) -> bool: return ( serialized_dict["type"] != "Quantity" and isinstance(serialized_dict["value"], dict) diff --git a/tests/utils/test_serializer.py b/tests/utils/test_serializer.py index cc47707..93e7f32 100644 --- a/tests/utils/test_serializer.py +++ b/tests/utils/test_serializer.py @@ -11,6 +11,7 @@ from pydase.data_service.task_manager import TaskStatus from pydase.utils.decorators import frontend from pydase.utils.serializer import ( SerializationPathError, + SerializedObject, dump, get_nested_dict_by_path, get_next_level_dict_by_key, @@ -464,12 +465,12 @@ def test_update_task_state(setup_dict: dict[str, Any]) -> None: } -def test_update_list_entry(setup_dict: dict[str, Any]) -> None: +def test_update_list_entry(setup_dict: dict[str, SerializedObject]) -> None: set_nested_value_by_path(setup_dict, "attr_list[1]", 20) assert setup_dict["attr_list"]["value"][1]["value"] == 20 -def test_update_list_append(setup_dict: dict[str, Any]) -> None: +def test_update_list_append(setup_dict: dict[str, SerializedObject]) -> None: set_nested_value_by_path(setup_dict, "attr_list[3]", 20) assert setup_dict["attr_list"]["value"][3]["value"] == 20