mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-21 16:50:02 +02:00
tasks are not allowed to have arguments anymore
This commit is contained in:
parent
b2f828ff6f
commit
ca2182c19b
@ -3,10 +3,15 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydase.data_service.abstract_data_service import AbstractDataService
|
from pydase.data_service.abstract_data_service import AbstractDataService
|
||||||
from pydase.utils.helpers import get_class_and_instance_attributes
|
from pydase.utils.helpers import (
|
||||||
|
function_has_arguments,
|
||||||
|
get_class_and_instance_attributes,
|
||||||
|
is_property_attribute,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@ -16,9 +21,12 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TaskDict(TypedDict):
|
class TaskDefinitionError(Exception):
|
||||||
task: asyncio.Task[None]
|
pass
|
||||||
kwargs: dict[str, Any]
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
RUNNING = "running"
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
@ -78,7 +86,7 @@ class TaskManager:
|
|||||||
def __init__(self, service: DataService) -> None:
|
def __init__(self, service: DataService) -> None:
|
||||||
self.service = service
|
self.service = service
|
||||||
|
|
||||||
self.tasks: dict[str, TaskDict] = {}
|
self.tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
"""A dictionary to keep track of running tasks. The keys are the names of the
|
"""A dictionary to keep track of running tasks. The keys are the names of the
|
||||||
tasks and the values are TaskDict instances which include the task itself and
|
tasks and the values are TaskDict instances which include the task itself and
|
||||||
its kwargs.
|
its kwargs.
|
||||||
@ -91,12 +99,25 @@ class TaskManager:
|
|||||||
return asyncio.get_running_loop()
|
return asyncio.get_running_loop()
|
||||||
|
|
||||||
def _set_start_and_stop_for_async_methods(self) -> None:
|
def _set_start_and_stop_for_async_methods(self) -> None:
|
||||||
# inspect the methods of the class
|
for name in dir(self.service):
|
||||||
for name, method in inspect.getmembers(
|
# circumvents calling properties
|
||||||
self.service, predicate=inspect.iscoroutinefunction
|
if is_property_attribute(self.service, name):
|
||||||
):
|
continue
|
||||||
|
|
||||||
|
method = getattr(self.service, name)
|
||||||
|
if inspect.iscoroutinefunction(method):
|
||||||
|
if function_has_arguments(method):
|
||||||
|
raise TaskDefinitionError(
|
||||||
|
"Asynchronous functions (tasks) should be defined without "
|
||||||
|
f"arguments. The task '{method.__name__}' has at least one "
|
||||||
|
"argument. Please remove the argument(s) from this function to "
|
||||||
|
"use it."
|
||||||
|
)
|
||||||
|
|
||||||
# create start and stop methods for each coroutine
|
# create start and stop methods for each coroutine
|
||||||
setattr(self.service, f"start_{name}", self._make_start_task(name, method))
|
setattr(
|
||||||
|
self.service, f"start_{name}", self._make_start_task(name, method)
|
||||||
|
)
|
||||||
setattr(self.service, f"stop_{name}", self._make_stop_task(name))
|
setattr(self.service, f"stop_{name}", self._make_stop_task(name))
|
||||||
|
|
||||||
def _initiate_task_startup(self) -> None:
|
def _initiate_task_startup(self) -> None:
|
||||||
@ -137,7 +158,7 @@ class TaskManager:
|
|||||||
# cancel the task
|
# cancel the task
|
||||||
task = self.tasks.get(name, None)
|
task = self.tasks.get(name, None)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
self._loop.call_soon_threadsafe(task["task"].cancel)
|
self._loop.call_soon_threadsafe(task.cancel)
|
||||||
|
|
||||||
return stop_task
|
return stop_task
|
||||||
|
|
||||||
@ -156,7 +177,7 @@ class TaskManager:
|
|||||||
method (callable): The coroutine to be turned into an asyncio task.
|
method (callable): The coroutine to be turned into an asyncio task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def start_task(*args: Any, **kwargs: Any) -> None:
|
def start_task() -> None:
|
||||||
def task_done_callback(task: asyncio.Task[None], name: str) -> None:
|
def task_done_callback(task: asyncio.Task[None], name: str) -> None:
|
||||||
"""Handles tasks that have finished.
|
"""Handles tasks that have finished.
|
||||||
|
|
||||||
@ -180,36 +201,16 @@ class TaskManager:
|
|||||||
)
|
)
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
async def task(*args: Any, **kwargs: Any) -> None:
|
async def task() -> None:
|
||||||
try:
|
try:
|
||||||
await method(*args, **kwargs)
|
await method()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Task '%s' was cancelled", name)
|
logger.info("Task '%s' was cancelled", name)
|
||||||
|
|
||||||
if not self.tasks.get(name):
|
if not self.tasks.get(name):
|
||||||
# Get the signature of the coroutine method to start
|
|
||||||
sig = inspect.signature(method)
|
|
||||||
|
|
||||||
# Create a list of the parameter names from the method signature.
|
|
||||||
parameter_names = list(sig.parameters.keys())
|
|
||||||
|
|
||||||
# Extend the list of positional arguments with None values to match
|
|
||||||
# the length of the parameter names list. This is done to ensure
|
|
||||||
# that zip can pair each parameter name with a corresponding value.
|
|
||||||
args_padded = list(args) + [None] * (len(parameter_names) - len(args))
|
|
||||||
|
|
||||||
# Create a dictionary of keyword arguments by pairing the parameter
|
|
||||||
# names with the values in 'args_padded'. Then merge this dictionary
|
|
||||||
# with the 'kwargs' dictionary. If a parameter is specified in both
|
|
||||||
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
|
||||||
kwargs_updated = {
|
|
||||||
**dict(zip(parameter_names, args_padded, strict=True)),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# creating the task and adding the task_done_callback which checks
|
# creating the task and adding the task_done_callback which checks
|
||||||
# if an exception has occured during the task execution
|
# if an exception has occured during the task execution
|
||||||
task_object = self._loop.create_task(task(*args, **kwargs))
|
task_object = self._loop.create_task(task())
|
||||||
task_object.add_done_callback(
|
task_object.add_done_callback(
|
||||||
lambda task: task_done_callback(task, name)
|
lambda task: task_done_callback(task, name)
|
||||||
)
|
)
|
||||||
@ -217,13 +218,10 @@ class TaskManager:
|
|||||||
# Store the task and its arguments in the '__tasks' dictionary. The
|
# Store the task and its arguments in the '__tasks' dictionary. The
|
||||||
# key is the name of the method, and the value is a dictionary
|
# key is the name of the method, and the value is a dictionary
|
||||||
# containing the task object and the updated keyword arguments.
|
# containing the task object and the updated keyword arguments.
|
||||||
self.tasks[name] = {
|
self.tasks[name] = task_object
|
||||||
"task": task_object,
|
|
||||||
"kwargs": kwargs_updated,
|
|
||||||
}
|
|
||||||
|
|
||||||
# emit the notification that the task was started
|
# emit the notification that the task was started
|
||||||
self.service._notify_changed(name, kwargs_updated)
|
self.service._notify_changed(name, TaskStatus.RUNNING)
|
||||||
else:
|
else:
|
||||||
logger.error("Task '%s' is already running!", name)
|
logger.error("Task '%s' is already running!", name)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user