feat: using TaskManager as attribute instead of inheriting

This commit is contained in:
Mose Müller 2023-08-02 12:06:22 +02:00
parent 8fd1f1822f
commit e7a0017431
3 changed files with 68 additions and 80 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import inspect import inspect
import json import json
import os import os
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Any, Optional, cast, get_type_hints from typing import Any, Optional, cast, get_type_hints
@ -22,7 +21,7 @@ from pyDataInterface.utils.warnings import (
warn_if_instance_class_does_not_inherit_from_DataService, warn_if_instance_class_does_not_inherit_from_DataService,
) )
from .abstract_data_service import AbstractDataService from .abstract_service_classes import AbstractDataService
from .callback_manager import CallbackManager from .callback_manager import CallbackManager
from .task_manager import TaskManager from .task_manager import TaskManager
@ -38,10 +37,14 @@ def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any:
) )
class DataService(rpyc.Service, AbstractDataService, TaskManager): class DataService(rpyc.Service, AbstractDataService):
def __init__(self, filename: Optional[str] = None) -> None: def __init__(self, filename: Optional[str] = None) -> None:
TaskManager.__init__(self)
self._callback_manager: CallbackManager = CallbackManager(self) self._callback_manager: CallbackManager = CallbackManager(self)
self._task_manager = TaskManager(self)
if not hasattr(self, "_autostart_tasks"):
self._autostart_tasks = {}
self.__root__: "DataService" = self self.__root__: "DataService" = self
"""Keep track of the root object. This helps to filter the emission of """Keep track of the root object. This helps to filter the emission of
notifications. This overwrite the TaksManager's __root__ attribute.""" notifications. This overwrite the TaksManager's __root__ attribute."""
@ -53,6 +56,50 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
self._initialised = True self._initialised = True
self._load_values_from_json() self._load_values_from_json()
def __setattr__(self, __name: str, __value: Any) -> None:
current_value = getattr(self, __name, None)
# parse ints into floats if current value is a float
if isinstance(current_value, float) and isinstance(__value, int):
__value = float(__value)
super().__setattr__(__name, __value)
if self.__dict__.get("_initialised") and not __name == "_initialised":
for callback in self._callback_manager.callbacks:
callback(__name, __value)
elif __name.startswith(f"_{self.__class__.__name__}__"):
logger.warning(
f"Warning: You should not set private but rather protected attributes! "
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
)
def __check_instance_classes(self) -> None:
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
# every class defined by the user should inherit from DataService
if not attr_name.startswith("_DataService__"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
def _rpyc_getattr(self, name: str) -> Any:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# allow all other attributes
return getattr(self, name)
def _rpyc_setattr(self, name: str, value: Any) -> None:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# check if the attribute has a setter method
attr = getattr(self, name, None)
if isinstance(attr, property) and attr.fset is None:
raise AttributeError(f"{name} attribute does not have a setter method")
# allow all other attributes
setattr(self, name, value)
def _load_values_from_json(self) -> None: def _load_values_from_json(self) -> None:
if self._filename is not None: if self._filename is not None:
# Check if the file specified by the filename exists # Check if the file specified by the filename exists
@ -101,50 +148,6 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
f'"{class_value_type}". Ignoring value from JSON file...' f'"{class_value_type}". Ignoring value from JSON file...'
) )
def __setattr__(self, __name: str, __value: Any) -> None:
current_value = getattr(self, __name, None)
# parse ints into floats if current value is a float
if isinstance(current_value, float) and isinstance(__value, int):
__value = float(__value)
super().__setattr__(__name, __value)
if self.__dict__.get("_initialised") and not __name == "_initialised":
for callback in self._callback_manager.callbacks:
callback(__name, __value)
elif __name.startswith(f"_{self.__class__.__name__}__"):
logger.warning(
f"Warning: You should not set private but rather protected attributes! "
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
)
def _rpyc_getattr(self, name: str) -> Any:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# allow all other attributes
return getattr(self, name)
def _rpyc_setattr(self, name: str, value: Any) -> None:
if name.startswith("_"):
# disallow special and private attributes
raise AttributeError("cannot access private/special names")
# check if the attribute has a setter method
attr = getattr(self, name, None)
if isinstance(attr, property) and attr.fset is None:
raise AttributeError(f"{name} attribute does not have a setter method")
# allow all other attributes
setattr(self, name, value)
def __check_instance_classes(self) -> None:
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
# every class defined by the user should inherit from DataService
if not attr_name.startswith("_DataService__"):
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
def serialize(self) -> dict[str, dict[str, Any]]: # noqa def serialize(self) -> dict[str, dict[str, Any]]: # noqa
""" """
Serializes the instance into a dictionary, preserving the structure of the Serializes the instance into a dictionary, preserving the structure of the
@ -243,8 +246,10 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
for k, v in sig.parameters.items() for k, v in sig.parameters.items()
} }
running_task_info = None running_task_info = None
if key in self._tasks: # If there's a running task for this method if (
task_info = self._tasks[key] key in self._task_manager._tasks
): # If there's a running task for this method
task_info = self._task_manager._tasks[key]
running_task_info = task_info["kwargs"] running_task_info = task_info["kwargs"]
result[key] = { result[key] = {

View File

@ -1,14 +1,12 @@
import asyncio import asyncio
import inspect import inspect
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import TypedDict from typing import TypedDict
from loguru import logger from loguru import logger
from tiqi_rpc import Any from tiqi_rpc import Any
from pyDataInterface.utils.helpers import get_class_and_instance_attributes from .abstract_service_classes import AbstractDataService, AbstractTaskManager
class TaskDict(TypedDict): class TaskDict(TypedDict):
@ -16,7 +14,7 @@ class TaskDict(TypedDict):
kwargs: dict[str, Any] kwargs: dict[str, Any]
class TaskManager(ABC): class TaskManager(AbstractTaskManager):
""" """
The TaskManager class is a utility designed to manage asynchronous tasks. It The TaskManager class is a utility designed to manage asynchronous tasks. It
provides functionality for starting, stopping, and tracking these tasks. The class provides functionality for starting, stopping, and tracking these tasks. The class
@ -70,35 +68,20 @@ class TaskManager(ABC):
interfaces, but can also be used to write logs, etc. interfaces, but can also be used to write logs, etc.
""" """
def __init__(self) -> None: def __init__(self, service: AbstractDataService) -> None:
self.__root__: "TaskManager" = self self.service = service
"""Keep track of the root object. This helps to filter the emission of
notifications."""
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._autostart_tasks: dict[str, tuple[Any]] self._tasks = {}
if "_autostart_tasks" not in self.__dict__:
self._autostart_tasks = {}
self._tasks: dict[str, TaskDict] = {} self._task_status_change_callbacks = []
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
"""
self._task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._set_start_and_stop_for_async_methods() self._set_start_and_stop_for_async_methods()
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901 def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
# inspect the methods of the class # inspect the methods of the class
for name, method in inspect.getmembers( for name, method in inspect.getmembers(
self, predicate=inspect.iscoroutinefunction self.service, predicate=inspect.iscoroutinefunction
): ):
@wraps(method) @wraps(method)
@ -157,12 +140,12 @@ class TaskManager(ABC):
callback(name, None) callback(name, None)
# create start and stop methods for each coroutine # create start and stop methods for each coroutine
setattr(self, f"start_{name}", start_task) setattr(self.service, f"start_{name}", start_task)
setattr(self, f"stop_{name}", stop_task) setattr(self.service, f"stop_{name}", stop_task)
def _start_autostart_tasks(self) -> None: def start_autostart_tasks(self) -> None:
if self._autostart_tasks is not None: if self.service._autostart_tasks is not None:
for service_name, args in self._autostart_tasks.items(): for service_name, args in self.service._autostart_tasks.items():
start_method = getattr(self, f"start_{service_name}", None) start_method = getattr(self, f"start_{service_name}", None)
if start_method is not None and callable(start_method): if start_method is not None and callable(start_method):
start_method(*args) start_method(*args)

View File

@ -99,7 +99,7 @@ class Server:
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
self._loop.set_exception_handler(self.custom_exception_handler) self._loop.set_exception_handler(self.custom_exception_handler)
self.install_signal_handlers() self.install_signal_handlers()
self._service._start_autostart_tasks() self._service._task_manager.start_autostart_tasks()
if self._enable_rpc: if self._enable_rpc:
self.executor = ThreadPoolExecutor() self.executor = ThreadPoolExecutor()