diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index 257ef98..9242145 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -3,7 +3,16 @@ import inspect import logging from collections.abc import Callable, Coroutine from enum import Enum -from typing import Any, Concatenate, Generic, ParamSpec, Self, TypeVar +from typing import ( + Any, + Concatenate, + Generic, + ParamSpec, + Self, + TypeVar, +) + +from typing_extensions import TypeIs import pydase @@ -14,6 +23,14 @@ P = ParamSpec("P") R = TypeVar("R") +def is_bound_method( + method: Callable[P, Coroutine[None, None, R | None]] + | Callable[Concatenate[Any, P], Coroutine[None, None, R | None]], +) -> TypeIs[Callable[P, Coroutine[None, None, R | None]]]: + """Check if instance method is bound to an object.""" + return inspect.ismethod(method) + + class TaskStatus(Enum): RUNNING = "running" NOT_RUNNING = "not_running" @@ -22,11 +39,17 @@ class TaskStatus(Enum): class Task(pydase.DataService, Generic[P, R]): def __init__( self, - func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]], + func: Callable[Concatenate[Any, P], Coroutine[None, None, R | None]] + | Callable[P, Coroutine[None, None, R | None]], ) -> None: super().__init__() - self._func = func + self._func_name = func.__name__ self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None + if is_bound_method(func): + self._func = func + self._bound_func = func + else: + self._func = func self._task: asyncio.Task[R] | None = None self._status = TaskStatus.NOT_RUNNING self._result: R | None = None @@ -56,7 +79,7 @@ class Task(pydase.DataService, Generic[P, R]): # Handle the exception, or you can re-raise it. logger.error( "Task '%s' encountered an exception: %s: %s", - self._func.__name__, + self._func_name, type(exception).__name__, exception, ) @@ -64,7 +87,7 @@ class Task(pydase.DataService, Generic[P, R]): self._result = task.result() - logger.info("Starting task %s", self._func.__name__) + logger.info("Starting task %s", self._func_name) if inspect.iscoroutinefunction(self._bound_func): res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs) self._task = asyncio.create_task(res) @@ -78,10 +101,12 @@ class Task(pydase.DataService, Generic[P, R]): def __get__(self, instance: Any, owner: Any) -> Self: # need to use this descriptor to bind the function to the instance of the class # containing the function - if instance: + if instance and self._bound_func is not None: async def bound_func(*args: P.args, **kwargs: P.kwargs) -> R | None: - return await self._func(instance, *args, **kwargs) + if not is_bound_method(self._func): + return await self._func(instance, *args, **kwargs) + return None self._bound_func = bound_func return self