diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index 6651f54..dff5da5 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -55,7 +55,7 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): self._bound_func = func else: self._func = func - self._task: asyncio.Task[R] | None = None + self._task: asyncio.Task[R | None] | None = None self._status = TaskStatus.NOT_RUNNING self._result: R | None = None if autostart: @@ -69,7 +69,7 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): if self._task: return - def task_done_callback(task: asyncio.Task[R]) -> None: + def task_done_callback(task: asyncio.Task[R | None]) -> None: """Handles tasks that have finished. Removes a task from the tasks dictionary, calls the defined @@ -94,12 +94,17 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): self._result = task.result() - logger.info("Starting task %s", self._func_name) - if inspect.iscoroutinefunction(self._bound_func): - res: Coroutine[None, None, R] = self._bound_func() - self._task = asyncio.create_task(res) - self._task.add_done_callback(task_done_callback) - self._status = TaskStatus.RUNNING + async def run_task() -> R | None: + if inspect.iscoroutinefunction(self._bound_func): + logger.info("Starting task %r", self._func_name) + self._status = TaskStatus.RUNNING + res: Coroutine[None, None, R] = self._bound_func() + return await res + return None + + logger.info("Creating task %r", self._func_name) + self._task = self._loop.create_task(run_task()) + self._task.add_done_callback(task_done_callback) def stop(self) -> None: if self._task: