feat: adding task status change callbacks

When a task (async function) is started / stopped, this will emit a
notification via socketio.
This commit is contained in:
Mose Müller 2023-08-02 12:06:20 +02:00
parent 3d07a5c9dd
commit bc50f99e18

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from itertools import chain from itertools import chain
from typing import Any from typing import Any, TypedDict
import rpyc import rpyc
from loguru import logger from loguru import logger
@ -16,6 +16,11 @@ from pyDataInterface.utils import (
from .data_service_list import DataServiceList from .data_service_list import DataServiceList
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class DataService(rpyc.Service): class DataService(rpyc.Service):
_list_mapping: dict[int, DataServiceList] = {} _list_mapping: dict[int, DataServiceList] = {}
""" """
@ -63,6 +68,13 @@ class DataService(rpyc.Service):
self._autostart_tasks = {} self._autostart_tasks = {}
self._callbacks: set[Callable[[str, Any], None]] = set() self._callbacks: set[Callable[[str, Any], None]] = set()
self._task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._set_start_and_stop_for_async_methods() self._set_start_and_stop_for_async_methods()
self._register_callbacks() self._register_callbacks()
@ -133,7 +145,39 @@ class DataService(rpyc.Service):
print(f"Task {name} was cancelled") print(f"Task {name} was cancelled")
if not self.__tasks.get(name): if not self.__tasks.get(name):
self.__tasks[name] = self.__loop.create_task(task(*args, **kwargs)) # Get the signature of the coroutine method to start
sig = inspect.signature(method)
# Create a list of the parameter names from the method signature.
parameter_names = list(sig.parameters.keys())
# Extend the list of positional arguments with None values to match
# the length of the parameter names list. This is done to ensure
# that zip can pair each parameter name with a corresponding value.
args_padded = list(args) + [None] * (
len(parameter_names) - len(args)
)
# Create a dictionary of keyword arguments by pairing the parameter
# names with the values in 'args_padded'. Then merge this dictionary
# with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = {
**dict(zip(parameter_names, args_padded)),
**kwargs,
}
# Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments.
self.__tasks[name] = {
"task": self.__loop.create_task(task(*args, **kwargs)),
"kwargs": kwargs_updated,
}
# emit the notification that the task was started
for callback in self._task_status_change_callbacks:
callback(name, kwargs_updated)
else: else:
logger.error(f"Task `{name}` is already running!") logger.error(f"Task `{name}` is already running!")
@ -141,7 +185,11 @@ class DataService(rpyc.Service):
# cancel the task # cancel the task
task = self.__tasks.pop(name) task = self.__tasks.pop(name)
if task is not None: if task is not None:
self.__loop.call_soon_threadsafe(task.cancel) self.__loop.call_soon_threadsafe(task["task"].cancel)
# emit the notification that the task was stopped
for callback in self._task_status_change_callbacks:
callback(name, None)
# create start and stop methods for each coroutine # create start and stop methods for each coroutine
setattr(self, f"start_{name}", start_task) setattr(self, f"start_{name}", start_task)
@ -153,6 +201,46 @@ class DataService(rpyc.Service):
self, f"{self.__class__.__name__}" self, f"{self.__class__.__name__}"
) )
self._register_property_callbacks(self, f"{self.__class__.__name__}") self._register_property_callbacks(self, f"{self.__class__.__name__}")
self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}")
def _register_start_stop_task_callbacks(
self, obj: "DataService", parent_path: str
) -> None:
"""
This function registers callbacks for start and stop methods of async functions.
These callbacks are stored in the '_task_status_change_callbacks' attribute and
are called when the status of a task changes.
Parameters:
-----------
obj: DataService
The target object on which callbacks are to be registered.
parent_path: str
The access path for the parent object. This is used to construct the full
access path for the notifications.
"""
# Create and register a callback for the object
# only emit the notification when the call was registered by the root object
callback: Callable[[str, dict[str, Any] | None], None] = (
lambda name, status: obj._emit_notification(
parent_path=parent_path, name=name, value=status
)
if self == obj.__root__
and not name.startswith("_") # we are only interested in public attributes
else None
)
obj._task_status_change_callbacks.append(callback)
# Recursively register callbacks for all nested attributes of the object
attrs = obj.__get_class_and_instance_attributes()
for nested_attr_name, nested_attr in attrs.items():
if isinstance(nested_attr, DataService):
self._register_start_stop_task_callbacks(
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
)
def _register_list_change_callbacks( def _register_list_change_callbacks(
self, obj: "DataService", parent_path: str self, obj: "DataService", parent_path: str
@ -545,11 +633,17 @@ class DataService(rpyc.Service):
else None else None
for k, v in sig.parameters.items() for k, v in sig.parameters.items()
} }
running_task_info = None
if key in self.__tasks: # If there's a running task for this method
task_info = self.__tasks[key]
running_task_info = task_info["kwargs"]
result[key] = { result[key] = {
"type": "method", "type": "method",
"async": asyncio.iscoroutinefunction(value), "async": asyncio.iscoroutinefunction(value),
"parameters": parameters, "parameters": parameters,
"doc": inspect.getdoc(value), "doc": inspect.getdoc(value),
"value": running_task_info,
} }
elif isinstance(getattr(self.__class__, key, None), property): elif isinstance(getattr(self.__class__, key, None), property):
prop: property = getattr(self.__class__, key) prop: property = getattr(self.__class__, key)