diff --git a/src/pydase/data_service/data_service_cache.py b/src/pydase/data_service/data_service_cache.py index a84179a..91fa6a0 100644 --- a/src/pydase/data_service/data_service_cache.py +++ b/src/pydase/data_service/data_service_cache.py @@ -47,7 +47,7 @@ class DataServiceCache: return { "full_access_path": full_access_path, "value": None, - "type": None, + "type": "NoneType", "doc": None, "readonly": False, } diff --git a/src/pydase/utils/serializer.py b/src/pydase/utils/serializer.py index a8acbbd..f094f9b 100644 --- a/src/pydase/utils/serializer.py +++ b/src/pydase/utils/serializer.py @@ -4,12 +4,7 @@ import inspect import logging import sys from enum import Enum -from typing import TYPE_CHECKING, Any, TypedDict, cast - -if sys.version_info < (3, 11): - from typing_extensions import NotRequired -else: - from typing import NotRequired +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast import pydase.units as u from pydase.data_service.abstract_data_service import AbstractDataService @@ -28,6 +23,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class SerializationError(Exception): + pass + + class SerializationPathError(Exception): pass @@ -41,23 +40,105 @@ class SignatureDict(TypedDict): return_annotation: dict[str, Any] -SerializedObject = TypedDict( - "SerializedObject", +class SerializedObjectBase(TypedDict): + full_access_path: str + doc: str | None + readonly: bool + + +class SerializedInteger(SerializedObjectBase): + value: int + type: Literal["int"] + + +class SerializedFloat(SerializedObjectBase): + value: float + type: Literal["float"] + + +class SerializedQuantity(SerializedObjectBase): + value: u.QuantityDict + type: Literal["Quantity"] + + +class SerializedBool(SerializedObjectBase): + value: bool + type: Literal["bool"] + + +class SerializedString(SerializedObjectBase): + value: str + type: Literal["str"] + + +class SerializedEnum(SerializedObjectBase): + name: str + value: str + type: Literal["Enum", "ColouredEnum"] + enum: dict[str, Any] + + +class SerializedList(SerializedObjectBase): + value: list[SerializedObject] + type: Literal["list"] + + +class SerializedDict(SerializedObjectBase): + value: dict[str, SerializedObject] + type: Literal["dict"] + + +class SerializedNoneType(SerializedObjectBase): + value: None + type: Literal["NoneType"] + + +SerializedMethod = TypedDict( + "SerializedMethod", { "full_access_path": str, - "name": NotRequired[str], - "value": "list[SerializedObject] | float | int | str | bool | dict[str, Any] | None", # noqa: E501 - "type": str | None, + "value": Literal["RUNNING"] | None, + "type": Literal["method"], "doc": str | None, "readonly": bool, - "enum": NotRequired[dict[str, Any]], - "async": NotRequired[bool], - "signature": NotRequired[SignatureDict], - "frontend_render": NotRequired[bool], + "async": bool, + "signature": SignatureDict, + "frontend_render": bool, }, ) +class SerializedException(SerializedObjectBase): + name: str + value: str + type: Literal["Exception"] + + +DataServiceTypes = Literal["DataService", "Image", "NumberSlider", "DeviceConnection"] + + +class SerializedDataService(SerializedObjectBase): + name: str + value: dict[str, SerializedObject] + type: DataServiceTypes + + +SerializedObject = ( + SerializedBool + | SerializedFloat + | SerializedInteger + | SerializedString + | SerializedList + | SerializedDict + | SerializedNoneType + | SerializedMethod + | SerializedException + | SerializedDataService + | SerializedEnum + | SerializedQuantity +) + + class Serializer: @staticmethod def serialize_object(obj: Any, access_path: str = "") -> SerializedObject: @@ -87,26 +168,41 @@ class Serializer: elif inspect.isfunction(obj) or inspect.ismethod(obj): result = Serializer._serialize_method(obj, access_path=access_path) - else: - obj_type = type(obj).__name__ - value = obj - readonly = False - doc = get_attribute_doc(obj) - result = { - "full_access_path": access_path, - "type": obj_type, - "value": value, - "readonly": readonly, - "doc": doc, - } + elif isinstance(obj, int | float | bool | str | None): + result = Serializer._serialize_primitive(obj, access_path=access_path) - return result + try: + return result + except UnboundLocalError: + raise SerializationError( + f"Could not serialized object of type {type(obj)}." + ) @staticmethod - def _serialize_exception(obj: Exception) -> SerializedObject: + def _serialize_primitive( + obj: float | bool | str | None, + access_path: str, + ) -> ( + SerializedInteger + | SerializedFloat + | SerializedBool + | SerializedString + | SerializedNoneType + ): + doc = get_attribute_doc(obj) + return { # type: ignore + "full_access_path": access_path, + "doc": doc, + "readonly": False, + "type": type(obj).__name__, + "value": obj, + } + + @staticmethod + def _serialize_exception(obj: Exception) -> SerializedException: return { "full_access_path": "", - "doc": "", + "doc": None, "readonly": True, "type": "Exception", "value": obj.args[0], @@ -114,17 +210,16 @@ class Serializer: } @staticmethod - def _serialize_enum(obj: Enum, access_path: str = "") -> SerializedObject: + def _serialize_enum(obj: Enum, access_path: str = "") -> SerializedEnum: import pydase.components.coloured_enum value = obj.name - readonly = False doc = obj.__doc__ class_name = type(obj).__name__ if sys.version_info < (3, 11) and doc == "An enumeration.": doc = None if isinstance(obj, pydase.components.coloured_enum.ColouredEnum): - obj_type = "ColouredEnum" + obj_type: Literal["ColouredEnum", "Enum"] = "ColouredEnum" else: obj_type = "Enum" @@ -133,7 +228,7 @@ class Serializer: "name": class_name, "type": obj_type, "value": value, - "readonly": readonly, + "readonly": False, "doc": doc, "enum": { name: member.value for name, member in obj.__class__.__members__.items() @@ -141,22 +236,21 @@ class Serializer: } @staticmethod - def _serialize_quantity(obj: u.Quantity, access_path: str = "") -> SerializedObject: - obj_type = "Quantity" - readonly = False + def _serialize_quantity( + obj: u.Quantity, access_path: str = "" + ) -> SerializedQuantity: doc = get_attribute_doc(obj) - value = {"magnitude": obj.m, "unit": str(obj.u)} + value: u.QuantityDict = {"magnitude": obj.m, "unit": str(obj.u)} return { "full_access_path": access_path, - "type": obj_type, + "type": "Quantity", "value": value, - "readonly": readonly, + "readonly": False, "doc": doc, } @staticmethod - def _serialize_dict(obj: dict[str, Any], access_path: str = "") -> SerializedObject: - obj_type = "dict" + def _serialize_dict(obj: dict[str, Any], access_path: str = "") -> SerializedDict: readonly = False doc = get_attribute_doc(obj) value = { @@ -165,15 +259,14 @@ class Serializer: } return { "full_access_path": access_path, - "type": obj_type, + "type": "dict", "value": value, "readonly": readonly, "doc": doc, } @staticmethod - def _serialize_list(obj: list[Any], access_path: str = "") -> SerializedObject: - obj_type = "list" + def _serialize_list(obj: list[Any], access_path: str = "") -> SerializedList: readonly = False doc = get_attribute_doc(obj) value = [ @@ -182,7 +275,7 @@ class Serializer: ] return { "full_access_path": access_path, - "type": obj_type, + "type": "list", "value": value, "readonly": readonly, "doc": doc, @@ -191,9 +284,7 @@ class Serializer: @staticmethod def _serialize_method( obj: Callable[..., Any], access_path: str = "" - ) -> SerializedObject: - obj_type = "method" - value = None + ) -> SerializedMethod: readonly = True doc = get_attribute_doc(obj) frontend_render = render_in_frontend(obj) @@ -216,8 +307,8 @@ class Serializer: return { "full_access_path": access_path, - "type": obj_type, - "value": value, + "type": "method", + "value": None, "readonly": readonly, "doc": doc, "async": inspect.iscoroutinefunction(obj), @@ -228,10 +319,10 @@ class Serializer: @staticmethod def _serialize_data_service( obj: AbstractDataService, access_path: str = "" - ) -> SerializedObject: + ) -> SerializedDataService: readonly = False doc = get_attribute_doc(obj) - obj_type = "DataService" + obj_type: DataServiceTypes = "DataService" obj_name = obj.__class__.__name__ # Get component base class if any @@ -239,7 +330,7 @@ class Serializer: (cls for cls in get_component_classes() if isinstance(obj, cls)), None ) if component_base_cls: - obj_type = component_base_cls.__name__ + obj_type = component_base_cls.__name__ # type: ignore # Get the set of DataService class attributes data_service_attr_set = set(dir(get_data_service_class_reference())) @@ -268,11 +359,13 @@ class Serializer: val = getattr(obj, key) path = f"{access_path}.{key}" if access_path else key - value[key] = Serializer.serialize_object(val, access_path=path) + serialized_object = Serializer.serialize_object(val, access_path=path) # If there's a running task for this method - if key in obj._task_manager.tasks: - value[key]["value"] = TaskStatus.RUNNING.name + if serialized_object["type"] == "method" and key in obj._task_manager.tasks: + serialized_object["value"] = TaskStatus.RUNNING.name + + value[key] = serialized_object # If the DataService attribute is a property if isinstance(getattr(obj.__class__, key, None), property): @@ -337,7 +430,7 @@ def set_nested_value_by_path( if next_level_serialized_object["type"] == "method": # state change of task next_level_serialized_object["value"] = ( - value.name if isinstance(value, Enum) else None + value.name if TaskStatus.RUNNING else None ) else: serialized_value = dump(value) @@ -349,7 +442,7 @@ def set_nested_value_by_path( keys_to_keep = set(serialized_value.keys()) - next_level_serialized_object.update(serialized_value) + next_level_serialized_object.update(serialized_value) # type: ignore # removes keys that are not present in the serialized new value for key in list(next_level_serialized_object.keys()): @@ -421,7 +514,7 @@ def get_next_level_dict_by_key( { "full_access_path": "", "value": None, - "type": None, + "type": "NoneType", "doc": None, "readonly": False, }