fix: fixed signature serializer for typing.Literal

This commit is contained in:
wakonig_k 2023-11-18 12:14:17 +01:00
parent 968960646c
commit 5d4cd1c191
2 changed files with 47 additions and 7 deletions

View File

@ -1,9 +1,9 @@
import builtins
import inspect
from typing import Callable, List
from typing import Any, Callable, List, Literal
def dtype_to_str(dtype: type) -> str:
def serialize_dtype(dtype: type) -> Any:
"""
Convert a dtype to a string.
@ -13,12 +13,17 @@ def dtype_to_str(dtype: type) -> str:
Returns:
str: String representation of the data type
"""
if hasattr(dtype, "__name__"):
return dtype.__name__
if hasattr(dtype, "__module__"):
if dtype.__module__ == "typing":
return {"Literal": dtype.__args__}
raise ValueError(f"Unknown dtype {dtype}")
def str_to_dtype(dtype: str) -> type:
def deserialize_dtype(dtype: Any) -> type:
"""
Convert a string to a dtype.
Convert a serialized dtype to a type.
Args:
dtype (str): String representation of the data type
@ -27,7 +32,15 @@ def str_to_dtype(dtype: str) -> type:
type: Data type
"""
if dtype == "_empty":
# pylint: disable=protected-access
return inspect._empty
if isinstance(dtype, dict):
# bodge needed for python 3.8
if "Literal" in dtype:
literal = Literal["dummy"]
literal.__args__ = dtype["Literal"]
return literal
raise ValueError(f"Unknown dtype {dtype}")
return builtins.__dict__.get(dtype)
@ -47,12 +60,13 @@ def signature_to_dict(func: Callable, include_class_obj=False) -> dict:
for param_name, param in params.items():
if not include_class_obj and param_name == "self" or param_name == "cls":
continue
# pylint: disable=protected-access
out.append(
{
"name": param_name,
"kind": param.kind.name,
"default": param.default if param.default != inspect._empty else "_empty",
"annotation": dtype_to_str(param.annotation),
"annotation": serialize_dtype(param.annotation),
}
)
return out
@ -70,12 +84,13 @@ def dict_to_signature(params: List[dict]) -> inspect.Signature:
"""
out = []
for param in params:
# pylint: disable=protected-access
out.append(
inspect.Parameter(
name=param["name"],
kind=getattr(inspect.Parameter, param["kind"]),
default=param["default"] if param["default"] != "_empty" else inspect._empty,
annotation=str_to_dtype(param["annotation"]),
annotation=deserialize_dtype(param["annotation"]),
)
)
return inspect.Signature(out)

View File

@ -1,4 +1,5 @@
import inspect
import typing
from bec_lib.signature_serializer import dict_to_signature, signature_to_dict
@ -28,3 +29,27 @@ def test_signature_serializer():
sig = dict_to_signature(params)
assert sig == inspect.signature(test_func)
def test_signature_serializer_with_literals():
def test_func(a, b: typing.Literal[1, 2, 3] = 1):
pass
params = signature_to_dict(test_func)
assert params == [
{
"name": "a",
"kind": "POSITIONAL_OR_KEYWORD",
"default": "_empty",
"annotation": "_empty",
},
{
"name": "b",
"kind": "POSITIONAL_OR_KEYWORD",
"default": 1,
"annotation": {"Literal": (1, 2, 3)},
},
]
sig = dict_to_signature(params)
assert sig == inspect.signature(test_func)