From 3cd71987470c2b0e0b974a185b8463406f2a4ec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Tue, 6 Aug 2024 09:36:12 +0200 Subject: [PATCH] task can only wrap async functions without arguments --- src/pydase/task/decorator.py | 13 ++++++------- src/pydase/task/task.py | 25 +++++++++++-------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/pydase/task/decorator.py b/src/pydase/task/decorator.py index 58c7d9f..a1d9a70 100644 --- a/src/pydase/task/decorator.py +++ b/src/pydase/task/decorator.py @@ -2,26 +2,25 @@ import asyncio import functools import logging from collections.abc import Callable, Coroutine -from typing import Any, Concatenate, ParamSpec, TypeVar +from typing import Any, TypeVar from pydase.task.task import Task logger = logging.getLogger(__name__) -P = ParamSpec("P") R = TypeVar("R") def task( *, autostart: bool = False -) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]: +) -> Callable[[Callable[[Any], Coroutine[None, None, R]]], Task[R]]: def decorator( - func: Callable[Concatenate[Any, P], Coroutine[None, None, R]], - ) -> Task[P, R]: + func: Callable[[Any], Coroutine[None, None, R]], + ) -> Task[R]: @functools.wraps(func) - async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None: + async def wrapper(self: Any) -> R | None: try: - return await func(self, *args, **kwargs) + return await func(self) except asyncio.CancelledError: logger.info("Task '%s' was cancelled", func.__name__) return None diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index 9c07b10..30c7ca1 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -5,9 +5,7 @@ from collections.abc import Callable, Coroutine from enum import Enum from typing import ( Any, - Concatenate, Generic, - ParamSpec, Self, TypeVar, ) @@ -19,14 +17,13 @@ import pydase logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -P = ParamSpec("P") R = TypeVar("R") def is_bound_method( - method: Callable[P, Coroutine[None, None, R | None]] - | Callable[Concatenate[Any, P], Coroutine[None, None, R | None]], -) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]: + method: Callable[[], Coroutine[None, None, R | None]] + | Callable[[Any], Coroutine[None, None, R | None]], +) -> TypeIs[Callable[[], Coroutine[None, None, R | None]]]: """Check if instance method is bound to an object.""" return inspect.ismethod(method) @@ -36,17 +33,17 @@ class TaskStatus(Enum): NOT_RUNNING = "not_running" -class Task(pydase.DataService, Generic[P, R]): +class Task(pydase.DataService, Generic[R]): def __init__( self, - func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]] - | Callable[P, Coroutine[None, None, R | None]], + func: Callable[[Any], Coroutine[None, None, R | None]] + | Callable[[], Coroutine[None, None, R | None]], *, autostart: bool = False, ) -> None: super().__init__() self._func_name = func.__name__ - self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None + self._bound_func: Callable[[], Coroutine[None, None, R | None]] | None = None if is_bound_method(func): self._func = func self._bound_func = func @@ -62,7 +59,7 @@ class Task(pydase.DataService, Generic[P, R]): def status(self) -> TaskStatus: return self._status - def start(self, *args: P.args, **kwargs: P.kwargs) -> None: + def start(self) -> None: if self._task: return @@ -93,7 +90,7 @@ class Task(pydase.DataService, Generic[P, R]): logger.info("Starting task %s", self._func_name) if inspect.iscoroutinefunction(self._bound_func): - res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs) + res: Coroutine[None, None, R] = self._bound_func() self._task = asyncio.create_task(res) self._task.add_done_callback(task_done_callback) self._status = TaskStatus.RUNNING @@ -107,9 +104,9 @@ class Task(pydase.DataService, Generic[P, R]): # containing the function if instance and self._bound_func is not None: - async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None: + async def bound_func() -> R | None: if not is_bound_method(self._func): - return await self._func(instance, *args, **kwargs) + return await self._func(instance) return None self._bound_func = bound_func