updating TaskManager

This commit is contained in:
Mose Müller 2023-08-02 12:06:22 +02:00
parent e8dd332753
commit ac9f39ca56
4 changed files with 13 additions and 19 deletions

View File

@ -25,10 +25,10 @@ class TaskDict(TypedDict):
class AbstractTaskManager(ABC): class AbstractTaskManager(ABC):
_task_status_change_callbacks: list[Callable[[str, dict[str, Any] | None], Any]] task_status_change_callbacks: list[Callable[[str, dict[str, Any] | None], Any]]
"""A list of callback functions to be invoked when the status of a task (start or """A list of callback functions to be invoked when the status of a task (start or
stop) changes.""" stop) changes."""
_tasks: dict[str, TaskDict] tasks: dict[str, TaskDict]
"""A dictionary to keep track of running tasks. The keys are the names of the """A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and tasks and the values are TaskDict instances which include the task itself and
its kwargs. its kwargs.

View File

@ -321,7 +321,7 @@ class CallbackManager(AbstractCallbackManager):
else None else None
) )
obj._task_manager._task_status_change_callbacks.append(callback) obj._task_manager.task_status_change_callbacks.append(callback)
# Recursively register callbacks for all nested attributes of the object # Recursively register callbacks for all nested attributes of the object
attrs: dict[str, Any] = get_class_and_instance_attributes(obj) attrs: dict[str, Any] = get_class_and_instance_attributes(obj)

View File

@ -247,9 +247,9 @@ class DataService(rpyc.Service, AbstractDataService):
} }
running_task_info = None running_task_info = None
if ( if (
key in self._task_manager._tasks key in self._task_manager.tasks
): # If there's a running task for this method ): # If there's a running task for this method
task_info = self._task_manager._tasks[key] task_info = self._task_manager.tasks[key]
running_task_info = task_info["kwargs"] running_task_info = task_info["kwargs"]
result[key] = { result[key] = {

View File

@ -1,19 +1,13 @@
import asyncio import asyncio
import inspect import inspect
from functools import wraps from functools import wraps
from typing import TypedDict from typing import Any
from loguru import logger from loguru import logger
from tiqi_rpc import Any
from .abstract_service_classes import AbstractDataService, AbstractTaskManager from .abstract_service_classes import AbstractDataService, AbstractTaskManager
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class TaskManager(AbstractTaskManager): class TaskManager(AbstractTaskManager):
""" """
The TaskManager class is a utility designed to manage asynchronous tasks. It The TaskManager class is a utility designed to manage asynchronous tasks. It
@ -72,9 +66,9 @@ class TaskManager(AbstractTaskManager):
self.service = service self.service = service
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._tasks = {} self.tasks = {}
self._task_status_change_callbacks = [] self.task_status_change_callbacks = []
self._set_start_and_stop_for_async_methods() self._set_start_and_stop_for_async_methods()
@ -92,7 +86,7 @@ class TaskManager(AbstractTaskManager):
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"Task {name} was cancelled") print(f"Task {name} was cancelled")
if not self._tasks.get(name): if not self.tasks.get(name):
# Get the signature of the coroutine method to start # Get the signature of the coroutine method to start
sig = inspect.signature(method) sig = inspect.signature(method)
@ -118,25 +112,25 @@ class TaskManager(AbstractTaskManager):
# Store the task and its arguments in the '__tasks' dictionary. The # Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary # key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments. # containing the task object and the updated keyword arguments.
self._tasks[name] = { self.tasks[name] = {
"task": self._loop.create_task(task(*args, **kwargs)), "task": self._loop.create_task(task(*args, **kwargs)),
"kwargs": kwargs_updated, "kwargs": kwargs_updated,
} }
# emit the notification that the task was started # emit the notification that the task was started
for callback in self._task_status_change_callbacks: for callback in self.task_status_change_callbacks:
callback(name, kwargs_updated) callback(name, kwargs_updated)
else: else:
logger.error(f"Task `{name}` is already running!") logger.error(f"Task `{name}` is already running!")
def stop_task() -> None: def stop_task() -> None:
# cancel the task # cancel the task
task = self._tasks.pop(name) task = self.tasks.pop(name)
if task is not None: if task is not None:
self._loop.call_soon_threadsafe(task["task"].cancel) self._loop.call_soon_threadsafe(task["task"].cancel)
# emit the notification that the task was stopped # emit the notification that the task was stopped
for callback in self._task_status_change_callbacks: for callback in self.task_status_change_callbacks:
callback(name, None) callback(name, None)
# create start and stop methods for each coroutine # create start and stop methods for each coroutine