mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-06-07 05:50:41 +02:00
feat: using TaskManager as attribute instead of inheriting
This commit is contained in:
parent
8fd1f1822f
commit
e7a0017431
@ -2,7 +2,6 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, cast, get_type_hints
|
from typing import Any, Optional, cast, get_type_hints
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ from pyDataInterface.utils.warnings import (
|
|||||||
warn_if_instance_class_does_not_inherit_from_DataService,
|
warn_if_instance_class_does_not_inherit_from_DataService,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .abstract_data_service import AbstractDataService
|
from .abstract_service_classes import AbstractDataService
|
||||||
from .callback_manager import CallbackManager
|
from .callback_manager import CallbackManager
|
||||||
from .task_manager import TaskManager
|
from .task_manager import TaskManager
|
||||||
|
|
||||||
@ -38,10 +37,14 @@ def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
class DataService(rpyc.Service, AbstractDataService):
|
||||||
def __init__(self, filename: Optional[str] = None) -> None:
|
def __init__(self, filename: Optional[str] = None) -> None:
|
||||||
TaskManager.__init__(self)
|
|
||||||
self._callback_manager: CallbackManager = CallbackManager(self)
|
self._callback_manager: CallbackManager = CallbackManager(self)
|
||||||
|
self._task_manager = TaskManager(self)
|
||||||
|
|
||||||
|
if not hasattr(self, "_autostart_tasks"):
|
||||||
|
self._autostart_tasks = {}
|
||||||
|
|
||||||
self.__root__: "DataService" = self
|
self.__root__: "DataService" = self
|
||||||
"""Keep track of the root object. This helps to filter the emission of
|
"""Keep track of the root object. This helps to filter the emission of
|
||||||
notifications. This overwrite the TaksManager's __root__ attribute."""
|
notifications. This overwrite the TaksManager's __root__ attribute."""
|
||||||
@ -53,6 +56,50 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
|||||||
self._initialised = True
|
self._initialised = True
|
||||||
self._load_values_from_json()
|
self._load_values_from_json()
|
||||||
|
|
||||||
|
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||||
|
current_value = getattr(self, __name, None)
|
||||||
|
# parse ints into floats if current value is a float
|
||||||
|
if isinstance(current_value, float) and isinstance(__value, int):
|
||||||
|
__value = float(__value)
|
||||||
|
|
||||||
|
super().__setattr__(__name, __value)
|
||||||
|
|
||||||
|
if self.__dict__.get("_initialised") and not __name == "_initialised":
|
||||||
|
for callback in self._callback_manager.callbacks:
|
||||||
|
callback(__name, __value)
|
||||||
|
elif __name.startswith(f"_{self.__class__.__name__}__"):
|
||||||
|
logger.warning(
|
||||||
|
f"Warning: You should not set private but rather protected attributes! "
|
||||||
|
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
|
||||||
|
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __check_instance_classes(self) -> None:
|
||||||
|
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
|
||||||
|
# every class defined by the user should inherit from DataService
|
||||||
|
if not attr_name.startswith("_DataService__"):
|
||||||
|
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
|
||||||
|
|
||||||
|
def _rpyc_getattr(self, name: str) -> Any:
|
||||||
|
if name.startswith("_"):
|
||||||
|
# disallow special and private attributes
|
||||||
|
raise AttributeError("cannot access private/special names")
|
||||||
|
# allow all other attributes
|
||||||
|
return getattr(self, name)
|
||||||
|
|
||||||
|
def _rpyc_setattr(self, name: str, value: Any) -> None:
|
||||||
|
if name.startswith("_"):
|
||||||
|
# disallow special and private attributes
|
||||||
|
raise AttributeError("cannot access private/special names")
|
||||||
|
|
||||||
|
# check if the attribute has a setter method
|
||||||
|
attr = getattr(self, name, None)
|
||||||
|
if isinstance(attr, property) and attr.fset is None:
|
||||||
|
raise AttributeError(f"{name} attribute does not have a setter method")
|
||||||
|
|
||||||
|
# allow all other attributes
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
def _load_values_from_json(self) -> None:
|
def _load_values_from_json(self) -> None:
|
||||||
if self._filename is not None:
|
if self._filename is not None:
|
||||||
# Check if the file specified by the filename exists
|
# Check if the file specified by the filename exists
|
||||||
@ -101,50 +148,6 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
|||||||
f'"{class_value_type}". Ignoring value from JSON file...'
|
f'"{class_value_type}". Ignoring value from JSON file...'
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
|
||||||
current_value = getattr(self, __name, None)
|
|
||||||
# parse ints into floats if current value is a float
|
|
||||||
if isinstance(current_value, float) and isinstance(__value, int):
|
|
||||||
__value = float(__value)
|
|
||||||
|
|
||||||
super().__setattr__(__name, __value)
|
|
||||||
|
|
||||||
if self.__dict__.get("_initialised") and not __name == "_initialised":
|
|
||||||
for callback in self._callback_manager.callbacks:
|
|
||||||
callback(__name, __value)
|
|
||||||
elif __name.startswith(f"_{self.__class__.__name__}__"):
|
|
||||||
logger.warning(
|
|
||||||
f"Warning: You should not set private but rather protected attributes! "
|
|
||||||
f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead "
|
|
||||||
f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _rpyc_getattr(self, name: str) -> Any:
|
|
||||||
if name.startswith("_"):
|
|
||||||
# disallow special and private attributes
|
|
||||||
raise AttributeError("cannot access private/special names")
|
|
||||||
# allow all other attributes
|
|
||||||
return getattr(self, name)
|
|
||||||
|
|
||||||
def _rpyc_setattr(self, name: str, value: Any) -> None:
|
|
||||||
if name.startswith("_"):
|
|
||||||
# disallow special and private attributes
|
|
||||||
raise AttributeError("cannot access private/special names")
|
|
||||||
|
|
||||||
# check if the attribute has a setter method
|
|
||||||
attr = getattr(self, name, None)
|
|
||||||
if isinstance(attr, property) and attr.fset is None:
|
|
||||||
raise AttributeError(f"{name} attribute does not have a setter method")
|
|
||||||
|
|
||||||
# allow all other attributes
|
|
||||||
setattr(self, name, value)
|
|
||||||
|
|
||||||
def __check_instance_classes(self) -> None:
|
|
||||||
for attr_name, attr_value in get_class_and_instance_attributes(self).items():
|
|
||||||
# every class defined by the user should inherit from DataService
|
|
||||||
if not attr_name.startswith("_DataService__"):
|
|
||||||
warn_if_instance_class_does_not_inherit_from_DataService(attr_value)
|
|
||||||
|
|
||||||
def serialize(self) -> dict[str, dict[str, Any]]: # noqa
|
def serialize(self) -> dict[str, dict[str, Any]]: # noqa
|
||||||
"""
|
"""
|
||||||
Serializes the instance into a dictionary, preserving the structure of the
|
Serializes the instance into a dictionary, preserving the structure of the
|
||||||
@ -243,8 +246,10 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager):
|
|||||||
for k, v in sig.parameters.items()
|
for k, v in sig.parameters.items()
|
||||||
}
|
}
|
||||||
running_task_info = None
|
running_task_info = None
|
||||||
if key in self._tasks: # If there's a running task for this method
|
if (
|
||||||
task_info = self._tasks[key]
|
key in self._task_manager._tasks
|
||||||
|
): # If there's a running task for this method
|
||||||
|
task_info = self._task_manager._tasks[key]
|
||||||
running_task_info = task_info["kwargs"]
|
running_task_info = task_info["kwargs"]
|
||||||
|
|
||||||
result[key] = {
|
result[key] = {
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import Callable
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from tiqi_rpc import Any
|
from tiqi_rpc import Any
|
||||||
|
|
||||||
from pyDataInterface.utils.helpers import get_class_and_instance_attributes
|
from .abstract_service_classes import AbstractDataService, AbstractTaskManager
|
||||||
|
|
||||||
|
|
||||||
class TaskDict(TypedDict):
|
class TaskDict(TypedDict):
|
||||||
@ -16,7 +14,7 @@ class TaskDict(TypedDict):
|
|||||||
kwargs: dict[str, Any]
|
kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class TaskManager(ABC):
|
class TaskManager(AbstractTaskManager):
|
||||||
"""
|
"""
|
||||||
The TaskManager class is a utility designed to manage asynchronous tasks. It
|
The TaskManager class is a utility designed to manage asynchronous tasks. It
|
||||||
provides functionality for starting, stopping, and tracking these tasks. The class
|
provides functionality for starting, stopping, and tracking these tasks. The class
|
||||||
@ -70,35 +68,20 @@ class TaskManager(ABC):
|
|||||||
interfaces, but can also be used to write logs, etc.
|
interfaces, but can also be used to write logs, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, service: AbstractDataService) -> None:
|
||||||
self.__root__: "TaskManager" = self
|
self.service = service
|
||||||
"""Keep track of the root object. This helps to filter the emission of
|
|
||||||
notifications."""
|
|
||||||
|
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
self._autostart_tasks: dict[str, tuple[Any]]
|
self._tasks = {}
|
||||||
if "_autostart_tasks" not in self.__dict__:
|
|
||||||
self._autostart_tasks = {}
|
|
||||||
|
|
||||||
self._tasks: dict[str, TaskDict] = {}
|
self._task_status_change_callbacks = []
|
||||||
"""A dictionary to keep track of running tasks. The keys are the names of the
|
|
||||||
tasks and the values are TaskDict instances which include the task itself and
|
|
||||||
its kwargs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._task_status_change_callbacks: list[
|
|
||||||
Callable[[str, dict[str, Any] | None], Any]
|
|
||||||
] = []
|
|
||||||
"""A list of callback functions to be invoked when the status of a task (start
|
|
||||||
or stop) changes."""
|
|
||||||
|
|
||||||
self._set_start_and_stop_for_async_methods()
|
self._set_start_and_stop_for_async_methods()
|
||||||
|
|
||||||
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
|
def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901
|
||||||
# inspect the methods of the class
|
# inspect the methods of the class
|
||||||
for name, method in inspect.getmembers(
|
for name, method in inspect.getmembers(
|
||||||
self, predicate=inspect.iscoroutinefunction
|
self.service, predicate=inspect.iscoroutinefunction
|
||||||
):
|
):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
@ -157,12 +140,12 @@ class TaskManager(ABC):
|
|||||||
callback(name, None)
|
callback(name, None)
|
||||||
|
|
||||||
# create start and stop methods for each coroutine
|
# create start and stop methods for each coroutine
|
||||||
setattr(self, f"start_{name}", start_task)
|
setattr(self.service, f"start_{name}", start_task)
|
||||||
setattr(self, f"stop_{name}", stop_task)
|
setattr(self.service, f"stop_{name}", stop_task)
|
||||||
|
|
||||||
def _start_autostart_tasks(self) -> None:
|
def start_autostart_tasks(self) -> None:
|
||||||
if self._autostart_tasks is not None:
|
if self.service._autostart_tasks is not None:
|
||||||
for service_name, args in self._autostart_tasks.items():
|
for service_name, args in self.service._autostart_tasks.items():
|
||||||
start_method = getattr(self, f"start_{service_name}", None)
|
start_method = getattr(self, f"start_{service_name}", None)
|
||||||
if start_method is not None and callable(start_method):
|
if start_method is not None and callable(start_method):
|
||||||
start_method(*args)
|
start_method(*args)
|
||||||
|
@ -99,7 +99,7 @@ class Server:
|
|||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
self._loop.set_exception_handler(self.custom_exception_handler)
|
self._loop.set_exception_handler(self.custom_exception_handler)
|
||||||
self.install_signal_handlers()
|
self.install_signal_handlers()
|
||||||
self._service._start_autostart_tasks()
|
self._service._task_manager.start_autostart_tasks()
|
||||||
|
|
||||||
if self._enable_rpc:
|
if self._enable_rpc:
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user