diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index 3765a75..3481164 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -4,19 +4,16 @@ import logging import sys from collections.abc import Callable, Coroutine from typing import ( - Any, Generic, TypeVar, ) -from typing_extensions import TypeIs - from pydase.task.task_status import TaskStatus if sys.version_info < (3, 11): - from typing_extensions import Self + pass else: - from typing import Self + pass import pydase.data_service.data_service from pydase.utils.helpers import current_event_loop_exists @@ -27,17 +24,8 @@ logger = logging.getLogger(__name__) R = TypeVar("R") -def is_bound_method( - method: Callable[[], Coroutine[None, None, R | None]] - | Callable[[Any], Coroutine[None, None, R | None]], -) -> TypeIs[Callable[[], Coroutine[None, None, R | None]]]: - """Check if instance method is bound to an object.""" - return inspect.ismethod(method) - - class Task(pydase.data_service.data_service.DataService, Generic[R]): - """ - A class representing a task within the `pydase` framework. + """A class representing a task within the `pydase` framework. The `Task` class wraps an asynchronous function and provides methods to manage its lifecycle, such as `start()` and `stop()`. It is typically used to perform periodic @@ -85,25 +73,24 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): def __init__( self, - func: Callable[[Any], Coroutine[None, None, R | None]] - | Callable[[], Coroutine[None, None, R | None]], + func: Callable[[], Coroutine[None, None, R | None]], *, autostart: bool = False, ) -> None: super().__init__() self._autostart = autostart self._func_name = func.__name__ - self._bound_func: Callable[[], Coroutine[None, None, R | None]] | None = None - self._set_up = False - if is_bound_method(func): - self._func = func - self._bound_func = func - else: - self._func = func + self._func = func self._task: asyncio.Task[R | None] | None = None self._status = TaskStatus.NOT_RUNNING self._result: R | None = None + if not current_event_loop_exists(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + else: + self._loop = asyncio.get_event_loop() + @property def autostart(self) -> bool: """Defines if the task should be started automatically when the @@ -144,10 +131,10 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): self._result = task.result() async def run_task() -> R | None: - if inspect.iscoroutinefunction(self._bound_func): + if inspect.iscoroutinefunction(self._func): logger.info("Starting task %r", self._func_name) self._status = TaskStatus.RUNNING - res: Coroutine[None, None, R] = self._bound_func() + res: Coroutine[None, None, R | None] = self._func() try: return await res except asyncio.CancelledError: @@ -167,24 +154,3 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): if self._task: self._task.cancel() - - def __get__(self, instance: Any, owner: Any) -> Self: - """Descriptor method used to correctly set up the task. - - This descriptor method is called by the class instance containing the task. - It binds the task function to that class instance. - - Since the `__init__` function is called when a function is decorated with - [`@task`][pydase.task.decorator.task], some setup is delayed until this - descriptor function is called. - """ - - if instance and not self._set_up: - if not current_event_loop_exists(): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - else: - self._loop = asyncio.get_event_loop() - self._bound_func = self._func.__get__(instance, owner) - self._set_up = True - return self