diff --git a/scilog/snippet.py b/scilog/snippet.py index 4b82a9f..31a3586 100644 --- a/scilog/snippet.py +++ b/scilog/snippet.py @@ -5,12 +5,14 @@ from .utils import typename def typechecked(func): @functools.wraps(func) - def typechecked_call(*args, **kwargs): - func_types = get_type_hints(func) - for index, key in enumerate(func_types.keys()): - if key != "return": - assert func_types[key] == type(args[index+1]), f"{repr(func)} expected to receive input of type {func_types[key].__name__} but received {type(args[index+1]).__name__}" - return func(*args, **kwargs) + def typechecked_call(obj, *args, **kwargs): + type_hints = get_type_hints(func) + del type_hints["return"] + for arg, dtype in zip(args, type_hints): + arg_type = type(arg) + if dtype != arg_type: + raise TypeError(f"{func} expected to receive input of type {dtype.__name__} but received {arg_type.__name__}") + return func(obj, *args, **kwargs) return typechecked_call