mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-21 16:50:02 +02:00
feat: separating TaskManager out of DataService
This commit is contained in:
parent
8460759a31
commit
e3211b6000
@ -2,26 +2,21 @@ import asyncio
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
import rpyc
|
||||
from loguru import logger
|
||||
|
||||
from pyDataInterface.utils import (
|
||||
get_class_and_instance_attributes,
|
||||
warn_if_instance_class_does_not_inherit_from_DataService,
|
||||
)
|
||||
|
||||
from .data_service_list import DataServiceList
|
||||
from .task_manager import TaskManager
|
||||
|
||||
|
||||
class TaskDict(TypedDict):
|
||||
task: asyncio.Task[None]
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
class DataService(rpyc.Service):
|
||||
class DataService(rpyc.Service, TaskManager):
|
||||
_list_mapping: dict[int, DataServiceList] = {}
|
||||
"""
|
||||
A dictionary mapping the id of the original lists to the corresponding
|
||||
@ -54,14 +49,10 @@ class DataService(rpyc.Service):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
TaskManager.__init__(self)
|
||||
self.__root__: "DataService" = self
|
||||
"""Keep track of the root object. This helps to filter the emission of
|
||||
notifications."""
|
||||
|
||||
self.__loop = asyncio.get_event_loop()
|
||||
|
||||
self.__tasks: dict[str, TaskDict] = {}
|
||||
"""Dictionary to keep track of running tasks."""
|
||||
notifications. This overwrite the TaksManager's __root__ attribute."""
|
||||
|
||||
self._autostart_tasks: dict[str, tuple[Any]]
|
||||
if "_autostart_tasks" not in self.__dict__:
|
||||
@ -69,14 +60,6 @@ class DataService(rpyc.Service):
|
||||
|
||||
self._callbacks: set[Callable[[str, Any], None]] = set()
|
||||
|
||||
self._task_status_change_callbacks: list[
|
||||
Callable[[str, dict[str, Any] | None], Any]
|
||||
] = []
|
||||
"""A list of callback functions to be invoked when the status of a task (start
|
||||
or stop) changes."""
|
||||
|
||||
self._set_start_and_stop_for_async_methods()
|
||||
|
||||
self._register_callbacks()
|
||||
self.__check_instance_classes()
|
||||
self._initialised = True
|
||||
@ -130,71 +113,6 @@ class DataService(rpyc.Service):
|
||||
f"No start method found for service '{service_name}'"
|
||||
)
|
||||
|
||||
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
|
||||
# inspect the methods of the class
|
||||
for name, method in inspect.getmembers(
|
||||
self, predicate=inspect.iscoroutinefunction
|
||||
):
|
||||
|
||||
@wraps(method)
|
||||
def start_task(*args: Any, **kwargs: Any) -> None:
|
||||
async def task(*args: Any, **kwargs: Any) -> None:
|
||||
try:
|
||||
await method(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
print(f"Task {name} was cancelled")
|
||||
|
||||
if not self.__tasks.get(name):
|
||||
# Get the signature of the coroutine method to start
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Create a list of the parameter names from the method signature.
|
||||
parameter_names = list(sig.parameters.keys())
|
||||
|
||||
# Extend the list of positional arguments with None values to match
|
||||
# the length of the parameter names list. This is done to ensure
|
||||
# that zip can pair each parameter name with a corresponding value.
|
||||
args_padded = list(args) + [None] * (
|
||||
len(parameter_names) - len(args)
|
||||
)
|
||||
|
||||
# Create a dictionary of keyword arguments by pairing the parameter
|
||||
# names with the values in 'args_padded'. Then merge this dictionary
|
||||
# with the 'kwargs' dictionary. If a parameter is specified in both
|
||||
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
||||
kwargs_updated = {
|
||||
**dict(zip(parameter_names, args_padded)),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Store the task and its arguments in the '__tasks' dictionary. The
|
||||
# key is the name of the method, and the value is a dictionary
|
||||
# containing the task object and the updated keyword arguments.
|
||||
self.__tasks[name] = {
|
||||
"task": self.__loop.create_task(task(*args, **kwargs)),
|
||||
"kwargs": kwargs_updated,
|
||||
}
|
||||
|
||||
# emit the notification that the task was started
|
||||
for callback in self._task_status_change_callbacks:
|
||||
callback(name, kwargs_updated)
|
||||
else:
|
||||
logger.error(f"Task `{name}` is already running!")
|
||||
|
||||
def stop_task() -> None:
|
||||
# cancel the task
|
||||
task = self.__tasks.pop(name)
|
||||
if task is not None:
|
||||
self.__loop.call_soon_threadsafe(task["task"].cancel)
|
||||
|
||||
# emit the notification that the task was stopped
|
||||
for callback in self._task_status_change_callbacks:
|
||||
callback(name, None)
|
||||
|
||||
# create start and stop methods for each coroutine
|
||||
setattr(self, f"start_{name}", start_task)
|
||||
setattr(self, f"stop_{name}", stop_task)
|
||||
|
||||
def _register_callbacks(self) -> None:
|
||||
self._register_list_change_callbacks(self, f"{self.__class__.__name__}")
|
||||
self._register_DataService_instance_callbacks(
|
||||
@ -203,45 +121,6 @@ class DataService(rpyc.Service):
|
||||
self._register_property_callbacks(self, f"{self.__class__.__name__}")
|
||||
self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}")
|
||||
|
||||
def _register_start_stop_task_callbacks(
|
||||
self, obj: "DataService", parent_path: str
|
||||
) -> None:
|
||||
"""
|
||||
This function registers callbacks for start and stop methods of async functions.
|
||||
These callbacks are stored in the '_task_status_change_callbacks' attribute and
|
||||
are called when the status of a task changes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
obj: DataService
|
||||
The target object on which callbacks are to be registered.
|
||||
parent_path: str
|
||||
The access path for the parent object. This is used to construct the full
|
||||
access path for the notifications.
|
||||
"""
|
||||
|
||||
# Create and register a callback for the object
|
||||
# only emit the notification when the call was registered by the root object
|
||||
callback: Callable[[str, dict[str, Any] | None], None] = (
|
||||
lambda name, status: obj._emit_notification(
|
||||
parent_path=parent_path, name=name, value=status
|
||||
)
|
||||
if self == obj.__root__
|
||||
and not name.startswith("_") # we are only interested in public attributes
|
||||
else None
|
||||
)
|
||||
|
||||
obj._task_status_change_callbacks.append(callback)
|
||||
|
||||
# Recursively register callbacks for all nested attributes of the object
|
||||
attrs = obj.__get_class_and_instance_attributes()
|
||||
|
||||
for nested_attr_name, nested_attr in attrs.items():
|
||||
if isinstance(nested_attr, DataService):
|
||||
self._register_start_stop_task_callbacks(
|
||||
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
|
||||
)
|
||||
|
||||
def _register_list_change_callbacks(
|
||||
self, obj: "DataService", parent_path: str
|
||||
) -> None:
|
||||
@ -274,7 +153,7 @@ class DataService(rpyc.Service):
|
||||
"""
|
||||
|
||||
# Convert all list attributes (both class and instance) to DataServiceList
|
||||
attrs = obj.__get_class_and_instance_attributes()
|
||||
attrs = get_class_and_instance_attributes(obj)
|
||||
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if isinstance(attr_value, DataService):
|
||||
@ -361,7 +240,7 @@ class DataService(rpyc.Service):
|
||||
obj._callbacks.add(callback)
|
||||
|
||||
# Recursively register callbacks for all nested attributes of the object
|
||||
attrs = obj.__get_class_and_instance_attributes()
|
||||
attrs = get_class_and_instance_attributes(obj)
|
||||
|
||||
for nested_attr_name, nested_attr in attrs.items():
|
||||
if isinstance(nested_attr, DataServiceList):
|
||||
@ -446,7 +325,7 @@ class DataService(rpyc.Service):
|
||||
propagates it through nested DataService instances.
|
||||
"""
|
||||
|
||||
attrs = obj.__get_class_and_instance_attributes()
|
||||
attrs = get_class_and_instance_attributes(obj)
|
||||
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if isinstance(attr_value, DataService):
|
||||
@ -471,7 +350,7 @@ class DataService(rpyc.Service):
|
||||
# >>> return self.class_attr.voltage * self.current
|
||||
#
|
||||
# The dependencies for this property are:
|
||||
# ('class_attr', 'voltage', 'current')
|
||||
# > ('class_attr', 'voltage', 'current')
|
||||
if f"self.{dependency}" not in source_code_string:
|
||||
continue
|
||||
|
||||
@ -496,7 +375,7 @@ class DataService(rpyc.Service):
|
||||
)
|
||||
else:
|
||||
callback = (
|
||||
lambda name, value, dependent_attr=attr_name, dep=dependency: obj._emit_notification(
|
||||
lambda name, _, dependent_attr=attr_name, dep=dependency: obj._emit_notification(
|
||||
parent_path=parent_path,
|
||||
name=dependent_attr,
|
||||
value=getattr(obj, dependent_attr),
|
||||
@ -507,22 +386,8 @@ class DataService(rpyc.Service):
|
||||
# Add to _callbacks
|
||||
obj._callbacks.add(callback)
|
||||
|
||||
def __get_class_and_instance_attributes(self) -> dict[str, Any]:
|
||||
"""Dictionary containing all attributes (both instance and class level) of a
|
||||
given object.
|
||||
|
||||
If an attribute exists at both the instance and class level,the value from the
|
||||
instance attribute takes precedence.
|
||||
The __root__ object is removed as this will lead to endless recursion in the for
|
||||
loops.
|
||||
"""
|
||||
|
||||
attrs = dict(chain(type(self).__dict__.items(), self.__dict__.items()))
|
||||
attrs.pop("__root__")
|
||||
return attrs
|
||||
|
||||
def __check_instance_classes(self) -> None:
|
||||
for attr_name, attr_value in self.__get_class_and_instance_attributes().items():
|
||||
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
|
||||
# every class defined by the user should inherit from DataService
|
||||
if not attr_name.startswith("_DataService__"):
|
||||
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
|
||||
@ -634,8 +499,8 @@ class DataService(rpyc.Service):
|
||||
for k, v in sig.parameters.items()
|
||||
}
|
||||
running_task_info = None
|
||||
if key in self.__tasks: # If there's a running task for this method
|
||||
task_info = self.__tasks[key]
|
||||
if key in self._tasks: # If there's a running task for this method
|
||||
task_info = self._tasks[key]
|
||||
running_task_info = task_info["kwargs"]
|
||||
|
||||
result[key] = {
|
||||
|
162
src/pyDataInterface/data_service/task_manager.py
Normal file
162
src/pyDataInterface/data_service/task_manager.py
Normal file
@ -0,0 +1,162 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypedDict
|
||||
|
||||
from loguru import logger
|
||||
from tiqi_rpc import Any
|
||||
|
||||
from pyDataInterface.utils import get_class_and_instance_attributes
|
||||
|
||||
|
||||
class TaskDict(TypedDict):
|
||||
task: asyncio.Task[None]
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
The TaskManager class is a utility class designed to manage asynchronous tasks. It
|
||||
provides functionality for starting and stopping these tasks. The class is primarily
|
||||
used by the DataService class to manage its tasks.
|
||||
|
||||
The TaskManager class has the following responsibilities:
|
||||
|
||||
- Track all running tasks.
|
||||
- Provide the ability to start and stop tasks.
|
||||
- Emit notifications when the status of a task changes.
|
||||
|
||||
The tasks are asynchronous functions which can be started or stopped with the
|
||||
generated functions in this class.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__root__: "TaskManager" = self
|
||||
"""Keep track of the root object. This helps to filter the emission of
|
||||
notifications."""
|
||||
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self._tasks: dict[str, TaskDict] = {}
|
||||
"""A dictionary to keep track of running tasks. The keys are the names of the
|
||||
tasks and the values are TaskDict instances which include the task itself and
|
||||
its kwargs.
|
||||
"""
|
||||
|
||||
self._task_status_change_callbacks: list[
|
||||
Callable[[str, dict[str, Any] | None], Any]
|
||||
] = []
|
||||
"""A list of callback functions to be invoked when the status of a task (start
|
||||
or stop) changes."""
|
||||
|
||||
self._set_start_and_stop_for_async_methods()
|
||||
|
||||
def _register_start_stop_task_callbacks(
|
||||
self, obj: "TaskManager", parent_path: str
|
||||
) -> None:
|
||||
"""
|
||||
This function registers callbacks for start and stop methods of async functions.
|
||||
These callbacks are stored in the '_task_status_change_callbacks' attribute and
|
||||
are called when the status of a task changes.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
obj: DataService
|
||||
The target object on which callbacks are to be registered.
|
||||
parent_path: str
|
||||
The access path for the parent object. This is used to construct the full
|
||||
access path for the notifications.
|
||||
"""
|
||||
|
||||
# Create and register a callback for the object
|
||||
# only emit the notification when the call was registered by the root object
|
||||
callback: Callable[[str, dict[str, Any] | None], None] = (
|
||||
lambda name, status: obj._emit_notification(
|
||||
parent_path=parent_path, name=name, value=status
|
||||
)
|
||||
if self == obj.__root__
|
||||
and not name.startswith("_") # we are only interested in public attributes
|
||||
else None
|
||||
)
|
||||
|
||||
obj._task_status_change_callbacks.append(callback)
|
||||
|
||||
# Recursively register callbacks for all nested attributes of the object
|
||||
attrs: dict[str, Any] = get_class_and_instance_attributes(obj)
|
||||
|
||||
for nested_attr_name, nested_attr in attrs.items():
|
||||
if isinstance(nested_attr, TaskManager):
|
||||
self._register_start_stop_task_callbacks(
|
||||
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
|
||||
)
|
||||
|
||||
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
|
||||
# inspect the methods of the class
|
||||
for name, method in inspect.getmembers(
|
||||
self, predicate=inspect.iscoroutinefunction
|
||||
):
|
||||
|
||||
@wraps(method)
|
||||
def start_task(*args: Any, **kwargs: Any) -> None:
|
||||
async def task(*args: Any, **kwargs: Any) -> None:
|
||||
try:
|
||||
await method(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
print(f"Task {name} was cancelled")
|
||||
|
||||
if not self._tasks.get(name):
|
||||
# Get the signature of the coroutine method to start
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Create a list of the parameter names from the method signature.
|
||||
parameter_names = list(sig.parameters.keys())
|
||||
|
||||
# Extend the list of positional arguments with None values to match
|
||||
# the length of the parameter names list. This is done to ensure
|
||||
# that zip can pair each parameter name with a corresponding value.
|
||||
args_padded = list(args) + [None] * (
|
||||
len(parameter_names) - len(args)
|
||||
)
|
||||
|
||||
# Create a dictionary of keyword arguments by pairing the parameter
|
||||
# names with the values in 'args_padded'. Then merge this dictionary
|
||||
# with the 'kwargs' dictionary. If a parameter is specified in both
|
||||
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
||||
kwargs_updated = {
|
||||
**dict(zip(parameter_names, args_padded)),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Store the task and its arguments in the '__tasks' dictionary. The
|
||||
# key is the name of the method, and the value is a dictionary
|
||||
# containing the task object and the updated keyword arguments.
|
||||
self._tasks[name] = {
|
||||
"task": self._loop.create_task(task(*args, **kwargs)),
|
||||
"kwargs": kwargs_updated,
|
||||
}
|
||||
|
||||
# emit the notification that the task was started
|
||||
for callback in self._task_status_change_callbacks:
|
||||
callback(name, kwargs_updated)
|
||||
else:
|
||||
logger.error(f"Task `{name}` is already running!")
|
||||
|
||||
def stop_task() -> None:
|
||||
# cancel the task
|
||||
task = self._tasks.pop(name)
|
||||
if task is not None:
|
||||
self._loop.call_soon_threadsafe(task["task"].cancel)
|
||||
|
||||
# emit the notification that the task was stopped
|
||||
for callback in self._task_status_change_callbacks:
|
||||
callback(name, None)
|
||||
|
||||
# create start and stop methods for each coroutine
|
||||
setattr(self, f"start_{name}", start_task)
|
||||
setattr(self, f"stop_{name}", stop_task)
|
||||
|
||||
@abstractmethod
|
||||
def _emit_notification(self, parent_path: str, name: str, value: Any) -> None:
|
||||
raise NotImplementedError
|
Loading…
x
Reference in New Issue
Block a user