diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index bf2c6e7..d65eda6 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -2,7 +2,6 @@ import asyncio import inspect import json import os -from collections.abc import Callable from enum import Enum from typing import Any, Optional, cast, get_type_hints @@ -22,7 +21,7 @@ from pyDataInterface.utils.warnings import ( warn_if_instance_class_does_not_inherit_from_DataService, ) -from .abstract_data_service import AbstractDataService +from .abstract_service_classes import AbstractDataService from .callback_manager import CallbackManager from .task_manager import TaskManager @@ -38,10 +37,14 @@ def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any: ) -class DataService(rpyc.Service, AbstractDataService, TaskManager): +class DataService(rpyc.Service, AbstractDataService): def __init__(self, filename: Optional[str] = None) -> None: - TaskManager.__init__(self) self._callback_manager: CallbackManager = CallbackManager(self) + self._task_manager = TaskManager(self) + + if not hasattr(self, "_autostart_tasks"): + self._autostart_tasks = {} + self.__root__: "DataService" = self """Keep track of the root object. This helps to filter the emission of notifications. This overwrite the TaksManager's __root__ attribute.""" @@ -53,6 +56,50 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager): self._initialised = True self._load_values_from_json() + def __setattr__(self, __name: str, __value: Any) -> None: + current_value = getattr(self, __name, None) + # parse ints into floats if current value is a float + if isinstance(current_value, float) and isinstance(__value, int): + __value = float(__value) + + super().__setattr__(__name, __value) + + if self.__dict__.get("_initialised") and not __name == "_initialised": + for callback in self._callback_manager.callbacks: + callback(__name, __value) + elif __name.startswith(f"_{self.__class__.__name__}__"): + logger.warning( + f"Warning: You should not set private but rather protected attributes! " + f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead " + f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}." + ) + + def __check_instance_classes(self) -> None: + for attr_name, attr_value in get_class_and_instance_attributes(self).items(): + # every class defined by the user should inherit from DataService + if not attr_name.startswith("_DataService__"): + warn_if_instance_class_does_not_inherit_from_DataService(attr_value) + + def _rpyc_getattr(self, name: str) -> Any: + if name.startswith("_"): + # disallow special and private attributes + raise AttributeError("cannot access private/special names") + # allow all other attributes + return getattr(self, name) + + def _rpyc_setattr(self, name: str, value: Any) -> None: + if name.startswith("_"): + # disallow special and private attributes + raise AttributeError("cannot access private/special names") + + # check if the attribute has a setter method + attr = getattr(self, name, None) + if isinstance(attr, property) and attr.fset is None: + raise AttributeError(f"{name} attribute does not have a setter method") + + # allow all other attributes + setattr(self, name, value) + def _load_values_from_json(self) -> None: if self._filename is not None: # Check if the file specified by the filename exists @@ -101,50 +148,6 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager): f'"{class_value_type}". Ignoring value from JSON file...' ) - def __setattr__(self, __name: str, __value: Any) -> None: - current_value = getattr(self, __name, None) - # parse ints into floats if current value is a float - if isinstance(current_value, float) and isinstance(__value, int): - __value = float(__value) - - super().__setattr__(__name, __value) - - if self.__dict__.get("_initialised") and not __name == "_initialised": - for callback in self._callback_manager.callbacks: - callback(__name, __value) - elif __name.startswith(f"_{self.__class__.__name__}__"): - logger.warning( - f"Warning: You should not set private but rather protected attributes! " - f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead " - f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}." - ) - - def _rpyc_getattr(self, name: str) -> Any: - if name.startswith("_"): - # disallow special and private attributes - raise AttributeError("cannot access private/special names") - # allow all other attributes - return getattr(self, name) - - def _rpyc_setattr(self, name: str, value: Any) -> None: - if name.startswith("_"): - # disallow special and private attributes - raise AttributeError("cannot access private/special names") - - # check if the attribute has a setter method - attr = getattr(self, name, None) - if isinstance(attr, property) and attr.fset is None: - raise AttributeError(f"{name} attribute does not have a setter method") - - # allow all other attributes - setattr(self, name, value) - - def __check_instance_classes(self) -> None: - for attr_name, attr_value in get_class_and_instance_attributes(self).items(): - # every class defined by the user should inherit from DataService - if not attr_name.startswith("_DataService__"): - warn_if_instance_class_does_not_inherit_from_DataService(attr_value) - def serialize(self) -> dict[str, dict[str, Any]]: # noqa """ Serializes the instance into a dictionary, preserving the structure of the @@ -243,8 +246,10 @@ class DataService(rpyc.Service, AbstractDataService, TaskManager): for k, v in sig.parameters.items() } running_task_info = None - if key in self._tasks: # If there's a running task for this method - task_info = self._tasks[key] + if ( + key in self._task_manager._tasks + ): # If there's a running task for this method + task_info = self._task_manager._tasks[key] running_task_info = task_info["kwargs"] result[key] = { diff --git a/src/pyDataInterface/data_service/task_manager.py b/src/pyDataInterface/data_service/task_manager.py index a22d762..327b8eb 100644 --- a/src/pyDataInterface/data_service/task_manager.py +++ b/src/pyDataInterface/data_service/task_manager.py @@ -1,14 +1,12 @@ import asyncio import inspect -from abc import ABC, abstractmethod -from collections.abc import Callable from functools import wraps from typing import TypedDict from loguru import logger from tiqi_rpc import Any -from pyDataInterface.utils.helpers import get_class_and_instance_attributes +from .abstract_service_classes import AbstractDataService, AbstractTaskManager class TaskDict(TypedDict): @@ -16,7 +14,7 @@ class TaskDict(TypedDict): kwargs: dict[str, Any] -class TaskManager(ABC): +class TaskManager(AbstractTaskManager): """ The TaskManager class is a utility designed to manage asynchronous tasks. It provides functionality for starting, stopping, and tracking these tasks. The class @@ -70,35 +68,20 @@ class TaskManager(ABC): interfaces, but can also be used to write logs, etc. """ - def __init__(self) -> None: - self.__root__: "TaskManager" = self - """Keep track of the root object. This helps to filter the emission of - notifications.""" - + def __init__(self, service: AbstractDataService) -> None: + self.service = service self._loop = asyncio.get_event_loop() - self._autostart_tasks: dict[str, tuple[Any]] - if "_autostart_tasks" not in self.__dict__: - self._autostart_tasks = {} + self._tasks = {} - self._tasks: dict[str, TaskDict] = {} - """A dictionary to keep track of running tasks. The keys are the names of the - tasks and the values are TaskDict instances which include the task itself and - its kwargs. - """ - - self._task_status_change_callbacks: list[ - Callable[[str, dict[str, Any] | None], Any] - ] = [] - """A list of callback functions to be invoked when the status of a task (start - or stop) changes.""" + self._task_status_change_callbacks = [] self._set_start_and_stop_for_async_methods() def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901 # inspect the methods of the class for name, method in inspect.getmembers( - self, predicate=inspect.iscoroutinefunction + self.service, predicate=inspect.iscoroutinefunction ): @wraps(method) @@ -157,12 +140,12 @@ class TaskManager(ABC): callback(name, None) # create start and stop methods for each coroutine - setattr(self, f"start_{name}", start_task) - setattr(self, f"stop_{name}", stop_task) + setattr(self.service, f"start_{name}", start_task) + setattr(self.service, f"stop_{name}", stop_task) - def _start_autostart_tasks(self) -> None: - if self._autostart_tasks is not None: - for service_name, args in self._autostart_tasks.items(): + def start_autostart_tasks(self) -> None: + if self.service._autostart_tasks is not None: + for service_name, args in self.service._autostart_tasks.items(): start_method = getattr(self, f"start_{service_name}", None) if start_method is not None and callable(start_method): start_method(*args) diff --git a/src/pyDataInterface/server/server.py b/src/pyDataInterface/server/server.py index 8cbf96c..7d13502 100644 --- a/src/pyDataInterface/server/server.py +++ b/src/pyDataInterface/server/server.py @@ -99,7 +99,7 @@ class Server: self._loop = asyncio.get_running_loop() self._loop.set_exception_handler(self.custom_exception_handler) self.install_signal_handlers() - self._service._start_autostart_tasks() + self._service._task_manager.start_autostart_tasks() if self._enable_rpc: self.executor = ThreadPoolExecutor()