diff --git a/src/pydase/data_service/task_manager.py b/src/pydase/data_service/task_manager.py index 6c33458..6bca414 100644 --- a/src/pydase/data_service/task_manager.py +++ b/src/pydase/data_service/task_manager.py @@ -3,10 +3,15 @@ from __future__ import annotations import asyncio import inspect import logging -from typing import TYPE_CHECKING, Any, TypedDict +from enum import Enum +from typing import TYPE_CHECKING, Any from pydase.data_service.abstract_data_service import AbstractDataService -from pydase.utils.helpers import get_class_and_instance_attributes +from pydase.utils.helpers import ( + function_has_arguments, + get_class_and_instance_attributes, + is_property_attribute, +) if TYPE_CHECKING: from collections.abc import Callable @@ -16,9 +21,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TaskDict(TypedDict): - task: asyncio.Task[None] - kwargs: dict[str, Any] +class TaskDefinitionError(Exception): + pass + + +class TaskStatus(Enum): + RUNNING = "running" class TaskManager: @@ -78,7 +86,7 @@ class TaskManager: def __init__(self, service: DataService) -> None: self.service = service - self.tasks: dict[str, TaskDict] = {} + self.tasks: dict[str, asyncio.Task[None]] = {} """A dictionary to keep track of running tasks. The keys are the names of the tasks and the values are TaskDict instances which include the task itself and its kwargs. @@ -91,13 +99,26 @@ class TaskManager: return asyncio.get_running_loop() def _set_start_and_stop_for_async_methods(self) -> None: - # inspect the methods of the class - for name, method in inspect.getmembers( - self.service, predicate=inspect.iscoroutinefunction - ): - # create start and stop methods for each coroutine - setattr(self.service, f"start_{name}", self._make_start_task(name, method)) - setattr(self.service, f"stop_{name}", self._make_stop_task(name)) + for name in dir(self.service): + # circumvents calling properties + if is_property_attribute(self.service, name): + continue + + method = getattr(self.service, name) + if inspect.iscoroutinefunction(method): + if function_has_arguments(method): + raise TaskDefinitionError( + "Asynchronous functions (tasks) should be defined without " + f"arguments. The task '{method.__name__}' has at least one " + "argument. Please remove the argument(s) from this function to " + "use it." + ) + + # create start and stop methods for each coroutine + setattr( + self.service, f"start_{name}", self._make_start_task(name, method) + ) + setattr(self.service, f"stop_{name}", self._make_stop_task(name)) def _initiate_task_startup(self) -> None: if self.service._autostart_tasks is not None: @@ -137,7 +158,7 @@ class TaskManager: # cancel the task task = self.tasks.get(name, None) if task is not None: - self._loop.call_soon_threadsafe(task["task"].cancel) + self._loop.call_soon_threadsafe(task.cancel) return stop_task @@ -156,7 +177,7 @@ class TaskManager: method (callable): The coroutine to be turned into an asyncio task. """ - def start_task(*args: Any, **kwargs: Any) -> None: + def start_task() -> None: def task_done_callback(task: asyncio.Task[None], name: str) -> None: """Handles tasks that have finished. @@ -180,36 +201,16 @@ class TaskManager: ) raise exception - async def task(*args: Any, **kwargs: Any) -> None: + async def task() -> None: try: - await method(*args, **kwargs) + await method() except asyncio.CancelledError: logger.info("Task '%s' was cancelled", name) if not self.tasks.get(name): - # Get the signature of the coroutine method to start - sig = inspect.signature(method) - - # Create a list of the parameter names from the method signature. - parameter_names = list(sig.parameters.keys()) - - # Extend the list of positional arguments with None values to match - # the length of the parameter names list. This is done to ensure - # that zip can pair each parameter name with a corresponding value. - args_padded = list(args) + [None] * (len(parameter_names) - len(args)) - - # Create a dictionary of keyword arguments by pairing the parameter - # names with the values in 'args_padded'. Then merge this dictionary - # with the 'kwargs' dictionary. If a parameter is specified in both - # 'args_padded' and 'kwargs', the value from 'kwargs' is used. - kwargs_updated = { - **dict(zip(parameter_names, args_padded, strict=True)), - **kwargs, - } - # creating the task and adding the task_done_callback which checks # if an exception has occured during the task execution - task_object = self._loop.create_task(task(*args, **kwargs)) + task_object = self._loop.create_task(task()) task_object.add_done_callback( lambda task: task_done_callback(task, name) ) @@ -217,13 +218,10 @@ class TaskManager: # Store the task and its arguments in the '__tasks' dictionary. The # key is the name of the method, and the value is a dictionary # containing the task object and the updated keyword arguments. - self.tasks[name] = { - "task": task_object, - "kwargs": kwargs_updated, - } + self.tasks[name] = task_object # emit the notification that the task was started - self.service._notify_changed(name, kwargs_updated) + self.service._notify_changed(name, TaskStatus.RUNNING) else: logger.error("Task '%s' is already running!", name)