mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-22 17:10:02 +02:00
feat: first Task implementation
This commit is contained in:
parent
743c18bdd7
commit
c34351270c
3
src/pydase/task/__init__.py
Normal file
3
src/pydase/task/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from pydase.task.decorator import task
|
||||||
|
|
||||||
|
__all__ = ["task"]
|
29
src/pydase/task/decorator.py
Normal file
29
src/pydase/task/decorator.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from pydase.task.task import Task
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
def task(
|
||||||
|
*, autostart: bool = False
|
||||||
|
) -> Callable[[Callable[Concatenate[Any, P], Coroutine[None, None, R]]], Task[P, R]]:
|
||||||
|
def decorator(
|
||||||
|
func: Callable[Concatenate[Any, P], Coroutine[None, None, R]],
|
||||||
|
) -> Task[P, R]:
|
||||||
|
async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R | None:
|
||||||
|
try:
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task '%s' was cancelled", func.__name__)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Task(wrapper)
|
||||||
|
|
||||||
|
return decorator
|
86
src/pydase/task/task.py
Normal file
86
src/pydase/task/task.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar
|
||||||
|
|
||||||
|
import pydase
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
RUNNING = "running"
|
||||||
|
NOT_RUNNING = "not_running"
|
||||||
|
|
||||||
|
|
||||||
|
class Task(pydase.DataService, Generic[P, R]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func: Callable[Concatenate[Any, P], R | None]
|
||||||
|
| Callable[Concatenate[Any, P], Coroutine[Any, Any, R | None]],
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._func = func
|
||||||
|
self._task: asyncio.Task[R] | None = None
|
||||||
|
self._status = TaskStatus.NOT_RUNNING
|
||||||
|
self._result: R | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> TaskStatus:
|
||||||
|
return self._status
|
||||||
|
|
||||||
|
def start(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
|
def task_done_callback(task: asyncio.Task[R]) -> None:
|
||||||
|
"""Handles tasks that have finished.
|
||||||
|
|
||||||
|
Removes a task from the tasks dictionary, calls the defined
|
||||||
|
callbacks, and logs and re-raises exceptions."""
|
||||||
|
|
||||||
|
# removing the finished task from the tasks i
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
# emit the notification that the task was stopped
|
||||||
|
self._status = TaskStatus.NOT_RUNNING
|
||||||
|
|
||||||
|
exception = task.exception()
|
||||||
|
if exception is not None:
|
||||||
|
# Handle the exception, or you can re-raise it.
|
||||||
|
logger.error(
|
||||||
|
"Task '%s' encountered an exception: %s: %s",
|
||||||
|
self._func.__name__,
|
||||||
|
type(exception).__name__,
|
||||||
|
exception,
|
||||||
|
)
|
||||||
|
raise exception
|
||||||
|
|
||||||
|
self._result = task.result()
|
||||||
|
|
||||||
|
logger.info("Starting task")
|
||||||
|
if inspect.iscoroutinefunction(self._func):
|
||||||
|
logger.info("Is coroutine function.")
|
||||||
|
res: Coroutine[None, None, R] = self._func(
|
||||||
|
self._parent_obj, *args, **kwargs
|
||||||
|
)
|
||||||
|
self._task = asyncio.create_task(res)
|
||||||
|
self._task.add_done_callback(task_done_callback)
|
||||||
|
self._status = TaskStatus.RUNNING
|
||||||
|
else:
|
||||||
|
logger.info("Is not a coroutine function.")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
|
||||||
|
def __get__(self, obj: Any, obj_type: Any) -> Self:
|
||||||
|
# need to use this descriptor to bind the function to the instance of the class
|
||||||
|
# containing the function
|
||||||
|
|
||||||
|
if obj is not None:
|
||||||
|
self._parent_obj = obj
|
||||||
|
return self
|
Loading…
x
Reference in New Issue
Block a user