tasks are not allowed to have arguments anymore

This commit is contained in:
Mose Müller 2024-02-27 15:31:25 +01:00
parent b2f828ff6f
commit ca2182c19b

View File

@ -3,10 +3,15 @@ from __future__ import annotations
import asyncio
import inspect
import logging
from typing import TYPE_CHECKING, Any, TypedDict
from enum import Enum
from typing import TYPE_CHECKING, Any
from pydase.data_service.abstract_data_service import AbstractDataService
from pydase.utils.helpers import get_class_and_instance_attributes
from pydase.utils.helpers import (
function_has_arguments,
get_class_and_instance_attributes,
is_property_attribute,
)
if TYPE_CHECKING:
from collections.abc import Callable
@ -16,9 +21,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class TaskDefinitionError(Exception):
pass
class TaskStatus(Enum):
RUNNING = "running"
class TaskManager:
@ -78,7 +86,7 @@ class TaskManager:
def __init__(self, service: DataService) -> None:
self.service = service
self.tasks: dict[str, TaskDict] = {}
self.tasks: dict[str, asyncio.Task[None]] = {}
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
@ -91,13 +99,26 @@ class TaskManager:
return asyncio.get_running_loop()
def _set_start_and_stop_for_async_methods(self) -> None:
# inspect the methods of the class
for name, method in inspect.getmembers(
self.service, predicate=inspect.iscoroutinefunction
):
# create start and stop methods for each coroutine
setattr(self.service, f"start_{name}", self._make_start_task(name, method))
setattr(self.service, f"stop_{name}", self._make_stop_task(name))
for name in dir(self.service):
# circumvents calling properties
if is_property_attribute(self.service, name):
continue
method = getattr(self.service, name)
if inspect.iscoroutinefunction(method):
if function_has_arguments(method):
raise TaskDefinitionError(
"Asynchronous functions (tasks) should be defined without "
f"arguments. The task '{method.__name__}' has at least one "
"argument. Please remove the argument(s) from this function to "
"use it."
)
# create start and stop methods for each coroutine
setattr(
self.service, f"start_{name}", self._make_start_task(name, method)
)
setattr(self.service, f"stop_{name}", self._make_stop_task(name))
def _initiate_task_startup(self) -> None:
if self.service._autostart_tasks is not None:
@ -137,7 +158,7 @@ class TaskManager:
# cancel the task
task = self.tasks.get(name, None)
if task is not None:
self._loop.call_soon_threadsafe(task["task"].cancel)
self._loop.call_soon_threadsafe(task.cancel)
return stop_task
@ -156,7 +177,7 @@ class TaskManager:
method (callable): The coroutine to be turned into an asyncio task.
"""
def start_task(*args: Any, **kwargs: Any) -> None:
def start_task() -> None:
def task_done_callback(task: asyncio.Task[None], name: str) -> None:
"""Handles tasks that have finished.
@ -180,36 +201,16 @@ class TaskManager:
)
raise exception
async def task(*args: Any, **kwargs: Any) -> None:
async def task() -> None:
try:
await method(*args, **kwargs)
await method()
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", name)
if not self.tasks.get(name):
# Get the signature of the coroutine method to start
sig = inspect.signature(method)
# Create a list of the parameter names from the method signature.
parameter_names = list(sig.parameters.keys())
# Extend the list of positional arguments with None values to match
# the length of the parameter names list. This is done to ensure
# that zip can pair each parameter name with a corresponding value.
args_padded = list(args) + [None] * (len(parameter_names) - len(args))
# Create a dictionary of keyword arguments by pairing the parameter
# names with the values in 'args_padded'. Then merge this dictionary
# with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = {
**dict(zip(parameter_names, args_padded, strict=True)),
**kwargs,
}
# creating the task and adding the task_done_callback which checks
# if an exception has occured during the task execution
task_object = self._loop.create_task(task(*args, **kwargs))
task_object = self._loop.create_task(task())
task_object.add_done_callback(
lambda task: task_done_callback(task, name)
)
@ -217,13 +218,10 @@ class TaskManager:
# Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments.
self.tasks[name] = {
"task": task_object,
"kwargs": kwargs_updated,
}
self.tasks[name] = task_object
# emit the notification that the task was started
self.service._notify_changed(name, kwargs_updated)
self.service._notify_changed(name, TaskStatus.RUNNING)
else:
logger.error("Task '%s' is already running!", name)