feat: using TaskManager as attribute instead of inheriting

This commit is contained in:
Mose Müller 2023-08-02 12:06:22 +02:00
parent 8fd1f1822f
commit e7a0017431
3 changed files with 68 additions and 80 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import inspect
import json
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional, cast, get_type_hints
@ -22,7 +21,7 @@ from pyDataInterface.utils.warnings import (
warn_if_instance_class_does_not_inherit_from_DataService,
)
from .abstract_data_service import AbstractDataService
from .abstract_service_classes import AbstractDataService
from .callback_manager import CallbackManager
from .task_manager import TaskManager
@ -38,10 +37,14 @@ def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any:
)
class DataService(rpyc.Service, AbstractDataService, TaskManager):
class DataService(rpyc.Service, AbstractDataService):
def __init__(self, filename: Optional[str] = None) -> None:
TaskManager.__init__(self)
self._callback_manager: CallbackManager = CallbackManager(self)
self._task_manager = TaskManager(self)
if not hasattr(self, "_autostart_tasks"):
self._autostart_tasks = {}
self.__root__: "DataService" = self
"""Keep track of the root object. This helps to filter the emission of
notifications. This overwrite the TaksManager's __root__ attribute."""
@ -53,6 +56,50 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
self._initialised = True
self._load_values_from_json()
def __setattr__(self, __name: str, __value: Any) -> None:
current_value = getattr(self, __name, None)
# parse ints into floats if current value is a float
if isinstance(current_value, float) and isinstance(__value, int):
__value = float(__value)
super().__setattr__(__name, __value)
if self.__dict__.get("_initialised") and not __name == "_initialised":
for callback in self._callback_manager.callbacks:
callback(__name, __value)
elif __name.startswith(f"_{self.__class__.__name__}__"):
logger.warning(
f"Warning: You should not set private but rather protected attributes! "
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
)
def __check_instance_classes(self) -> None:
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
# every class defined by the user should inherit from DataService
if not attr_name.startswith("_DataService__"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
def _rpyc_getattr(self, name: str) -> Any:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# allow all other attributes
return getattr(self, name)
def _rpyc_setattr(self, name: str, value: Any) -> None:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# check if the attribute has a setter method
attr = getattr(self, name, None)
if isinstance(attr, property) and attr.fset is None:
raise AttributeError(f"{name} attribute does not have a setter method")
# allow all other attributes
setattr(self, name, value)
def _load_values_from_json(self) -> None:
if self._filename is not None:
# Check if the file specified by the filename exists
@ -101,50 +148,6 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
f'"{class_value_type}". Ignoring value from JSON file...'
)
def __setattr__(self, __name: str, __value: Any) -> None:
current_value = getattr(self, __name, None)
# parse ints into floats if current value is a float
if isinstance(current_value, float) and isinstance(__value, int):
__value = float(__value)
super().__setattr__(__name, __value)
if self.__dict__.get("_initialised") and not __name == "_initialised":
for callback in self._callback_manager.callbacks:
callback(__name, __value)
elif __name.startswith(f"_{self.__class__.__name__}__"):
logger.warning(
f"Warning: You should not set private but rather protected attributes! "
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
)
def _rpyc_getattr(self, name: str) -> Any:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# allow all other attributes
return getattr(self, name)
def _rpyc_setattr(self, name: str, value: Any) -> None:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# check if the attribute has a setter method
attr = getattr(self, name, None)
if isinstance(attr, property) and attr.fset is None:
raise AttributeError(f"{name} attribute does not have a setter method")
# allow all other attributes
setattr(self, name, value)
def __check_instance_classes(self) -> None:
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
# every class defined by the user should inherit from DataService
if not attr_name.startswith("_DataService__"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
def serialize(self) -> dict[str, dict[str, Any]]: # noqa
"""
Serializes the instance into a dictionary, preserving the structure of the
@ -243,8 +246,10 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
for k, v in sig.parameters.items()
}
running_task_info = None
if key in self._tasks: # If there's a running task for this method
task_info = self._tasks[key]
if (
key in self._task_manager._tasks
): # If there's a running task for this method
task_info = self._task_manager._tasks[key]
running_task_info = task_info["kwargs"]
result[key] = {

View File

@ -1,14 +1,12 @@
import asyncio
import inspect
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import TypedDict
from loguru import logger
from tiqi_rpc import Any
from pyDataInterface.utils.helpers import get_class_and_instance_attributes
from .abstract_service_classes import AbstractDataService, AbstractTaskManager
class TaskDict(TypedDict):
@ -16,7 +14,7 @@ class TaskDict(TypedDict):
kwargs: dict[str, Any]
class TaskManager(ABC):
class TaskManager(AbstractTaskManager):
"""
The TaskManager class is a utility designed to manage asynchronous tasks. It
provides functionality for starting, stopping, and tracking these tasks. The class
@ -70,35 +68,20 @@ class TaskManager(ABC):
interfaces, but can also be used to write logs, etc.
"""
def __init__(self) -> None:
self.__root__: "TaskManager" = self
"""Keep track of the root object. This helps to filter the emission of
notifications."""
def __init__(self, service: AbstractDataService) -> None:
self.service = service
self._loop = asyncio.get_event_loop()
self._autostart_tasks: dict[str, tuple[Any]]
if "_autostart_tasks" not in self.__dict__:
self._autostart_tasks = {}
self._tasks = {}
self._tasks: dict[str, TaskDict] = {}
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
"""
self._task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._task_status_change_callbacks = []
self._set_start_and_stop_for_async_methods()
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
# inspect the methods of the class
for name, method in inspect.getmembers(
self, predicate=inspect.iscoroutinefunction
self.service, predicate=inspect.iscoroutinefunction
):
@wraps(method)
@ -157,12 +140,12 @@ class TaskManager(ABC):
callback(name, None)
# create start and stop methods for each coroutine
setattr(self, f"start_{name}", start_task)
setattr(self, f"stop_{name}", stop_task)
setattr(self.service, f"start_{name}", start_task)
setattr(self.service, f"stop_{name}", stop_task)
def _start_autostart_tasks(self) -> None:
if self._autostart_tasks is not None:
for service_name, args in self._autostart_tasks.items():
def start_autostart_tasks(self) -> None:
if self.service._autostart_tasks is not None:
for service_name, args in self.service._autostart_tasks.items():
start_method = getattr(self, f"start_{service_name}", None)
if start_method is not None and callable(start_method):
start_method(*args)

View File

@ -99,7 +99,7 @@ class Server:
self._loop = asyncio.get_running_loop()
self._loop.set_exception_handler(self.custom_exception_handler)
self.install_signal_handlers()
self._service._start_autostart_tasks()
self._service._task_manager.start_autostart_tasks()
if self._enable_rpc:
self.executor = ThreadPoolExecutor()