task can only wrap async functions without arguments

This commit is contained in:
Mose Müller 2024-08-06 09:36:12 +02:00
parent 1e02f12794
commit 3cd7198747
2 changed files with 17 additions and 21 deletions

View File

@ -2,26 +2,25 @@ import asyncio
import functools import functools
import logging import logging
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from typing import Any, Concatenate, ParamSpec, TypeVar from typing import Any, TypeVar
from pydase.task.task import Task from pydase.task.task import Task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
def task( def task(
*, autostart: bool = False *, autostart: bool = False
) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]: ) -> Callable[[Callable[[Any], Coroutine[None, None, R]]], Task[R]]:
def decorator( def decorator(
func: Callable[Concatenate[Any, P], Coroutine[None, None, R]], func: Callable[[Any], Coroutine[None, None, R]],
) -> Task[P, R]: ) -> Task[R]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None: async def wrapper(self: Any) -> R | None:
try: try:
return await func(self, *args, **kwargs) return await func(self)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", func.__name__) logger.info("Task '%s' was cancelled", func.__name__)
return None return None

View File

@ -5,9 +5,7 @@ from collections.abc import Callable, Coroutine
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
Concatenate,
Generic, Generic,
ParamSpec,
Self, Self,
TypeVar, TypeVar,
) )
@ -19,14 +17,13 @@ import pydase
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
def is_bound_method( def is_bound_method(
method: Callable[P, Coroutine[None, None, R | None]] method: Callable[[], Coroutine[None, None, R | None]]
| Callable[Concatenate[Any, P], Coroutine[None, None, R | None]], | Callable[[Any], Coroutine[None, None, R | None]],
) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]: ) -> TypeIs[Callable[[], Coroutine[None, None, R | None]]]:
"""Check if instance method is bound to an object.""" """Check if instance method is bound to an object."""
return inspect.ismethod(method) return inspect.ismethod(method)
@ -36,17 +33,17 @@ class TaskStatus(Enum):
NOT_RUNNING = "not_running" NOT_RUNNING = "not_running"
class Task(pydase.DataService, Generic[P, R]): class Task(pydase.DataService, Generic[R]):
def __init__( def __init__(
self, self,
func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]] func: Callable[[Any], Coroutine[None, None, R | None]]
| Callable[P, Coroutine[None, None, R | None]], | Callable[[], Coroutine[None, None, R | None]],
*, *,
autostart: bool = False, autostart: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self._func_name = func.__name__ self._func_name = func.__name__
self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None self._bound_func: Callable[[], Coroutine[None, None, R | None]] | None = None
if is_bound_method(func): if is_bound_method(func):
self._func = func self._func = func
self._bound_func = func self._bound_func = func
@ -62,7 +59,7 @@ class Task(pydase.DataService, Generic[P, R]):
def status(self) -> TaskStatus: def status(self) -> TaskStatus:
return self._status return self._status
def start(self, *args: P.args, **kwargs: P.kwargs) -> None: def start(self) -> None:
if self._task: if self._task:
return return
@ -93,7 +90,7 @@ class Task(pydase.DataService, Generic[P, R]):
logger.info("Starting task %s", self._func_name) logger.info("Starting task %s", self._func_name)
if inspect.iscoroutinefunction(self._bound_func): if inspect.iscoroutinefunction(self._bound_func):
res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs) res: Coroutine[None, None, R] = self._bound_func()
self._task = asyncio.create_task(res) self._task = asyncio.create_task(res)
self._task.add_done_callback(task_done_callback) self._task.add_done_callback(task_done_callback)
self._status = TaskStatus.RUNNING self._status = TaskStatus.RUNNING
@ -107,9 +104,9 @@ class Task(pydase.DataService, Generic[P, R]):
# containing the function # containing the function
if instance and self._bound_func is not None: if instance and self._bound_func is not None:
async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None: async def bound_func() -> R | None:
if not is_bound_method(self._func): if not is_bound_method(self._func):
return await self._func(instance, *args, **kwargs) return await self._func(instance)
return None return None
self._bound_func = bound_func self._bound_func = bound_func