Merge pull request #111 from tiqi-group/refactor/updates_serialized_object_type_hints

updates type hints for serialized objects
This commit is contained in:
Mose Müller 2024-03-06 18:27:21 +01:00 committed by GitHub
commit 390a375777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 148 additions and 74 deletions

View File

@ -17,6 +17,7 @@ from pydase.utils.helpers import (
is_property_attribute, is_property_attribute,
) )
from pydase.utils.serializer import ( from pydase.utils.serializer import (
SerializedObject,
Serializer, Serializer,
) )
@ -125,7 +126,7 @@ class DataService(rpyc.Service, AbstractDataService):
# allow all other attributes # allow all other attributes
setattr(self, name, value) setattr(self, name, value)
def serialize(self) -> dict[str, dict[str, Any]]: def serialize(self) -> SerializedObject:
""" """
Serializes the instance into a dictionary, preserving the structure of the Serializes the instance into a dictionary, preserving the structure of the
instance. instance.

View File

@ -1,9 +1,10 @@
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
from pydase.utils.serializer import ( from pydase.utils.serializer import (
SerializationPathError, SerializationPathError,
SerializationValueError, SerializationValueError,
SerializedObject,
get_nested_dict_by_path, get_nested_dict_by_path,
set_nested_value_by_path, set_nested_value_by_path,
) )
@ -16,12 +17,12 @@ logger = logging.getLogger(__name__)
class DataServiceCache: class DataServiceCache:
def __init__(self, service: "DataService") -> None: def __init__(self, service: "DataService") -> None:
self._cache: dict[str, Any] = {} self._cache: SerializedObject
self.service = service self.service = service
self._initialize_cache() self._initialize_cache()
@property @property
def cache(self) -> dict[str, Any]: def cache(self) -> SerializedObject:
return self._cache return self._cache
def _initialize_cache(self) -> None: def _initialize_cache(self) -> None:
@ -30,10 +31,22 @@ class DataServiceCache:
self._cache = self.service.serialize() self._cache = self.service.serialize()
def update_cache(self, full_access_path: str, value: Any) -> None: def update_cache(self, full_access_path: str, value: Any) -> None:
set_nested_value_by_path(self._cache["value"], full_access_path, value) set_nested_value_by_path(
cast(dict[str, SerializedObject], self._cache["value"]),
full_access_path,
value,
)
def get_value_dict_from_cache(self, full_access_path: str) -> dict[str, Any]: def get_value_dict_from_cache(self, full_access_path: str) -> SerializedObject:
try: try:
return get_nested_dict_by_path(self._cache["value"], full_access_path) return get_nested_dict_by_path(
cast(dict[str, SerializedObject], self._cache["value"]),
full_access_path,
)
except (SerializationPathError, SerializationValueError, KeyError): except (SerializationPathError, SerializationValueError, KeyError):
return {} return {
"value": None,
"type": None,
"doc": None,
"readonly": False,
}

View File

@ -9,7 +9,7 @@ from pydase.observer_pattern.observer.property_observer import (
PropertyObserver, PropertyObserver,
) )
from pydase.utils.helpers import get_object_attr_from_path_list from pydase.utils.helpers import get_object_attr_from_path_list
from pydase.utils.serializer import dump from pydase.utils.serializer import SerializedObject, dump
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,7 +18,7 @@ class DataServiceObserver(PropertyObserver):
def __init__(self, state_manager: StateManager) -> None: def __init__(self, state_manager: StateManager) -> None:
self.state_manager = state_manager self.state_manager = state_manager
self._notification_callbacks: list[ self._notification_callbacks: list[
Callable[[str, Any, dict[str, Any]], None] Callable[[str, Any, SerializedObject], None]
] = [] ] = []
super().__init__(state_manager.service) super().__init__(state_manager.service)
@ -59,7 +59,10 @@ class DataServiceObserver(PropertyObserver):
self._notify_dependent_property_changes(full_access_path) self._notify_dependent_property_changes(full_access_path)
def _update_cache_value( def _update_cache_value(
self, full_access_path: str, value: Any, cached_value_dict: dict[str, Any] self,
full_access_path: str,
value: Any,
cached_value_dict: SerializedObject | dict[str, Any],
) -> None: ) -> None:
value_dict = dump(value) value_dict = dump(value)
if cached_value_dict != {}: if cached_value_dict != {}:
@ -93,7 +96,7 @@ class DataServiceObserver(PropertyObserver):
) )
def add_notification_callback( def add_notification_callback(
self, callback: Callable[[str, Any, dict[str, Any]], None] self, callback: Callable[[str, Any, SerializedObject], None]
) -> None: ) -> None:
""" """
Registers a callback function to be invoked upon attribute changes in the Registers a callback function to be invoked upon attribute changes in the

View File

@ -13,6 +13,7 @@ from pydase.utils.helpers import (
parse_list_attr_and_index, parse_list_attr_and_index,
) )
from pydase.utils.serializer import ( from pydase.utils.serializer import (
SerializedObject,
dump, dump,
generate_serialized_data_paths, generate_serialized_data_paths,
get_nested_dict_by_path, get_nested_dict_by_path,
@ -114,10 +115,17 @@ class StateManager:
self._data_service_cache = DataServiceCache(self.service) self._data_service_cache = DataServiceCache(self.service)
@property @property
def cache(self) -> dict[str, Any]: def cache(self) -> SerializedObject:
"""Returns the cached DataService state.""" """Returns the cached DataService state."""
return self._data_service_cache.cache return self._data_service_cache.cache
@property
def cache_value(self) -> dict[str, SerializedObject]:
"""Returns the "value" value of the DataService serialization."""
return cast(
dict[str, SerializedObject], self._data_service_cache.cache["value"]
)
def save_state(self) -> None: def save_state(self) -> None:
""" """
Saves the DataService's current state to a JSON file defined by `self.filename`. Saves the DataService's current state to a JSON file defined by `self.filename`.
@ -126,7 +134,7 @@ class StateManager:
if self.filename is not None: if self.filename is not None:
with open(self.filename, "w") as f: with open(self.filename, "w") as f:
json.dump(self.cache["value"], f, indent=4) json.dump(self.cache_value, f, indent=4)
else: else:
logger.info( logger.info(
"State manager was not initialised with a filename. Skipping " "State manager was not initialised with a filename. Skipping "
@ -191,7 +199,7 @@ class StateManager:
value: The new value to set for the attribute. value: The new value to set for the attribute.
""" """
current_value_dict = get_nested_dict_by_path(self.cache["value"], path) current_value_dict = get_nested_dict_by_path(self.cache_value, path)
# This will also filter out methods as they are 'read-only' # This will also filter out methods as they are 'read-only'
if current_value_dict["readonly"]: if current_value_dict["readonly"]:
@ -216,10 +224,12 @@ class StateManager:
return dump(value_object)["value"] != current_value return dump(value_object)["value"] != current_value
def __convert_value_if_needed( def __convert_value_if_needed(
self, value: Any, current_value_dict: dict[str, Any] self, value: Any, current_value_dict: SerializedObject
) -> Any: ) -> Any:
if current_value_dict["type"] == "Quantity": if current_value_dict["type"] == "Quantity":
return u.convert_to_quantity(value, current_value_dict["value"]["unit"]) return u.convert_to_quantity(
value, cast(dict[str, Any], current_value_dict["value"])["unit"]
)
if current_value_dict["type"] == "float" and not isinstance(value, float): if current_value_dict["type"] == "float" and not isinstance(value, float):
return float(value) return float(value)
return value return value
@ -234,7 +244,7 @@ class StateManager:
# Update path to reflect the attribute without list indices # Update path to reflect the attribute without list indices
path = ".".join([*parent_path_list, attr_name]) path = ".".join([*parent_path_list, attr_name])
attr_cache_type = get_nested_dict_by_path(self.cache["value"], path)["type"] attr_cache_type = get_nested_dict_by_path(self.cache_value, path)["type"]
# Traverse the object according to the path parts # Traverse the object according to the path parts
target_obj = get_object_attr_from_path_list(self.service, parent_path_list) target_obj = get_object_attr_from_path_list(self.service, parent_path_list)
@ -273,7 +283,7 @@ class StateManager:
return has_decorator return has_decorator
cached_serialization_dict = get_nested_dict_by_path( cached_serialization_dict = get_nested_dict_by_path(
self.cache["value"], full_access_path self.cache_value, full_access_path
) )
if cached_serialization_dict["value"] == "method": if cached_serialization_dict["value"] == "method":

View File

@ -9,6 +9,7 @@ from pydase.data_service.data_service_observer import DataServiceObserver
from pydase.data_service.state_manager import StateManager from pydase.data_service.state_manager import StateManager
from pydase.utils.helpers import get_object_attr_from_path_list from pydase.utils.helpers import get_object_attr_from_path_list
from pydase.utils.logging import SocketIOHandler from pydase.utils.logging import SocketIOHandler
from pydase.utils.serializer import SerializedObject
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -93,7 +94,7 @@ def setup_sio_server(
# Add notification callback to observer # Add notification callback to observer
def sio_callback( def sio_callback(
full_access_path: str, value: Any, cached_value_dict: dict[str, Any] full_access_path: str, value: Any, cached_value_dict: SerializedObject
) -> None: ) -> None:
if cached_value_dict != {}: if cached_value_dict != {}:

View File

@ -16,7 +16,7 @@ from pydase.data_service.data_service_observer import DataServiceObserver
from pydase.server.web_server.sio_setup import ( from pydase.server.web_server.sio_setup import (
setup_sio_server, setup_sio_server,
) )
from pydase.utils.serializer import generate_serialized_data_paths from pydase.utils.serializer import SerializedObject, generate_serialized_data_paths
from pydase.version import __version__ from pydase.version import __version__
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,7 +126,7 @@ class WebServer:
@property @property
def web_settings(self) -> dict[str, dict[str, Any]]: def web_settings(self) -> dict[str, dict[str, Any]]:
current_web_settings = self._get_web_settings_from_file() current_web_settings = self._get_web_settings_from_file()
for path in generate_serialized_data_paths(self.state_manager.cache["value"]): for path in generate_serialized_data_paths(self.state_manager.cache_value):
if path in current_web_settings: if path in current_web_settings:
continue continue
@ -160,7 +160,7 @@ class WebServer:
return type(self.service).__name__ return type(self.service).__name__
@app.get("/service-properties") @app.get("/service-properties")
def service_properties() -> dict[str, Any]: def service_properties() -> SerializedObject:
return self.state_manager.cache return self.state_manager.cache
@app.get("/web-settings") @app.get("/web-settings")

View File

@ -1,9 +1,15 @@
from __future__ import annotations
import inspect import inspect
import logging import logging
import sys import sys
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Any, TypedDict from typing import TYPE_CHECKING, Any, TypedDict, cast
if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
from typing import NotRequired
import pydase.units as u import pydase.units as u
from pydase.data_service.abstract_data_service import AbstractDataService from pydase.data_service.abstract_data_service import AbstractDataService
@ -16,6 +22,9 @@ from pydase.utils.helpers import (
render_in_frontend, render_in_frontend,
) )
if TYPE_CHECKING:
from collections.abc import Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,10 +36,31 @@ class SerializationValueError(Exception):
pass pass
class SignatureDict(TypedDict):
parameters: dict[str, dict[str, Any]]
return_annotation: dict[str, Any]
SerializedObject = TypedDict(
"SerializedObject",
{
"name": NotRequired[str],
"value": "list[SerializedObject] | float | int | str | bool | dict[str, Any] | None", # noqa: E501
"type": str | None,
"doc": str | None,
"readonly": bool,
"enum": NotRequired[dict[str, Any]],
"async": NotRequired[bool],
"signature": NotRequired[SignatureDict],
"frontend_render": NotRequired[bool],
},
)
class Serializer: class Serializer:
@staticmethod @staticmethod
def serialize_object(obj: Any) -> dict[str, Any]: def serialize_object(obj: Any) -> SerializedObject:
result: dict[str, Any] = {} result: SerializedObject
if isinstance(obj, AbstractDataService): if isinstance(obj, AbstractDataService):
result = Serializer._serialize_data_service(obj) result = Serializer._serialize_data_service(obj)
@ -67,7 +97,7 @@ class Serializer:
return result return result
@staticmethod @staticmethod
def _serialize_enum(obj: Enum) -> dict[str, Any]: def _serialize_enum(obj: Enum) -> SerializedObject:
import pydase.components.coloured_enum import pydase.components.coloured_enum
value = obj.name value = obj.name
@ -91,7 +121,7 @@ class Serializer:
} }
@staticmethod @staticmethod
def _serialize_quantity(obj: u.Quantity) -> dict[str, Any]: def _serialize_quantity(obj: u.Quantity) -> SerializedObject:
obj_type = "Quantity" obj_type = "Quantity"
readonly = False readonly = False
doc = get_attribute_doc(obj) doc = get_attribute_doc(obj)
@ -104,7 +134,7 @@ class Serializer:
} }
@staticmethod @staticmethod
def _serialize_dict(obj: dict[str, Any]) -> dict[str, Any]: def _serialize_dict(obj: dict[str, Any]) -> SerializedObject:
obj_type = "dict" obj_type = "dict"
readonly = False readonly = False
doc = get_attribute_doc(obj) doc = get_attribute_doc(obj)
@ -117,7 +147,7 @@ class Serializer:
} }
@staticmethod @staticmethod
def _serialize_list(obj: list[Any]) -> dict[str, Any]: def _serialize_list(obj: list[Any]) -> SerializedObject:
obj_type = "list" obj_type = "list"
readonly = False readonly = False
doc = get_attribute_doc(obj) doc = get_attribute_doc(obj)
@ -130,7 +160,7 @@ class Serializer:
} }
@staticmethod @staticmethod
def _serialize_method(obj: Callable[..., Any]) -> dict[str, Any]: def _serialize_method(obj: Callable[..., Any]) -> SerializedObject:
obj_type = "method" obj_type = "method"
value = None value = None
readonly = True readonly = True
@ -141,16 +171,12 @@ class Serializer:
sig = inspect.signature(obj) sig = inspect.signature(obj)
sig.return_annotation sig.return_annotation
class SignatureDict(TypedDict):
parameters: dict[str, dict[str, Any]]
return_annotation: dict[str, Any]
signature: SignatureDict = {"parameters": {}, "return_annotation": {}} signature: SignatureDict = {"parameters": {}, "return_annotation": {}}
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
signature["parameters"][k] = { signature["parameters"][k] = {
"annotation": str(v.annotation), "annotation": str(v.annotation),
"default": dump(v.default) if v.default != inspect._empty else {}, "default": {} if v.default == inspect._empty else dump(v.default),
} }
return { return {
@ -164,7 +190,7 @@ class Serializer:
} }
@staticmethod @staticmethod
def _serialize_data_service(obj: AbstractDataService) -> dict[str, Any]: def _serialize_data_service(obj: AbstractDataService) -> SerializedObject:
readonly = False readonly = False
doc = get_attribute_doc(obj) doc = get_attribute_doc(obj)
obj_type = "DataService" obj_type = "DataService"
@ -184,7 +210,7 @@ class Serializer:
# Get the difference between the two sets # Get the difference between the two sets
derived_only_attr_set = obj_attr_set - data_service_attr_set derived_only_attr_set = obj_attr_set - data_service_attr_set
value = {} value: dict[str, SerializedObject] = {}
# Iterate over attributes, properties, class attributes, and methods # Iterate over attributes, properties, class attributes, and methods
for key in sorted(derived_only_attr_set): for key in sorted(derived_only_attr_set):
@ -224,12 +250,12 @@ class Serializer:
} }
def dump(obj: Any) -> dict[str, Any]: def dump(obj: Any) -> SerializedObject:
return Serializer.serialize_object(obj) return Serializer.serialize_object(obj)
def set_nested_value_by_path( def set_nested_value_by_path(
serialization_dict: dict[str, Any], path: str, value: Any serialization_dict: dict[str, SerializedObject], path: str, value: Any
) -> None: ) -> None:
""" """
Set a value in a nested dictionary structure, which conforms to the serialization Set a value in a nested dictionary structure, which conforms to the serialization
@ -251,16 +277,18 @@ def set_nested_value_by_path(
""" """
parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1] parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1]
current_dict: dict[str, Any] = serialization_dict current_dict: dict[str, SerializedObject] = serialization_dict
try: try:
for path_part in parent_path_parts: for path_part in parent_path_parts:
current_dict = get_next_level_dict_by_key( next_level_serialized_object = get_next_level_dict_by_key(
current_dict, path_part, allow_append=False current_dict, path_part, allow_append=False
) )
current_dict = current_dict["value"] current_dict = cast(
dict[str, SerializedObject], next_level_serialized_object["value"]
)
current_dict = get_next_level_dict_by_key( next_level_serialized_object = get_next_level_dict_by_key(
current_dict, attr_name, allow_append=True current_dict, attr_name, allow_append=True
) )
except (SerializationPathError, SerializationValueError, KeyError) as e: except (SerializationPathError, SerializationValueError, KeyError) as e:
@ -270,47 +298,53 @@ def set_nested_value_by_path(
serialized_value = dump(value) serialized_value = dump(value)
keys_to_keep = set(serialized_value.keys()) keys_to_keep = set(serialized_value.keys())
if current_dict == {}: # adding an attribute / element to a list or dict if (
next_level_serialized_object == {}
): # adding an attribute / element to a list or dict
pass pass
elif current_dict["type"] == "method": # state change of task elif next_level_serialized_object["type"] == "method": # state change of task
keys_to_keep = set(current_dict.keys()) keys_to_keep = set(next_level_serialized_object.keys())
serialized_value = current_dict serialized_value = {} # type: ignore
serialized_value["value"] = value.name if isinstance(value, Enum) else None next_level_serialized_object["value"] = (
value.name if isinstance(value, Enum) else None
)
else: else:
# attribute-specific information should not be overwritten by new value # attribute-specific information should not be overwritten by new value
serialized_value.pop("readonly") serialized_value.pop("readonly") # type: ignore
serialized_value.pop("doc") serialized_value.pop("doc") # type: ignore
current_dict.update(serialized_value) next_level_serialized_object.update(serialized_value)
# removes keys that are not present in the serialized new value # removes keys that are not present in the serialized new value
for key in list(current_dict.keys()): for key in list(next_level_serialized_object.keys()):
if key not in keys_to_keep: if key not in keys_to_keep:
current_dict.pop(key, None) next_level_serialized_object.pop(key, None) # type: ignore
def get_nested_dict_by_path( def get_nested_dict_by_path(
serialization_dict: dict[str, Any], serialization_dict: dict[str, SerializedObject],
path: str, path: str,
) -> dict[str, Any]: ) -> SerializedObject:
parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1] parent_path_parts, attr_name = path.split(".")[:-1], path.split(".")[-1]
current_dict: dict[str, Any] = serialization_dict current_dict: dict[str, SerializedObject] = serialization_dict
for path_part in parent_path_parts: for path_part in parent_path_parts:
current_dict = get_next_level_dict_by_key( next_level_serialized_object = get_next_level_dict_by_key(
current_dict, path_part, allow_append=False current_dict, path_part, allow_append=False
) )
current_dict = current_dict["value"] current_dict = cast(
dict[str, SerializedObject], next_level_serialized_object["value"]
)
return get_next_level_dict_by_key(current_dict, attr_name, allow_append=False) return get_next_level_dict_by_key(current_dict, attr_name, allow_append=False)
def get_next_level_dict_by_key( def get_next_level_dict_by_key(
serialization_dict: dict[str, Any], serialization_dict: dict[str, SerializedObject],
attr_name: str, attr_name: str,
*, *,
allow_append: bool = False, allow_append: bool = False,
) -> dict[str, Any]: ) -> SerializedObject:
""" """
Retrieve a nested dictionary entry or list item from a data structure serialized Retrieve a nested dictionary entry or list item from a data structure serialized
with `pydase.utils.serializer.Serializer`. with `pydase.utils.serializer.Serializer`.
@ -335,14 +369,25 @@ def get_next_level_dict_by_key(
try: try:
if index is not None: if index is not None:
serialization_dict = serialization_dict[attr_name]["value"][index] next_level_serialized_object = cast(
list[SerializedObject], serialization_dict[attr_name]["value"]
)[index]
else: else:
serialization_dict = serialization_dict[attr_name] next_level_serialized_object = serialization_dict[attr_name]
except IndexError as e: except IndexError as e:
if allow_append and index == len(serialization_dict[attr_name]["value"]): if (
index is not None
and allow_append
and index
== len(cast(list[SerializedObject], serialization_dict[attr_name]["value"]))
):
# Appending to list # Appending to list
serialization_dict[attr_name]["value"].append({}) cast(list[SerializedObject], serialization_dict[attr_name]["value"]).append(
serialization_dict = serialization_dict[attr_name]["value"][index] {} # type: ignore
)
next_level_serialized_object = cast(
list[SerializedObject], serialization_dict[attr_name]["value"]
)[index]
else: else:
raise SerializationPathError( raise SerializationPathError(
f"Error occured trying to change '{attr_name}[{index}]': {e}" f"Error occured trying to change '{attr_name}[{index}]': {e}"
@ -354,17 +399,17 @@ def get_next_level_dict_by_key(
"a 'value' key." "a 'value' key."
) )
if not isinstance(serialization_dict, dict): if not isinstance(next_level_serialized_object, dict):
raise SerializationValueError( raise SerializationValueError(
f"Expected a dictionary at '{attr_name}', but found type " f"Expected a dictionary at '{attr_name}', but found type "
f"'{type(serialization_dict).__name__}' instead." f"'{type(next_level_serialized_object).__name__}' instead."
) )
return serialization_dict return next_level_serialized_object
def generate_serialized_data_paths( def generate_serialized_data_paths(
data: dict[str, dict[str, Any]], parent_path: str = "" data: dict[str, Any], parent_path: str = ""
) -> list[str]: ) -> list[str]:
""" """
Generate a list of access paths for all attributes in a dictionary representing Generate a list of access paths for all attributes in a dictionary representing
@ -404,7 +449,7 @@ def generate_serialized_data_paths(
return paths return paths
def serialized_dict_is_nested_object(serialized_dict: dict[str, Any]) -> bool: def serialized_dict_is_nested_object(serialized_dict: SerializedObject) -> bool:
return ( return (
serialized_dict["type"] != "Quantity" serialized_dict["type"] != "Quantity"
and isinstance(serialized_dict["value"], dict) and isinstance(serialized_dict["value"], dict)

View File

@ -11,6 +11,7 @@ from pydase.data_service.task_manager import TaskStatus
from pydase.utils.decorators import frontend from pydase.utils.decorators import frontend
from pydase.utils.serializer import ( from pydase.utils.serializer import (
SerializationPathError, SerializationPathError,
SerializedObject,
dump, dump,
get_nested_dict_by_path, get_nested_dict_by_path,
get_next_level_dict_by_key, get_next_level_dict_by_key,
@ -464,12 +465,12 @@ def test_update_task_state(setup_dict: dict[str, Any]) -> None:
} }
def test_update_list_entry(setup_dict: dict[str, Any]) -> None: def test_update_list_entry(setup_dict: dict[str, SerializedObject]) -> None:
set_nested_value_by_path(setup_dict, "attr_list[1]", 20) set_nested_value_by_path(setup_dict, "attr_list[1]", 20)
assert setup_dict["attr_list"]["value"][1]["value"] == 20 assert setup_dict["attr_list"]["value"][1]["value"] == 20
def test_update_list_append(setup_dict: dict[str, Any]) -> None: def test_update_list_append(setup_dict: dict[str, SerializedObject]) -> None:
set_nested_value_by_path(setup_dict, "attr_list[3]", 20) set_nested_value_by_path(setup_dict, "attr_list[3]", 20)
assert setup_dict["attr_list"]["value"][3]["value"] == 20 assert setup_dict["attr_list"]["value"][3]["value"] == 20