feat: separating TaskManager out of DataService

This commit is contained in:
Mose Müller 2023-08-02 12:06:20 +02:00
parent 8460759a31
commit e3211b6000
2 changed files with 176 additions and 149 deletions

View File

@ -2,26 +2,21 @@ import asyncio
import inspect import inspect
from collections.abc import Callable from collections.abc import Callable
from enum import Enum from enum import Enum
from functools import wraps from typing import Any
from itertools import chain
from typing import Any, TypedDict
import rpyc import rpyc
from loguru import logger from loguru import logger
from pyDataInterface.utils import ( from pyDataInterface.utils import (
get_class_and_instance_attributes,
warn_if_instance_class_does_not_inherit_from_DataService, warn_if_instance_class_does_not_inherit_from_DataService,
) )
from .data_service_list import DataServiceList from .data_service_list import DataServiceList
from .task_manager import TaskManager
class TaskDict(TypedDict): class DataService(rpyc.Service, TaskManager):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class DataService(rpyc.Service):
_list_mapping: dict[int, DataServiceList] = {} _list_mapping: dict[int, DataServiceList] = {}
""" """
A dictionary mapping the id of the original lists to the corresponding A dictionary mapping the id of the original lists to the corresponding
@ -54,14 +49,10 @@ class DataService(rpyc.Service):
""" """
def __init__(self) -> None: def __init__(self) -> None:
TaskManager.__init__(self)
self.__root__: "DataService" = self self.__root__: "DataService" = self
"""Keep track of the root object. This helps to filter the emission of """Keep track of the root object. This helps to filter the emission of
notifications.""" notifications. This overwrite the TaksManager's __root__ attribute."""
self.__loop = asyncio.get_event_loop()
self.__tasks: dict[str, TaskDict] = {}
"""Dictionary to keep track of running tasks."""
self._autostart_tasks: dict[str, tuple[Any]] self._autostart_tasks: dict[str, tuple[Any]]
if "_autostart_tasks" not in self.__dict__: if "_autostart_tasks" not in self.__dict__:
@ -69,14 +60,6 @@ class DataService(rpyc.Service):
self._callbacks: set[Callable[[str, Any], None]] = set() self._callbacks: set[Callable[[str, Any], None]] = set()
self._task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._set_start_and_stop_for_async_methods()
self._register_callbacks() self._register_callbacks()
self.__check_instance_classes() self.__check_instance_classes()
self._initialised = True self._initialised = True
@ -130,71 +113,6 @@ class DataService(rpyc.Service):
f"No start method found for service '{service_name}'" f"No start method found for service '{service_name}'"
) )
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
# inspect the methods of the class
for name, method in inspect.getmembers(
self, predicate=inspect.iscoroutinefunction
):
@wraps(method)
def start_task(*args: Any, **kwargs: Any) -> None:
async def task(*args: Any, **kwargs: Any) -> None:
try:
await method(*args, **kwargs)
except asyncio.CancelledError:
print(f"Task {name} was cancelled")
if not self.__tasks.get(name):
# Get the signature of the coroutine method to start
sig = inspect.signature(method)
# Create a list of the parameter names from the method signature.
parameter_names = list(sig.parameters.keys())
# Extend the list of positional arguments with None values to match
# the length of the parameter names list. This is done to ensure
# that zip can pair each parameter name with a corresponding value.
args_padded = list(args) + [None] * (
len(parameter_names) - len(args)
)
# Create a dictionary of keyword arguments by pairing the parameter
# names with the values in 'args_padded'. Then merge this dictionary
# with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = {
**dict(zip(parameter_names, args_padded)),
**kwargs,
}
# Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments.
self.__tasks[name] = {
"task": self.__loop.create_task(task(*args, **kwargs)),
"kwargs": kwargs_updated,
}
# emit the notification that the task was started
for callback in self._task_status_change_callbacks:
callback(name, kwargs_updated)
else:
logger.error(f"Task `{name}` is already running!")
def stop_task() -> None:
# cancel the task
task = self.__tasks.pop(name)
if task is not None:
self.__loop.call_soon_threadsafe(task["task"].cancel)
# emit the notification that the task was stopped
for callback in self._task_status_change_callbacks:
callback(name, None)
# create start and stop methods for each coroutine
setattr(self, f"start_{name}", start_task)
setattr(self, f"stop_{name}", stop_task)
def _register_callbacks(self) -> None: def _register_callbacks(self) -> None:
self._register_list_change_callbacks(self, f"{self.__class__.__name__}") self._register_list_change_callbacks(self, f"{self.__class__.__name__}")
self._register_DataService_instance_callbacks( self._register_DataService_instance_callbacks(
@ -203,45 +121,6 @@ class DataService(rpyc.Service):
self._register_property_callbacks(self, f"{self.__class__.__name__}") self._register_property_callbacks(self, f"{self.__class__.__name__}")
self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}") self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}")
def _register_start_stop_task_callbacks(
self, obj: "DataService", parent_path: str
) -> None:
"""
This function registers callbacks for start and stop methods of async functions.
These callbacks are stored in the '_task_status_change_callbacks' attribute and
are called when the status of a task changes.
Parameters:
-----------
obj: DataService
The target object on which callbacks are to be registered.
parent_path: str
The access path for the parent object. This is used to construct the full
access path for the notifications.
"""
# Create and register a callback for the object
# only emit the notification when the call was registered by the root object
callback: Callable[[str, dict[str, Any] | None], None] = (
lambda name, status: obj._emit_notification(
parent_path=parent_path, name=name, value=status
)
if self == obj.__root__
and not name.startswith("_") # we are only interested in public attributes
else None
)
obj._task_status_change_callbacks.append(callback)
# Recursively register callbacks for all nested attributes of the object
attrs = obj.__get_class_and_instance_attributes()
for nested_attr_name, nested_attr in attrs.items():
if isinstance(nested_attr, DataService):
self._register_start_stop_task_callbacks(
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
)
def _register_list_change_callbacks( def _register_list_change_callbacks(
self, obj: "DataService", parent_path: str self, obj: "DataService", parent_path: str
) -> None: ) -> None:
@ -274,7 +153,7 @@ class DataService(rpyc.Service):
""" """
# Convert all list attributes (both class and instance) to DataServiceList # Convert all list attributes (both class and instance) to DataServiceList
attrs = obj.__get_class_and_instance_attributes() attrs = get_class_and_instance_attributes(obj)
for attr_name, attr_value in attrs.items(): for attr_name, attr_value in attrs.items():
if isinstance(attr_value, DataService): if isinstance(attr_value, DataService):
@ -361,7 +240,7 @@ class DataService(rpyc.Service):
obj._callbacks.add(callback) obj._callbacks.add(callback)
# Recursively register callbacks for all nested attributes of the object # Recursively register callbacks for all nested attributes of the object
attrs = obj.__get_class_and_instance_attributes() attrs = get_class_and_instance_attributes(obj)
for nested_attr_name, nested_attr in attrs.items(): for nested_attr_name, nested_attr in attrs.items():
if isinstance(nested_attr, DataServiceList): if isinstance(nested_attr, DataServiceList):
@ -446,7 +325,7 @@ class DataService(rpyc.Service):
propagates it through nested DataService instances. propagates it through nested DataService instances.
""" """
attrs = obj.__get_class_and_instance_attributes() attrs = get_class_and_instance_attributes(obj)
for attr_name, attr_value in attrs.items(): for attr_name, attr_value in attrs.items():
if isinstance(attr_value, DataService): if isinstance(attr_value, DataService):
@ -471,7 +350,7 @@ class DataService(rpyc.Service):
# >>> return self.class_attr.voltage * self.current # >>> return self.class_attr.voltage * self.current
# #
# The dependencies for this property are: # The dependencies for this property are:
# ('class_attr', 'voltage', 'current') # > ('class_attr', 'voltage', 'current')
if f"self.{dependency}" not in source_code_string: if f"self.{dependency}" not in source_code_string:
continue continue
@ -496,7 +375,7 @@ class DataService(rpyc.Service):
) )
else: else:
callback = ( callback = (
lambda name, value, dependent_attr=attr_name, dep=dependency: obj._emit_notification( lambda name, _, dependent_attr=attr_name, dep=dependency: obj._emit_notification(
parent_path=parent_path, parent_path=parent_path,
name=dependent_attr, name=dependent_attr,
value=getattr(obj, dependent_attr), value=getattr(obj, dependent_attr),
@ -507,22 +386,8 @@ class DataService(rpyc.Service):
# Add to _callbacks # Add to _callbacks
obj._callbacks.add(callback) obj._callbacks.add(callback)
def __get_class_and_instance_attributes(self) -> dict[str, Any]:
"""Dictionary containing all attributes (both instance and class level) of a
given object.
If an attribute exists at both the instance and class level,the value from the
instance attribute takes precedence.
The __root__ object is removed as this will lead to endless recursion in the for
loops.
"""
attrs = dict(chain(type(self).__dict__.items(), self.__dict__.items()))
attrs.pop("__root__")
return attrs
def __check_instance_classes(self) -> None: def __check_instance_classes(self) -> None:
for attr_name, attr_value in self.__get_class_and_instance_attributes().items(): for attr_name, attr_value in get_class_and_instance_attributes(self).items():
# every class defined by the user should inherit from DataService # every class defined by the user should inherit from DataService
if not attr_name.startswith("_DataService__"): if not attr_name.startswith("_DataService__"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value) warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
@ -634,8 +499,8 @@ class DataService(rpyc.Service):
for k, v in sig.parameters.items() for k, v in sig.parameters.items()
} }
running_task_info = None running_task_info = None
if key in self.__tasks: # If there's a running task for this method if key in self._tasks: # If there's a running task for this method
task_info = self.__tasks[key] task_info = self._tasks[key]
running_task_info = task_info["kwargs"] running_task_info = task_info["kwargs"]
result[key] = { result[key] = {

View File

@ -0,0 +1,162 @@
import asyncio
import inspect
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import TypedDict
from loguru import logger
from tiqi_rpc import Any
from pyDataInterface.utils import get_class_and_instance_attributes
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class TaskManager:
"""
The TaskManager class is a utility class designed to manage asynchronous tasks. It
provides functionality for starting and stopping these tasks. The class is primarily
used by the DataService class to manage its tasks.
The TaskManager class has the following responsibilities:
- Track all running tasks.
- Provide the ability to start and stop tasks.
- Emit notifications when the status of a task changes.
The tasks are asynchronous functions which can be started or stopped with the
generated functions in this class.
"""
def __init__(self) -> None:
self.__root__: "TaskManager" = self
"""Keep track of the root object. This helps to filter the emission of
notifications."""
self._loop = asyncio.get_event_loop()
self._tasks: dict[str, TaskDict] = {}
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
"""
self._task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._set_start_and_stop_for_async_methods()
def _register_start_stop_task_callbacks(
self, obj: "TaskManager", parent_path: str
) -> None:
"""
This function registers callbacks for start and stop methods of async functions.
These callbacks are stored in the '_task_status_change_callbacks' attribute and
are called when the status of a task changes.
Parameters:
-----------
obj: DataService
The target object on which callbacks are to be registered.
parent_path: str
The access path for the parent object. This is used to construct the full
access path for the notifications.
"""
# Create and register a callback for the object
# only emit the notification when the call was registered by the root object
callback: Callable[[str, dict[str, Any] | None], None] = (
lambda name, status: obj._emit_notification(
parent_path=parent_path, name=name, value=status
)
if self == obj.__root__
and not name.startswith("_") # we are only interested in public attributes
else None
)
obj._task_status_change_callbacks.append(callback)
# Recursively register callbacks for all nested attributes of the object
attrs: dict[str, Any] = get_class_and_instance_attributes(obj)
for nested_attr_name, nested_attr in attrs.items():
if isinstance(nested_attr, TaskManager):
self._register_start_stop_task_callbacks(
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
)
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
# inspect the methods of the class
for name, method in inspect.getmembers(
self, predicate=inspect.iscoroutinefunction
):
@wraps(method)
def start_task(*args: Any, **kwargs: Any) -> None:
async def task(*args: Any, **kwargs: Any) -> None:
try:
await method(*args, **kwargs)
except asyncio.CancelledError:
print(f"Task {name} was cancelled")
if not self._tasks.get(name):
# Get the signature of the coroutine method to start
sig = inspect.signature(method)
# Create a list of the parameter names from the method signature.
parameter_names = list(sig.parameters.keys())
# Extend the list of positional arguments with None values to match
# the length of the parameter names list. This is done to ensure
# that zip can pair each parameter name with a corresponding value.
args_padded = list(args) + [None] * (
len(parameter_names) - len(args)
)
# Create a dictionary of keyword arguments by pairing the parameter
# names with the values in 'args_padded'. Then merge this dictionary
# with the 'kwargs' dictionary. If a parameter is specified in both
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
kwargs_updated = {
**dict(zip(parameter_names, args_padded)),
**kwargs,
}
# Store the task and its arguments in the '__tasks' dictionary. The
# key is the name of the method, and the value is a dictionary
# containing the task object and the updated keyword arguments.
self._tasks[name] = {
"task": self._loop.create_task(task(*args, **kwargs)),
"kwargs": kwargs_updated,
}
# emit the notification that the task was started
for callback in self._task_status_change_callbacks:
callback(name, kwargs_updated)
else:
logger.error(f"Task `{name}` is already running!")
def stop_task() -> None:
# cancel the task
task = self._tasks.pop(name)
if task is not None:
self._loop.call_soon_threadsafe(task["task"].cancel)
# emit the notification that the task was stopped
for callback in self._task_status_change_callbacks:
callback(name, None)
# create start and stop methods for each coroutine
setattr(self, f"start_{name}", start_task)
setattr(self, f"stop_{name}", stop_task)
@abstractmethod
def _emit_notification(self, parent_path: str, name: str, value: Any) -> None:
raise NotImplementedError