diff --git a/src/pydase/task/__init__.py b/src/pydase/task/__init__.py new file mode 100644 index 0000000..69c7c2f --- /dev/null +++ b/src/pydase/task/__init__.py @@ -0,0 +1,3 @@ +from pydase.task.decorator import task + +__all__ = ["task"] diff --git a/src/pydase/task/decorator.py b/src/pydase/task/decorator.py new file mode 100644 index 0000000..ee704a9 --- /dev/null +++ b/src/pydase/task/decorator.py @@ -0,0 +1,29 @@ +import asyncio +import logging +from collections.abc import Callable, Coroutine +from typing import Any, Concatenate, ParamSpec, TypeVar + +from pydase.task.task import Task + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +def task( + *, autostart: bool = False +) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]: + def decorator( + func: Callable[Concatenate[Any, P], Coroutine[None, None, R]], + ) -> Task[P, R]: + async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None: + try: + return await func(self, *args, **kwargs) + except asyncio.CancelledError: + logger.info("Task '%s' was cancelled", func.__name__) + return None + + return Task(wrapper) + + return decorator diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py new file mode 100644 index 0000000..fcd2d89 --- /dev/null +++ b/src/pydase/task/task.py @@ -0,0 +1,86 @@ +import asyncio +import inspect +import logging +from collections.abc import Callable, Coroutine +from enum import Enum +from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar + +import pydase + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +class TaskStatus(Enum): + RUNNING = "running" + NOT_RUNNING = "not_running" + + +class Task(pydase.DataService, Generic[P, R]): + def __init__( + self, + func: Callable[Concatenate[Any, P], R | None] + | Callable[Concatenate[Any, P], Coroutine[Any, Any, R | None]], + ) -> None: + super().__init__() + self._func = func + self._task: asyncio.Task[R] | None = None + self._status = TaskStatus.NOT_RUNNING + self._result: R | None = None + + @property + def status(self) -> TaskStatus: + return self._status + + def start(self, *args: P.args, **kwargs: P.kwargs) -> None: + def task_done_callback(task: asyncio.Task[R]) -> None: + """Handles tasks that have finished. + + Removes a task from the tasks dictionary, calls the defined + callbacks, and logs and re-raises exceptions.""" + + # removing the finished task from the tasks i + self._task = None + + # emit the notification that the task was stopped + self._status = TaskStatus.NOT_RUNNING + + exception = task.exception() + if exception is not None: + # Handle the exception, or you can re-raise it. + logger.error( + "Task '%s' encountered an exception: %s: %s", + self._func.__name__, + type(exception).__name__, + exception, + ) + raise exception + + self._result = task.result() + + logger.info("Starting task") + if inspect.iscoroutinefunction(self._func): + logger.info("Is coroutine function.") + res: Coroutine[None, None, R] = self._func( + self._parent_obj, *args, **kwargs + ) + self._task = asyncio.create_task(res) + self._task.add_done_callback(task_done_callback) + self._status = TaskStatus.RUNNING + else: + logger.info("Is not a coroutine function.") + + def stop(self) -> None: + if self._task: + self._task.cancel() + + def __get__(self, obj: Any, obj_type: Any) -> Self: + # need to use this descriptor to bind the function to the instance of the class + # containing the function + + if obj is not None: + self._parent_obj = obj + return self