diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index d029f46..7b9a628 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -26,6 +26,7 @@ class Task(pydase.DataService, Generic[P, R]): ) -> None: super().__init__() self._func = func + self._bound_func: Callable[P, Coroutine[None, None, R | None]] | None = None self._task: asyncio.Task[R] | None = None self._status = TaskStatus.NOT_RUNNING self._result: R | None = None @@ -61,25 +62,23 @@ class Task(pydase.DataService, Generic[P, R]): self._result = task.result() logger.info("Starting task") - if inspect.iscoroutinefunction(self._func): - logger.info("Is coroutine function.") - res: Coroutine[None, None, R] = self._func( - self._parent_obj, *args, **kwargs - ) + if inspect.iscoroutinefunction(self._bound_func): + res: Coroutine[None, None, R] = self._bound_func(*args, **kwargs) self._task = asyncio.create_task(res) self._task.add_done_callback(task_done_callback) self._status = TaskStatus.RUNNING - else: - logger.info("Is not a coroutine function.") def stop(self) -> None: if self._task: self._task.cancel() - def __get__(self, obj: Any, obj_type: Any) -> Self: + def __get__(self, instance: Any, owner: Any) -> Self: # need to use this descriptor to bind the function to the instance of the class # containing the function + if instance: - if obj is not None: - self._parent_obj = obj + async def bound_func(*args, **kwargs) -> R | None: + return await self._func(instance, *args, **kwargs) + + self._bound_func = bound_func return self