mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-21 16:50:02 +02:00
feat: adding task status change callbacks
When a task (async function) is started / stopped, this will emit a notification via socketio.
This commit is contained in:
parent
3d07a5c9dd
commit
bc50f99e18
@ -4,7 +4,7 @@ from collections.abc import Callable
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -16,6 +16,11 @@ from pyDataInterface.utils import (
|
|||||||
from .data_service_list import DataServiceList
|
from .data_service_list import DataServiceList
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDict(TypedDict):
|
||||||
|
task: asyncio.Task[None]
|
||||||
|
kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class DataService(rpyc.Service):
|
class DataService(rpyc.Service):
|
||||||
_list_mapping: dict[int, DataServiceList] = {}
|
_list_mapping: dict[int, DataServiceList] = {}
|
||||||
"""
|
"""
|
||||||
@ -63,6 +68,13 @@ class DataService(rpyc.Service):
|
|||||||
self._autostart_tasks = {}
|
self._autostart_tasks = {}
|
||||||
|
|
||||||
self._callbacks: set[Callable[[str, Any], None]] = set()
|
self._callbacks: set[Callable[[str, Any], None]] = set()
|
||||||
|
|
||||||
|
self._task_status_change_callbacks: list[
|
||||||
|
Callable[[str, dict[str, Any] | None], Any]
|
||||||
|
] = []
|
||||||
|
"""A list of callback functions to be invoked when the status of a task (start
|
||||||
|
or stop) changes."""
|
||||||
|
|
||||||
self._set_start_and_stop_for_async_methods()
|
self._set_start_and_stop_for_async_methods()
|
||||||
|
|
||||||
self._register_callbacks()
|
self._register_callbacks()
|
||||||
@ -133,7 +145,39 @@ class DataService(rpyc.Service):
|
|||||||
print(f"Task {name} was cancelled")
|
print(f"Task {name} was cancelled")
|
||||||
|
|
||||||
if not self.__tasks.get(name):
|
if not self.__tasks.get(name):
|
||||||
self.__tasks[name] = self.__loop.create_task(task(*args, **kwargs))
|
# Get the signature of the coroutine method to start
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
# Create a list of the parameter names from the method signature.
|
||||||
|
parameter_names = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
# Extend the list of positional arguments with None values to match
|
||||||
|
# the length of the parameter names list. This is done to ensure
|
||||||
|
# that zip can pair each parameter name with a corresponding value.
|
||||||
|
args_padded = list(args) + [None] * (
|
||||||
|
len(parameter_names) - len(args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dictionary of keyword arguments by pairing the parameter
|
||||||
|
# names with the values in 'args_padded'. Then merge this dictionary
|
||||||
|
# with the 'kwargs' dictionary. If a parameter is specified in both
|
||||||
|
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
||||||
|
kwargs_updated = {
|
||||||
|
**dict(zip(parameter_names, args_padded)),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store the task and its arguments in the '__tasks' dictionary. The
|
||||||
|
# key is the name of the method, and the value is a dictionary
|
||||||
|
# containing the task object and the updated keyword arguments.
|
||||||
|
self.__tasks[name] = {
|
||||||
|
"task": self.__loop.create_task(task(*args, **kwargs)),
|
||||||
|
"kwargs": kwargs_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
# emit the notification that the task was started
|
||||||
|
for callback in self._task_status_change_callbacks:
|
||||||
|
callback(name, kwargs_updated)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Task `{name}` is already running!")
|
logger.error(f"Task `{name}` is already running!")
|
||||||
|
|
||||||
@ -141,7 +185,11 @@ class DataService(rpyc.Service):
|
|||||||
# cancel the task
|
# cancel the task
|
||||||
task = self.__tasks.pop(name)
|
task = self.__tasks.pop(name)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
self.__loop.call_soon_threadsafe(task.cancel)
|
self.__loop.call_soon_threadsafe(task["task"].cancel)
|
||||||
|
|
||||||
|
# emit the notification that the task was stopped
|
||||||
|
for callback in self._task_status_change_callbacks:
|
||||||
|
callback(name, None)
|
||||||
|
|
||||||
# create start and stop methods for each coroutine
|
# create start and stop methods for each coroutine
|
||||||
setattr(self, f"start_{name}", start_task)
|
setattr(self, f"start_{name}", start_task)
|
||||||
@ -153,6 +201,46 @@ class DataService(rpyc.Service):
|
|||||||
self, f"{self.__class__.__name__}"
|
self, f"{self.__class__.__name__}"
|
||||||
)
|
)
|
||||||
self._register_property_callbacks(self, f"{self.__class__.__name__}")
|
self._register_property_callbacks(self, f"{self.__class__.__name__}")
|
||||||
|
self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}")
|
||||||
|
|
||||||
|
def _register_start_stop_task_callbacks(
|
||||||
|
self, obj: "DataService", parent_path: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
This function registers callbacks for start and stop methods of async functions.
|
||||||
|
These callbacks are stored in the '_task_status_change_callbacks' attribute and
|
||||||
|
are called when the status of a task changes.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
obj: DataService
|
||||||
|
The target object on which callbacks are to be registered.
|
||||||
|
parent_path: str
|
||||||
|
The access path for the parent object. This is used to construct the full
|
||||||
|
access path for the notifications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create and register a callback for the object
|
||||||
|
# only emit the notification when the call was registered by the root object
|
||||||
|
callback: Callable[[str, dict[str, Any] | None], None] = (
|
||||||
|
lambda name, status: obj._emit_notification(
|
||||||
|
parent_path=parent_path, name=name, value=status
|
||||||
|
)
|
||||||
|
if self == obj.__root__
|
||||||
|
and not name.startswith("_") # we are only interested in public attributes
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
obj._task_status_change_callbacks.append(callback)
|
||||||
|
|
||||||
|
# Recursively register callbacks for all nested attributes of the object
|
||||||
|
attrs = obj.__get_class_and_instance_attributes()
|
||||||
|
|
||||||
|
for nested_attr_name, nested_attr in attrs.items():
|
||||||
|
if isinstance(nested_attr, DataService):
|
||||||
|
self._register_start_stop_task_callbacks(
|
||||||
|
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
|
||||||
|
)
|
||||||
|
|
||||||
def _register_list_change_callbacks(
|
def _register_list_change_callbacks(
|
||||||
self, obj: "DataService", parent_path: str
|
self, obj: "DataService", parent_path: str
|
||||||
@ -545,11 +633,17 @@ class DataService(rpyc.Service):
|
|||||||
else None
|
else None
|
||||||
for k, v in sig.parameters.items()
|
for k, v in sig.parameters.items()
|
||||||
}
|
}
|
||||||
|
running_task_info = None
|
||||||
|
if key in self.__tasks: # If there's a running task for this method
|
||||||
|
task_info = self.__tasks[key]
|
||||||
|
running_task_info = task_info["kwargs"]
|
||||||
|
|
||||||
result[key] = {
|
result[key] = {
|
||||||
"type": "method",
|
"type": "method",
|
||||||
"async": asyncio.iscoroutinefunction(value),
|
"async": asyncio.iscoroutinefunction(value),
|
||||||
"parameters": parameters,
|
"parameters": parameters,
|
||||||
"doc": inspect.getdoc(value),
|
"doc": inspect.getdoc(value),
|
||||||
|
"value": running_task_info,
|
||||||
}
|
}
|
||||||
elif isinstance(getattr(self.__class__, key, None), property):
|
elif isinstance(getattr(self.__class__, key, None), property):
|
||||||
prop: property = getattr(self.__class__, key)
|
prop: property = getattr(self.__class__, key)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user