task can receive bound and unbound functions now

This commit is contained in:
Mose Müller 2024-08-06 09:29:18 +02:00
parent 7ddcd97f68
commit e4a3cf341f

View File

@ -3,7 +3,16 @@ import inspect
import logging
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar
from typing import (
Any,
Concatenate,
Generic,
ParamSpec,
Self,
TypeVar,
)
from typing_extensions import TypeIs
import pydase
@ -14,6 +23,14 @@ P = ParamSpec("P")
R = TypeVar("R")
def is_bound_method(
method: Callable[P, Coroutine[None, None, R | None]]
| Callable[Concatenate[Any, P], Coroutine[None, None, R | None]],
) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]:
"""Check if instance method is bound to an object."""
return inspect.ismethod(method)
class TaskStatus(Enum):
RUNNING = "running"
NOT_RUNNING = "not_running"
@ -22,11 +39,17 @@ class TaskStatus(Enum):
class Task(pydase.DataService, Generic[P, R]):
def __init__(
self,
func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]],
func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]]
| Callable[P, Coroutine[None, None, R | None]],
) -> None:
super().__init__()
self._func = func
self._func_name = func.__name__
self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None
if is_bound_method(func):
self._func = func
self._bound_func = func
else:
self._func = func
self._task: asyncio.Task[R] | None = None
self._status = TaskStatus.NOT_RUNNING
self._result: R | None = None
@ -56,7 +79,7 @@ class Task(pydase.DataService, Generic[P, R]):
# Handle the exception, or you can re-raise it.
logger.error(
"Task '%s' encountered an exception: %s: %s",
self._func.__name__,
self._func_name,
type(exception).__name__,
exception,
)
@ -64,7 +87,7 @@ class Task(pydase.DataService, Generic[P, R]):
self._result = task.result()
logger.info("Starting task %s", self._func.__name__)
logger.info("Starting task %s", self._func_name)
if inspect.iscoroutinefunction(self._bound_func):
res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs)
self._task = asyncio.create_task(res)
@ -78,10 +101,12 @@ class Task(pydase.DataService, Generic[P, R]):
def __get__(self, instance: Any, owner: Any) -> Self:
# need to use this descriptor to bind the function to the instance of the class
# containing the function
if instance:
if instance and self._bound_func is not None:
async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None:
if not is_bound_method(self._func):
return await self._func(instance, *args, **kwargs)
return None
self._bound_func = bound_func
return self