task can receive bound and unbound functions now

This commit is contained in:
Mose Müller 2024-08-06 09:29:18 +02:00
parent 7ddcd97f68
commit e4a3cf341f

View File

@ -3,7 +3,16 @@ import inspect
import logging import logging
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from enum import Enum from enum import Enum
from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar from typing import (
Any,
Concatenate,
Generic,
ParamSpec,
Self,
TypeVar,
)
from typing_extensions import TypeIs
import pydase import pydase
@ -14,6 +23,14 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
def is_bound_method(
method: Callable[P, Coroutine[None, None, R | None]]
| Callable[Concatenate[Any, P], Coroutine[None, None, R | None]],
) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]:
"""Check if instance method is bound to an object."""
return inspect.ismethod(method)
class TaskStatus(Enum): class TaskStatus(Enum):
RUNNING = "running" RUNNING = "running"
NOT_RUNNING = "not_running" NOT_RUNNING = "not_running"
@ -22,11 +39,17 @@ class TaskStatus(Enum):
class Task(pydase.DataService, Generic[P, R]): class Task(pydase.DataService, Generic[P, R]):
def __init__( def __init__(
self, self,
func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]], func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]]
| Callable[P, Coroutine[None, None, R | None]],
) -> None: ) -> None:
super().__init__() super().__init__()
self._func = func self._func_name = func.__name__
self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None
if is_bound_method(func):
self._func = func
self._bound_func = func
else:
self._func = func
self._task: asyncio.Task[R] | None = None self._task: asyncio.Task[R] | None = None
self._status = TaskStatus.NOT_RUNNING self._status = TaskStatus.NOT_RUNNING
self._result: R | None = None self._result: R | None = None
@ -56,7 +79,7 @@ class Task(pydase.DataService, Generic[P, R]):
# Handle the exception, or you can re-raise it. # Handle the exception, or you can re-raise it.
logger.error( logger.error(
"Task '%s' encountered an exception: %s: %s", "Task '%s' encountered an exception: %s: %s",
self._func.__name__, self._func_name,
type(exception).__name__, type(exception).__name__,
exception, exception,
) )
@ -64,7 +87,7 @@ class Task(pydase.DataService, Generic[P, R]):
self._result = task.result() self._result = task.result()
logger.info("Starting task %s", self._func.__name__) logger.info("Starting task %s", self._func_name)
if inspect.iscoroutinefunction(self._bound_func): if inspect.iscoroutinefunction(self._bound_func):
res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs) res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs)
self._task = asyncio.create_task(res) self._task = asyncio.create_task(res)
@ -78,10 +101,12 @@ class Task(pydase.DataService, Generic[P, R]):
def __get__(self, instance: Any, owner: Any) -> Self: def __get__(self, instance: Any, owner: Any) -> Self:
# need to use this descriptor to bind the function to the instance of the class # need to use this descriptor to bind the function to the instance of the class
# containing the function # containing the function
if instance: if instance and self._bound_func is not None:
async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None: async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None:
return await self._func(instance, *args, **kwargs) if not is_bound_method(self._func):
return await self._func(instance, *args, **kwargs)
return None
self._bound_func = bound_func self._bound_func = bound_func
return self return self