feat: first Task implementation

This commit is contained in:
Mose Müller 2024-08-05 15:10:04 +02:00
parent 743c18bdd7
commit c34351270c
3 changed files with 118 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from pydase.task.decorator import task
__all__ = ["task"]

View File

@ -0,0 +1,29 @@
import asyncio
import logging
from collections.abc import Callable, Coroutine
from typing import Any, Concatenate, ParamSpec, TypeVar
from pydase.task.task import Task
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def task(
*, autostart: bool = False
) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]:
def decorator(
func: Callable[Concatenate[Any, P], Coroutine[None, None, R]],
) -> Task[P, R]:
async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None:
try:
return await func(self, *args, **kwargs)
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", func.__name__)
return None
return Task(wrapper)
return decorator

86
src/pydase/task/task.py Normal file
View File

@ -0,0 +1,86 @@
import asyncio
import inspect
import logging
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar
import pydase
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class TaskStatus(Enum):
RUNNING = "running"
NOT_RUNNING = "not_running"
class Task(pydase.DataService, Generic[P, R]):
def __init__(
self,
func: Callable[Concatenate[Any, P], R | None]
| Callable[Concatenate[Any, P], Coroutine[Any, Any, R | None]],
) -> None:
super().__init__()
self._func = func
self._task: asyncio.Task[R] | None = None
self._status = TaskStatus.NOT_RUNNING
self._result: R | None = None
@property
def status(self) -> TaskStatus:
return self._status
def start(self, *args: P.args, **kwargs: P.kwargs) -> None:
def task_done_callback(task: asyncio.Task[R]) -> None:
"""Handles tasks that have finished.
Removes a task from the tasks dictionary, calls the defined
callbacks, and logs and re-raises exceptions."""
# removing the finished task from the tasks i
self._task = None
# emit the notification that the task was stopped
self._status = TaskStatus.NOT_RUNNING
exception = task.exception()
if exception is not None:
# Handle the exception, or you can re-raise it.
logger.error(
"Task '%s' encountered an exception: %s: %s",
self._func.__name__,
type(exception).__name__,
exception,
)
raise exception
self._result = task.result()
logger.info("Starting task")
if inspect.iscoroutinefunction(self._func):
logger.info("Is coroutine function.")
res: Coroutine[None, None, R] = self._func(
self._parent_obj, *args, **kwargs
)
self._task = asyncio.create_task(res)
self._task.add_done_callback(task_done_callback)
self._status = TaskStatus.RUNNING
else:
logger.info("Is not a coroutine function.")
def stop(self) -> None:
if self._task:
self._task.cancel()
def __get__(self, obj: Any, obj_type: Any) -> Self:
# need to use this descriptor to bind the function to the instance of the class
# containing the function
if obj is not None:
self._parent_obj = obj
return self