tasks are not allowed to have arguments anymore

This commit is contained in:
Mose Müller 2024-02-27 15:31:25 +01:00
parent b2f828ff6f
commit ca2182c19b

View File

@ -3,10 +3,15 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Any, TypedDict from enum import Enum
from typing import TYPE_CHECKING, Any
from pydase.data_service.abstract_data_service import AbstractDataService from pydase.data_service.abstract_data_service import AbstractDataService
from pydase.utils.helpers import get_class_and_instance_attributes from pydase.utils.helpers import (
function_has_arguments,
get_class_and_instance_attributes,
is_property_attribute,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
@ -16,9 +21,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskDict(TypedDict): class TaskDefinitionError(Exception):
task: asyncio.Task[None] pass
kwargs: dict[str, Any]
class TaskStatus(Enum):
RUNNING = "running"
class TaskManager: class TaskManager:
@ -78,7 +86,7 @@ class TaskManager:
def __init__(self, service: DataService) -> None: def __init__(self, service: DataService) -> None:
self.service = service self.service = service
self.tasks: dict[str, TaskDict] = {} self.tasks: dict[str, asyncio.Task[None]] = {}
"""A dictionary to keep track of running tasks. The keys are the names of the """A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and tasks and the values are TaskDict instances which include the task itself and
its kwargs. its kwargs.
@ -91,12 +99,25 @@ class TaskManager:
return asyncio.get_running_loop() return asyncio.get_running_loop()
def _set_start_and_stop_for_async_methods(self) -> None: def _set_start_and_stop_for_async_methods(self) -> None:
# inspect the methods of the class for name in dir(self.service):
for name, method in inspect.getmembers( # circumvents calling properties
self.service, predicate=inspect.iscoroutinefunction if is_property_attribute(self.service, name):
): continue
method = getattr(self.service, name)
if inspect.iscoroutinefunction(method):
if function_has_arguments(method):
raise TaskDefinitionError(
"Asynchronous functions (tasks) should be defined without "
f"arguments. The task '{method.__name__}' has at least one "
"argument. Please remove the argument(s) from this function to "
"use it."
)
# create start and stop methods for each coroutine # create start and stop methods for each coroutine
setattr(self.service, f"start_{name}", self._make_start_task(name, method)) setattr(
self.service, f"start_{name}", self._make_start_task(name, method)
)
setattr(self.service, f"stop_{name}", self._make_stop_task(name)) setattr(self.service, f"stop_{name}", self._make_stop_task(name))
def _initiate_task_startup(self) -> None: def _initiate_task_startup(self) -> None:
@ -137,7 +158,7 @@ class TaskManager:
# cancel the task # cancel the task
task = self.tasks.get(name, None) task = self.tasks.get(name, None)
if task is not None: if task is not None:
self._loop.call_soon_threadsafe(task["task"].cancel) self._loop.call_soon_threadsafe(task.cancel)
return stop_task return stop_task
@ -156,7 +177,7 @@ class TaskManager:
method (callable): The coroutine to be turned into an asyncio task. method (callable): The coroutine to be turned into an asyncio task.
""" """
def start_task(*args: Any, **kwargs: Any) -> None: def start_task() -> None:
def task_done_callback(task: asyncio.Task[None], name: str) -> None: def task_done_callback(task: asyncio.Task[None], name: str) -> None:
"""Handles tasks that have finished. """Handles tasks that have finished.
@ -180,36 +201,16 @@ class TaskManager:
) )
raise exception raise exception
async def task(*args: Any, **kwargs: Any) -> None: async def task() -> None:
try: try:
await method(*args, **kwargs) await method()
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", name) logger.info("Task '%s' was cancelled", name)
if not self.tasks.get(name): if not self.tasks.get(name):
# Get the signature of the coroutine method to start
sig = inspect.signature(method)
# Create a list of the parameter names from the method signature.
parameter_names = list(sig.parameters.keys())
# Extend the list of positional arguments with None values to match
# the length of the parameter names list. This is done to ensure
# that zip can pair each parameter name with a corresponding value.
args_padded = list(args) + [None] * (len(parameter_names) - len(args))
# Create a dictionary of keyword arguments by pairing the parameter
# names with the values in 'args_padded'. Then merge this dictionary
# with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = {
**dict(zip(parameter_names, args_padded, strict=True)),
**kwargs,
}
# creating the task and adding the task_done_callback which checks # creating the task and adding the task_done_callback which checks
# if an exception has occured during the task execution # if an exception has occured during the task execution
task_object = self._loop.create_task(task(*args, **kwargs)) task_object = self._loop.create_task(task())
task_object.add_done_callback( task_object.add_done_callback(
lambda task: task_done_callback(task, name) lambda task: task_done_callback(task, name)
) )
@ -217,13 +218,10 @@ class TaskManager:
# Store the task and its arguments in the '__tasks' dictionary. The # Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary # key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments. # containing the task object and the updated keyword arguments.
self.tasks[name] = { self.tasks[name] = task_object
"task": task_object,
"kwargs": kwargs_updated,
}
# emit the notification that the task was started # emit the notification that the task was started
self.service._notify_changed(name, kwargs_updated) self.service._notify_changed(name, TaskStatus.RUNNING)
else: else:
logger.error("Task '%s' is already running!", name) logger.error("Task '%s' is already running!", name)