mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-21 10:10:02 +02:00
fix: fixed signature serializer for typing.Literal
This commit is contained in:
parent
968960646c
commit
5d4cd1c191
@ -1,9 +1,9 @@
|
||||
import builtins
|
||||
import inspect
|
||||
from typing import Callable, List
|
||||
from typing import Any, Callable, List, Literal
|
||||
|
||||
|
||||
def dtype_to_str(dtype: type) -> str:
|
||||
def serialize_dtype(dtype: type) -> Any:
|
||||
"""
|
||||
Convert a dtype to a string.
|
||||
|
||||
@ -13,12 +13,17 @@ def dtype_to_str(dtype: type) -> str:
|
||||
Returns:
|
||||
str: String representation of the data type
|
||||
"""
|
||||
return dtype.__name__
|
||||
if hasattr(dtype, "__name__"):
|
||||
return dtype.__name__
|
||||
if hasattr(dtype, "__module__"):
|
||||
if dtype.__module__ == "typing":
|
||||
return {"Literal": dtype.__args__}
|
||||
raise ValueError(f"Unknown dtype {dtype}")
|
||||
|
||||
|
||||
def str_to_dtype(dtype: str) -> type:
|
||||
def deserialize_dtype(dtype: Any) -> type:
|
||||
"""
|
||||
Convert a string to a dtype.
|
||||
Convert a serialized dtype to a type.
|
||||
|
||||
Args:
|
||||
dtype (str): String representation of the data type
|
||||
@ -27,7 +32,15 @@ def str_to_dtype(dtype: str) -> type:
|
||||
type: Data type
|
||||
"""
|
||||
if dtype == "_empty":
|
||||
# pylint: disable=protected-access
|
||||
return inspect._empty
|
||||
if isinstance(dtype, dict):
|
||||
# bodge needed for python 3.8
|
||||
if "Literal" in dtype:
|
||||
literal = Literal["dummy"]
|
||||
literal.__args__ = dtype["Literal"]
|
||||
return literal
|
||||
raise ValueError(f"Unknown dtype {dtype}")
|
||||
return builtins.__dict__.get(dtype)
|
||||
|
||||
|
||||
@ -47,12 +60,13 @@ def signature_to_dict(func: Callable, include_class_obj=False) -> dict:
|
||||
for param_name, param in params.items():
|
||||
if not include_class_obj and param_name == "self" or param_name == "cls":
|
||||
continue
|
||||
# pylint: disable=protected-access
|
||||
out.append(
|
||||
{
|
||||
"name": param_name,
|
||||
"kind": param.kind.name,
|
||||
"default": param.default if param.default != inspect._empty else "_empty",
|
||||
"annotation": dtype_to_str(param.annotation),
|
||||
"annotation": serialize_dtype(param.annotation),
|
||||
}
|
||||
)
|
||||
return out
|
||||
@ -70,12 +84,13 @@ def dict_to_signature(params: List[dict]) -> inspect.Signature:
|
||||
"""
|
||||
out = []
|
||||
for param in params:
|
||||
# pylint: disable=protected-access
|
||||
out.append(
|
||||
inspect.Parameter(
|
||||
name=param["name"],
|
||||
kind=getattr(inspect.Parameter, param["kind"]),
|
||||
default=param["default"] if param["default"] != "_empty" else inspect._empty,
|
||||
annotation=str_to_dtype(param["annotation"]),
|
||||
annotation=deserialize_dtype(param["annotation"]),
|
||||
)
|
||||
)
|
||||
return inspect.Signature(out)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from bec_lib.signature_serializer import dict_to_signature, signature_to_dict
|
||||
|
||||
@ -28,3 +29,27 @@ def test_signature_serializer():
|
||||
|
||||
sig = dict_to_signature(params)
|
||||
assert sig == inspect.signature(test_func)
|
||||
|
||||
|
||||
def test_signature_serializer_with_literals():
|
||||
def test_func(a, b: typing.Literal[1, 2, 3] = 1):
|
||||
pass
|
||||
|
||||
params = signature_to_dict(test_func)
|
||||
assert params == [
|
||||
{
|
||||
"name": "a",
|
||||
"kind": "POSITIONAL_OR_KEYWORD",
|
||||
"default": "_empty",
|
||||
"annotation": "_empty",
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"kind": "POSITIONAL_OR_KEYWORD",
|
||||
"default": 1,
|
||||
"annotation": {"Literal": (1, 2, 3)},
|
||||
},
|
||||
]
|
||||
|
||||
sig = dict_to_signature(params)
|
||||
assert sig == inspect.signature(test_func)
|
||||
|
Loading…
x
Reference in New Issue
Block a user