mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-06-06 13:30:41 +02:00
feat: using TaskManager as attribute instead of inheriting
This commit is contained in:
parent
8fd1f1822f
commit
e7a0017431
@ -2,7 +2,6 @@ import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, cast, get_type_hints
|
||||
|
||||
@ -22,7 +21,7 @@ from pyDataInterface.utils.warnings import (
|
||||
warn_if_instance_class_does_not_inherit_from_DataService,
|
||||
)
|
||||
|
||||
from .abstract_data_service import AbstractDataService
|
||||
from .abstract_service_classes import AbstractDataService
|
||||
from .callback_manager import CallbackManager
|
||||
from .task_manager import TaskManager
|
||||
|
||||
@ -38,10 +37,14 @@ def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any:
|
||||
)
|
||||
|
||||
|
||||
class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
||||
class DataService(rpyc.Service, AbstractDataService):
|
||||
def __init__(self, filename: Optional[str] = None) -> None:
|
||||
TaskManager.__init__(self)
|
||||
self._callback_manager: CallbackManager = CallbackManager(self)
|
||||
self._task_manager = TaskManager(self)
|
||||
|
||||
if not hasattr(self, "_autostart_tasks"):
|
||||
self._autostart_tasks = {}
|
||||
|
||||
self.__root__: "DataService" = self
|
||||
"""Keep track of the root object. This helps to filter the emission of
|
||||
notifications. This overwrite the TaksManager's __root__ attribute."""
|
||||
@ -53,6 +56,50 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
||||
self._initialised = True
|
||||
self._load_values_from_json()
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
current_value = getattr(self, __name, None)
|
||||
# parse ints into floats if current value is a float
|
||||
if isinstance(current_value, float) and isinstance(__value, int):
|
||||
__value = float(__value)
|
||||
|
||||
super().__setattr__(__name, __value)
|
||||
|
||||
if self.__dict__.get("_initialised") and not __name == "_initialised":
|
||||
for callback in self._callback_manager.callbacks:
|
||||
callback(__name, __value)
|
||||
elif __name.startswith(f"_{self.__class__.__name__}__"):
|
||||
logger.warning(
|
||||
f"Warning: You should not set private but rather protected attributes! "
|
||||
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
|
||||
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
|
||||
)
|
||||
|
||||
def __check_instance_classes(self) -> None:
|
||||
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
|
||||
# every class defined by the user should inherit from DataService
|
||||
if not attr_name.startswith("_DataService__"):
|
||||
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
|
||||
|
||||
def _rpyc_getattr(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
# disallow special and private attributes
|
||||
raise AttributeError("cannot access private/special names")
|
||||
# allow all other attributes
|
||||
return getattr(self, name)
|
||||
|
||||
def _rpyc_setattr(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
# disallow special and private attributes
|
||||
raise AttributeError("cannot access private/special names")
|
||||
|
||||
# check if the attribute has a setter method
|
||||
attr = getattr(self, name, None)
|
||||
if isinstance(attr, property) and attr.fset is None:
|
||||
raise AttributeError(f"{name} attribute does not have a setter method")
|
||||
|
||||
# allow all other attributes
|
||||
setattr(self, name, value)
|
||||
|
||||
def _load_values_from_json(self) -> None:
|
||||
if self._filename is not None:
|
||||
# Check if the file specified by the filename exists
|
||||
@ -101,50 +148,6 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
||||
f'"{class_value_type}". Ignoring value from JSON file...'
|
||||
)
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
current_value = getattr(self, __name, None)
|
||||
# parse ints into floats if current value is a float
|
||||
if isinstance(current_value, float) and isinstance(__value, int):
|
||||
__value = float(__value)
|
||||
|
||||
super().__setattr__(__name, __value)
|
||||
|
||||
if self.__dict__.get("_initialised") and not __name == "_initialised":
|
||||
for callback in self._callback_manager.callbacks:
|
||||
callback(__name, __value)
|
||||
elif __name.startswith(f"_{self.__class__.__name__}__"):
|
||||
logger.warning(
|
||||
f"Warning: You should not set private but rather protected attributes! "
|
||||
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
|
||||
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
|
||||
)
|
||||
|
||||
def _rpyc_getattr(self, name: str) -> Any:
|
||||
if name.startswith("_"):
|
||||
# disallow special and private attributes
|
||||
raise AttributeError("cannot access private/special names")
|
||||
# allow all other attributes
|
||||
return getattr(self, name)
|
||||
|
||||
def _rpyc_setattr(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
# disallow special and private attributes
|
||||
raise AttributeError("cannot access private/special names")
|
||||
|
||||
# check if the attribute has a setter method
|
||||
attr = getattr(self, name, None)
|
||||
if isinstance(attr, property) and attr.fset is None:
|
||||
raise AttributeError(f"{name} attribute does not have a setter method")
|
||||
|
||||
# allow all other attributes
|
||||
setattr(self, name, value)
|
||||
|
||||
def __check_instance_classes(self) -> None:
|
||||
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
|
||||
# every class defined by the user should inherit from DataService
|
||||
if not attr_name.startswith("_DataService__"):
|
||||
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
|
||||
|
||||
def serialize(self) -> dict[str, dict[str, Any]]: # noqa
|
||||
"""
|
||||
Serializes the instance into a dictionary, preserving the structure of the
|
||||
@ -243,8 +246,10 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
||||
for k, v in sig.parameters.items()
|
||||
}
|
||||
running_task_info = None
|
||||
if key in self._tasks: # If there's a running task for this method
|
||||
task_info = self._tasks[key]
|
||||
if (
|
||||
key in self._task_manager._tasks
|
||||
): # If there's a running task for this method
|
||||
task_info = self._task_manager._tasks[key]
|
||||
running_task_info = task_info["kwargs"]
|
||||
|
||||
result[key] = {
|
||||
|
@ -1,14 +1,12 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypedDict
|
||||
|
||||
from loguru import logger
|
||||
from tiqi_rpc import Any
|
||||
|
||||
from pyDataInterface.utils.helpers import get_class_and_instance_attributes
|
||||
from .abstract_service_classes import AbstractDataService, AbstractTaskManager
|
||||
|
||||
|
||||
class TaskDict(TypedDict):
|
||||
@ -16,7 +14,7 @@ class TaskDict(TypedDict):
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
class TaskManager(ABC):
|
||||
class TaskManager(AbstractTaskManager):
|
||||
"""
|
||||
The TaskManager class is a utility designed to manage asynchronous tasks. It
|
||||
provides functionality for starting, stopping, and tracking these tasks. The class
|
||||
@ -70,35 +68,20 @@ class TaskManager(ABC):
|
||||
interfaces, but can also be used to write logs, etc.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__root__: "TaskManager" = self
|
||||
"""Keep track of the root object. This helps to filter the emission of
|
||||
notifications."""
|
||||
|
||||
def __init__(self, service: AbstractDataService) -> None:
|
||||
self.service = service
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self._autostart_tasks: dict[str, tuple[Any]]
|
||||
if "_autostart_tasks" not in self.__dict__:
|
||||
self._autostart_tasks = {}
|
||||
self._tasks = {}
|
||||
|
||||
self._tasks: dict[str, TaskDict] = {}
|
||||
"""A dictionary to keep track of running tasks. The keys are the names of the
|
||||
tasks and the values are TaskDict instances which include the task itself and
|
||||
its kwargs.
|
||||
"""
|
||||
|
||||
self._task_status_change_callbacks: list[
|
||||
Callable[[str, dict[str, Any] | None], Any]
|
||||
] = []
|
||||
"""A list of callback functions to be invoked when the status of a task (start
|
||||
or stop) changes."""
|
||||
self._task_status_change_callbacks = []
|
||||
|
||||
self._set_start_and_stop_for_async_methods()
|
||||
|
||||
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
|
||||
# inspect the methods of the class
|
||||
for name, method in inspect.getmembers(
|
||||
self, predicate=inspect.iscoroutinefunction
|
||||
self.service, predicate=inspect.iscoroutinefunction
|
||||
):
|
||||
|
||||
@wraps(method)
|
||||
@ -157,12 +140,12 @@ class TaskManager(ABC):
|
||||
callback(name, None)
|
||||
|
||||
# create start and stop methods for each coroutine
|
||||
setattr(self, f"start_{name}", start_task)
|
||||
setattr(self, f"stop_{name}", stop_task)
|
||||
setattr(self.service, f"start_{name}", start_task)
|
||||
setattr(self.service, f"stop_{name}", stop_task)
|
||||
|
||||
def _start_autostart_tasks(self) -> None:
|
||||
if self._autostart_tasks is not None:
|
||||
for service_name, args in self._autostart_tasks.items():
|
||||
def start_autostart_tasks(self) -> None:
|
||||
if self.service._autostart_tasks is not None:
|
||||
for service_name, args in self.service._autostart_tasks.items():
|
||||
start_method = getattr(self, f"start_{service_name}", None)
|
||||
if start_method is not None and callable(start_method):
|
||||
start_method(*args)
|
||||
|
@ -99,7 +99,7 @@ class Server:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._loop.set_exception_handler(self.custom_exception_handler)
|
||||
self.install_signal_handlers()
|
||||
self._service._start_autostart_tasks()
|
||||
self._service._task_manager.start_autostart_tasks()
|
||||
|
||||
if self._enable_rpc:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
Loading…
x
Reference in New Issue
Block a user