mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-21 18:20:01 +02:00
fix: fixed signature serializer for typing.Literal
This commit is contained in:
parent
968960646c
commit
5d4cd1c191
@ -1,9 +1,9 @@
|
|||||||
import builtins
|
import builtins
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List
|
from typing import Any, Callable, List, Literal
|
||||||
|
|
||||||
|
|
||||||
def dtype_to_str(dtype: type) -> str:
|
def serialize_dtype(dtype: type) -> Any:
|
||||||
"""
|
"""
|
||||||
Convert a dtype to a string.
|
Convert a dtype to a string.
|
||||||
|
|
||||||
@ -13,12 +13,17 @@ def dtype_to_str(dtype: type) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: String representation of the data type
|
str: String representation of the data type
|
||||||
"""
|
"""
|
||||||
|
if hasattr(dtype, "__name__"):
|
||||||
return dtype.__name__
|
return dtype.__name__
|
||||||
|
if hasattr(dtype, "__module__"):
|
||||||
|
if dtype.__module__ == "typing":
|
||||||
|
return {"Literal": dtype.__args__}
|
||||||
|
raise ValueError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(dtype: str) -> type:
|
def deserialize_dtype(dtype: Any) -> type:
|
||||||
"""
|
"""
|
||||||
Convert a string to a dtype.
|
Convert a serialized dtype to a type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dtype (str): String representation of the data type
|
dtype (str): String representation of the data type
|
||||||
@ -27,7 +32,15 @@ def str_to_dtype(dtype: str) -> type:
|
|||||||
type: Data type
|
type: Data type
|
||||||
"""
|
"""
|
||||||
if dtype == "_empty":
|
if dtype == "_empty":
|
||||||
|
# pylint: disable=protected-access
|
||||||
return inspect._empty
|
return inspect._empty
|
||||||
|
if isinstance(dtype, dict):
|
||||||
|
# bodge needed for python 3.8
|
||||||
|
if "Literal" in dtype:
|
||||||
|
literal = Literal["dummy"]
|
||||||
|
literal.__args__ = dtype["Literal"]
|
||||||
|
return literal
|
||||||
|
raise ValueError(f"Unknown dtype {dtype}")
|
||||||
return builtins.__dict__.get(dtype)
|
return builtins.__dict__.get(dtype)
|
||||||
|
|
||||||
|
|
||||||
@ -47,12 +60,13 @@ def signature_to_dict(func: Callable, include_class_obj=False) -> dict:
|
|||||||
for param_name, param in params.items():
|
for param_name, param in params.items():
|
||||||
if not include_class_obj and param_name == "self" or param_name == "cls":
|
if not include_class_obj and param_name == "self" or param_name == "cls":
|
||||||
continue
|
continue
|
||||||
|
# pylint: disable=protected-access
|
||||||
out.append(
|
out.append(
|
||||||
{
|
{
|
||||||
"name": param_name,
|
"name": param_name,
|
||||||
"kind": param.kind.name,
|
"kind": param.kind.name,
|
||||||
"default": param.default if param.default != inspect._empty else "_empty",
|
"default": param.default if param.default != inspect._empty else "_empty",
|
||||||
"annotation": dtype_to_str(param.annotation),
|
"annotation": serialize_dtype(param.annotation),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@ -70,12 +84,13 @@ def dict_to_signature(params: List[dict]) -> inspect.Signature:
|
|||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
for param in params:
|
for param in params:
|
||||||
|
# pylint: disable=protected-access
|
||||||
out.append(
|
out.append(
|
||||||
inspect.Parameter(
|
inspect.Parameter(
|
||||||
name=param["name"],
|
name=param["name"],
|
||||||
kind=getattr(inspect.Parameter, param["kind"]),
|
kind=getattr(inspect.Parameter, param["kind"]),
|
||||||
default=param["default"] if param["default"] != "_empty" else inspect._empty,
|
default=param["default"] if param["default"] != "_empty" else inspect._empty,
|
||||||
annotation=str_to_dtype(param["annotation"]),
|
annotation=deserialize_dtype(param["annotation"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return inspect.Signature(out)
|
return inspect.Signature(out)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import typing
|
||||||
|
|
||||||
from bec_lib.signature_serializer import dict_to_signature, signature_to_dict
|
from bec_lib.signature_serializer import dict_to_signature, signature_to_dict
|
||||||
|
|
||||||
@ -28,3 +29,27 @@ def test_signature_serializer():
|
|||||||
|
|
||||||
sig = dict_to_signature(params)
|
sig = dict_to_signature(params)
|
||||||
assert sig == inspect.signature(test_func)
|
assert sig == inspect.signature(test_func)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_serializer_with_literals():
|
||||||
|
def test_func(a, b: typing.Literal[1, 2, 3] = 1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
params = signature_to_dict(test_func)
|
||||||
|
assert params == [
|
||||||
|
{
|
||||||
|
"name": "a",
|
||||||
|
"kind": "POSITIONAL_OR_KEYWORD",
|
||||||
|
"default": "_empty",
|
||||||
|
"annotation": "_empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "b",
|
||||||
|
"kind": "POSITIONAL_OR_KEYWORD",
|
||||||
|
"default": 1,
|
||||||
|
"annotation": {"Literal": (1, 2, 3)},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
sig = dict_to_signature(params)
|
||||||
|
assert sig == inspect.signature(test_func)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user