fix: fixed signature serializer for typing.Literal

This commit is contained in:
wakonig_k 2023-11-18 12:14:17 +01:00
parent 968960646c
commit 5d4cd1c191
2 changed files with 47 additions and 7 deletions

View File

@ -1,9 +1,9 @@
import builtins import builtins
import inspect import inspect
from typing import Callable, List from typing import Any, Callable, List, Literal
def dtype_to_str(dtype: type) -> str: def serialize_dtype(dtype: type) -> Any:
""" """
Convert a dtype to a string. Convert a dtype to a string.
@ -13,12 +13,17 @@ def dtype_to_str(dtype: type) -> str:
Returns: Returns:
str: String representation of the data type str: String representation of the data type
""" """
if hasattr(dtype, "__name__"):
return dtype.__name__ return dtype.__name__
if hasattr(dtype, "__module__"):
if dtype.__module__ == "typing":
return {"Literal": dtype.__args__}
raise ValueError(f"Unknown dtype {dtype}")
def str_to_dtype(dtype: str) -> type: def deserialize_dtype(dtype: Any) -> type:
""" """
Convert a string to a dtype. Convert a serialized dtype to a type.
Args: Args:
dtype (str): String representation of the data type dtype (str): String representation of the data type
@ -27,7 +32,15 @@ def str_to_dtype(dtype: str) -> type:
type: Data type type: Data type
""" """
if dtype == "_empty": if dtype == "_empty":
# pylint: disable=protected-access
return inspect._empty return inspect._empty
if isinstance(dtype, dict):
# bodge needed for python 3.8
if "Literal" in dtype:
literal = Literal["dummy"]
literal.__args__ = dtype["Literal"]
return literal
raise ValueError(f"Unknown dtype {dtype}")
return builtins.__dict__.get(dtype) return builtins.__dict__.get(dtype)
@ -47,12 +60,13 @@ def signature_to_dict(func: Callable, include_class_obj=False) -> dict:
for param_name, param in params.items(): for param_name, param in params.items():
if not include_class_obj and param_name == "self" or param_name == "cls": if not include_class_obj and param_name == "self" or param_name == "cls":
continue continue
# pylint: disable=protected-access
out.append( out.append(
{ {
"name": param_name, "name": param_name,
"kind": param.kind.name, "kind": param.kind.name,
"default": param.default if param.default != inspect._empty else "_empty", "default": param.default if param.default != inspect._empty else "_empty",
"annotation": dtype_to_str(param.annotation), "annotation": serialize_dtype(param.annotation),
} }
) )
return out return out
@ -70,12 +84,13 @@ def dict_to_signature(params: List[dict]) -> inspect.Signature:
""" """
out = [] out = []
for param in params: for param in params:
# pylint: disable=protected-access
out.append( out.append(
inspect.Parameter( inspect.Parameter(
name=param["name"], name=param["name"],
kind=getattr(inspect.Parameter, param["kind"]), kind=getattr(inspect.Parameter, param["kind"]),
default=param["default"] if param["default"] != "_empty" else inspect._empty, default=param["default"] if param["default"] != "_empty" else inspect._empty,
annotation=str_to_dtype(param["annotation"]), annotation=deserialize_dtype(param["annotation"]),
) )
) )
return inspect.Signature(out) return inspect.Signature(out)

View File

@ -1,4 +1,5 @@
import inspect import inspect
import typing
from bec_lib.signature_serializer import dict_to_signature, signature_to_dict from bec_lib.signature_serializer import dict_to_signature, signature_to_dict
@ -28,3 +29,27 @@ def test_signature_serializer():
sig = dict_to_signature(params) sig = dict_to_signature(params)
assert sig == inspect.signature(test_func) assert sig == inspect.signature(test_func)
def test_signature_serializer_with_literals():
def test_func(a, b: typing.Literal[1, 2, 3] = 1):
pass
params = signature_to_dict(test_func)
assert params == [
{
"name": "a",
"kind": "POSITIONAL_OR_KEYWORD",
"default": "_empty",
"annotation": "_empty",
},
{
"name": "b",
"kind": "POSITIONAL_OR_KEYWORD",
"default": 1,
"annotation": {"Literal": (1, 2, 3)},
},
]
sig = dict_to_signature(params)
assert sig == inspect.signature(test_func)