From 5d4cd1c1918b4f417b9ebb51e5a12b5692bd7384 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 18 Nov 2023 12:14:17 +0100 Subject: [PATCH] fix: fixed signature serializer for typing.Literal --- bec_lib/bec_lib/signature_serializer.py | 29 ++++++++++++++++------ bec_lib/tests/test_signature_serializer.py | 25 +++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/bec_lib/bec_lib/signature_serializer.py b/bec_lib/bec_lib/signature_serializer.py index 6938d161..161b2f5d 100644 --- a/bec_lib/bec_lib/signature_serializer.py +++ b/bec_lib/bec_lib/signature_serializer.py @@ -1,9 +1,9 @@ import builtins import inspect -from typing import Callable, List +from typing import Any, Callable, List, Literal -def dtype_to_str(dtype: type) -> str: +def serialize_dtype(dtype: type) -> Any: """ Convert a dtype to a string. @@ -13,12 +13,17 @@ def dtype_to_str(dtype: type) -> str: Returns: str: String representation of the data type """ - return dtype.__name__ + if hasattr(dtype, "__name__"): + return dtype.__name__ + if hasattr(dtype, "__module__"): + if dtype.__module__ == "typing": + return {"Literal": dtype.__args__} + raise ValueError(f"Unknown dtype {dtype}") -def str_to_dtype(dtype: str) -> type: +def deserialize_dtype(dtype: Any) -> type: """ - Convert a string to a dtype. + Convert a serialized dtype to a type. Args: dtype (str): String representation of the data type @@ -27,7 +32,15 @@ def str_to_dtype(dtype: str) -> type: type: Data type """ if dtype == "_empty": + # pylint: disable=protected-access return inspect._empty + if isinstance(dtype, dict): + # bodge needed for python 3.8 + if "Literal" in dtype: + literal = Literal["dummy"] + literal.__args__ = dtype["Literal"] + return literal + raise ValueError(f"Unknown dtype {dtype}") return builtins.__dict__.get(dtype) @@ -47,12 +60,13 @@ def signature_to_dict(func: Callable, include_class_obj=False) -> dict: for param_name, param in params.items(): if not include_class_obj and param_name == "self" or param_name == "cls": continue + # pylint: disable=protected-access out.append( { "name": param_name, "kind": param.kind.name, "default": param.default if param.default != inspect._empty else "_empty", - "annotation": dtype_to_str(param.annotation), + "annotation": serialize_dtype(param.annotation), } ) return out @@ -70,12 +84,13 @@ def dict_to_signature(params: List[dict]) -> inspect.Signature: """ out = [] for param in params: + # pylint: disable=protected-access out.append( inspect.Parameter( name=param["name"], kind=getattr(inspect.Parameter, param["kind"]), default=param["default"] if param["default"] != "_empty" else inspect._empty, - annotation=str_to_dtype(param["annotation"]), + annotation=deserialize_dtype(param["annotation"]), ) ) return inspect.Signature(out) diff --git a/bec_lib/tests/test_signature_serializer.py b/bec_lib/tests/test_signature_serializer.py index 86e320be..ebf338a6 100644 --- a/bec_lib/tests/test_signature_serializer.py +++ b/bec_lib/tests/test_signature_serializer.py @@ -1,4 +1,5 @@ import inspect +import typing from bec_lib.signature_serializer import dict_to_signature, signature_to_dict @@ -28,3 +29,27 @@ def test_signature_serializer(): sig = dict_to_signature(params) assert sig == inspect.signature(test_func) + + +def test_signature_serializer_with_literals(): + def test_func(a, b: typing.Literal[1, 2, 3] = 1): + pass + + params = signature_to_dict(test_func) + assert params == [ + { + "name": "a", + "kind": "POSITIONAL_OR_KEYWORD", + "default": "_empty", + "annotation": "_empty", + }, + { + "name": "b", + "kind": "POSITIONAL_OR_KEYWORD", + "default": 1, + "annotation": {"Literal": (1, 2, 3)}, + }, + ] + + sig = dict_to_signature(params) + assert sig == inspect.signature(test_func)