updating TaskManager

This commit is contained in:
Mose Müller 2023-08-02 12:06:22 +02:00
parent e8dd332753
commit ac9f39ca56
4 changed files with 13 additions and 19 deletions

View File

@ -25,10 +25,10 @@ class TaskDict(TypedDict):
class AbstractTaskManager(ABC):
_task_status_change_callbacks: list[Callable[[str, dict[str, Any] | None], Any]]
task_status_change_callbacks: list[Callable[[str, dict[str, Any] | None], Any]]
"""A list of callback functions to be invoked when the status of a task (start or
stop) changes."""
_tasks: dict[str, TaskDict]
tasks: dict[str, TaskDict]
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.

View File

@ -321,7 +321,7 @@ class CallbackManager(AbstractCallbackManager):
else None
)
obj._task_manager._task_status_change_callbacks.append(callback)
obj._task_manager.task_status_change_callbacks.append(callback)
# Recursively register callbacks for all nested attributes of the object
attrs: dict[str, Any] = get_class_and_instance_attributes(obj)

View File

@ -247,9 +247,9 @@ class DataService(rpyc.Service, AbstractDataService):
}
running_task_info = None
if (
key in self._task_manager._tasks
key in self._task_manager.tasks
): # If there's a running task for this method
task_info = self._task_manager._tasks[key]
task_info = self._task_manager.tasks[key]
running_task_info = task_info["kwargs"]
result[key] = {

View File

@ -1,19 +1,13 @@
import asyncio
import inspect
from functools import wraps
from typing import TypedDict
from typing import Any
from loguru import logger
from tiqi_rpc import Any
from .abstract_service_classes import AbstractDataService, AbstractTaskManager
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class TaskManager(AbstractTaskManager):
"""
The TaskManager class is a utility designed to manage asynchronous tasks. It
@ -72,9 +66,9 @@ class TaskManager(AbstractTaskManager):
self.service = service
self._loop = asyncio.get_event_loop()
self._tasks = {}
self.tasks = {}
self._task_status_change_callbacks = []
self.task_status_change_callbacks = []
self._set_start_and_stop_for_async_methods()
@ -92,7 +86,7 @@ class TaskManager(AbstractTaskManager):
except asyncio.CancelledError:
print(f"Task {name} was cancelled")
if not self._tasks.get(name):
if not self.tasks.get(name):
# Get the signature of the coroutine method to start
sig = inspect.signature(method)
@ -118,25 +112,25 @@ class TaskManager(AbstractTaskManager):
# Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments.
self._tasks[name] = {
self.tasks[name] = {
"task": self._loop.create_task(task(*args, **kwargs)),
"kwargs": kwargs_updated,
}
# emit the notification that the task was started
for callback in self._task_status_change_callbacks:
for callback in self.task_status_change_callbacks:
callback(name, kwargs_updated)
else:
logger.error(f"Task `{name}` is already running!")
def stop_task() -> None:
# cancel the task
task = self._tasks.pop(name)
task = self.tasks.pop(name)
if task is not None:
self._loop.call_soon_threadsafe(task["task"].cancel)
# emit the notification that the task was stopped
for callback in self._task_status_change_callbacks:
for callback in self.task_status_change_callbacks:
callback(name, None)
# create start and stop methods for each coroutine