task can only wrap async functions without arguments

This commit is contained in:
Mose Müller 2024-08-06 09:36:12 +02:00
parent 1e02f12794
commit 3cd7198747
2 changed files with 17 additions and 21 deletions

View File

@ -2,26 +2,25 @@ import asyncio
import functools
import logging
from collections.abc import Callable, Coroutine
from typing import Any, Concatenate, ParamSpec, TypeVar
from typing import Any, TypeVar
from pydase.task.task import Task
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def task(
*, autostart: bool = False
) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]:
) -> Callable[[Callable[[Any], Coroutine[None, None, R]]], Task[R]]:
def decorator(
func: Callable[Concatenate[Any, P], Coroutine[None, None, R]],
) -> Task[P, R]:
func: Callable[[Any], Coroutine[None, None, R]],
) -> Task[R]:
@functools.wraps(func)
async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None:
async def wrapper(self: Any) -> R | None:
try:
return await func(self, *args, **kwargs)
return await func(self)
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", func.__name__)
return None

View File

@ -5,9 +5,7 @@ from collections.abc import Callable, Coroutine
from enum import Enum
from typing import (
Any,
Concatenate,
Generic,
ParamSpec,
Self,
TypeVar,
)
@ -19,14 +17,13 @@ import pydase
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def is_bound_method(
method: Callable[P, Coroutine[None, None, R | None]]
| Callable[Concatenate[Any, P], Coroutine[None, None, R | None]],
) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]:
method: Callable[[], Coroutine[None, None, R | None]]
| Callable[[Any], Coroutine[None, None, R | None]],
) -> TypeIs[Callable[[], Coroutine[None, None, R | None]]]:
"""Check if instance method is bound to an object."""
return inspect.ismethod(method)
@ -36,17 +33,17 @@ class TaskStatus(Enum):
NOT_RUNNING = "not_running"
class Task(pydase.DataService, Generic[P, R]):
class Task(pydase.DataService, Generic[R]):
def __init__(
self,
func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]]
| Callable[P, Coroutine[None, None, R | None]],
func: Callable[[Any], Coroutine[None, None, R | None]]
| Callable[[], Coroutine[None, None, R | None]],
*,
autostart: bool = False,
) -> None:
super().__init__()
self._func_name = func.__name__
self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None
self._bound_func: Callable[[], Coroutine[None, None, R | None]] | None = None
if is_bound_method(func):
self._func = func
self._bound_func = func
@ -62,7 +59,7 @@ class Task(pydase.DataService, Generic[P, R]):
def status(self) -> TaskStatus:
return self._status
def start(self, *args: P.args, **kwargs: P.kwargs) -> None:
def start(self) -> None:
if self._task:
return
@ -93,7 +90,7 @@ class Task(pydase.DataService, Generic[P, R]):
logger.info("Starting task %s", self._func_name)
if inspect.iscoroutinefunction(self._bound_func):
res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs)
res: Coroutine[None, None, R] = self._bound_func()
self._task = asyncio.create_task(res)
self._task.add_done_callback(task_done_callback)
self._status = TaskStatus.RUNNING
@ -107,9 +104,9 @@ class Task(pydase.DataService, Generic[P, R]):
# containing the function
if instance and self._bound_func is not None:
async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None:
async def bound_func() -> R | None:
if not is_bound_method(self._func):
return await self._func(instance, *args, **kwargs)
return await self._func(instance)
return None
self._bound_func = bound_func