diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index 927b2ef..180aebe 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -2,26 +2,21 @@ import asyncio import inspect from collections.abc import Callable from enum import Enum -from functools import wraps -from itertools import chain -from typing import Any, TypedDict +from typing import Any import rpyc from loguru import logger from pyDataInterface.utils import ( + get_class_and_instance_attributes, warn_if_instance_class_does_not_inherit_from_DataService, ) from .data_service_list import DataServiceList +from .task_manager import TaskManager -class TaskDict(TypedDict): - task: asyncio.Task[None] - kwargs: dict[str, Any] - - -class DataService(rpyc.Service): +class DataService(rpyc.Service, TaskManager): _list_mapping: dict[int, DataServiceList] = {} """ A dictionary mapping the id of the original lists to the corresponding @@ -54,14 +49,10 @@ class DataService(rpyc.Service): """ def __init__(self) -> None: + TaskManager.__init__(self) self.__root__: "DataService" = self """Keep track of the root object. This helps to filter the emission of - notifications.""" - - self.__loop = asyncio.get_event_loop() - - self.__tasks: dict[str, TaskDict] = {} - """Dictionary to keep track of running tasks.""" + notifications. This overwrite the TaksManager's __root__ attribute.""" self._autostart_tasks: dict[str, tuple[Any]] if "_autostart_tasks" not in self.__dict__: @@ -69,14 +60,6 @@ class DataService(rpyc.Service): self._callbacks: set[Callable[[str, Any], None]] = set() - self._task_status_change_callbacks: list[ - Callable[[str, dict[str, Any] | None], Any] - ] = [] - """A list of callback functions to be invoked when the status of a task (start - or stop) changes.""" - - self._set_start_and_stop_for_async_methods() - self._register_callbacks() self.__check_instance_classes() self._initialised = True @@ -130,71 +113,6 @@ class DataService(rpyc.Service): f"No start method found for service '{service_name}'" ) - def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901 - # inspect the methods of the class - for name, method in inspect.getmembers( - self, predicate=inspect.iscoroutinefunction - ): - - @wraps(method) - def start_task(*args: Any, **kwargs: Any) -> None: - async def task(*args: Any, **kwargs: Any) -> None: - try: - await method(*args, **kwargs) - except asyncio.CancelledError: - print(f"Task {name} was cancelled") - - if not self.__tasks.get(name): - # Get the signature of the coroutine method to start - sig = inspect.signature(method) - - # Create a list of the parameter names from the method signature. - parameter_names = list(sig.parameters.keys()) - - # Extend the list of positional arguments with None values to match - # the length of the parameter names list. This is done to ensure - # that zip can pair each parameter name with a corresponding value. - args_padded = list(args) + [None] * ( - len(parameter_names) - len(args) - ) - - # Create a dictionary of keyword arguments by pairing the parameter - # names with the values in 'args_padded'. Then merge this dictionary - # with the 'kwargs' dictionary. If a parameter is specified in both - # 'args_padded' and 'kwargs', the value from 'kwargs' is used. - kwargs_updated = { - **dict(zip(parameter_names, args_padded)), - **kwargs, - } - - # Store the task and its arguments in the '__tasks' dictionary. The - # key is the name of the method, and the value is a dictionary - # containing the task object and the updated keyword arguments. - self.__tasks[name] = { - "task": self.__loop.create_task(task(*args, **kwargs)), - "kwargs": kwargs_updated, - } - - # emit the notification that the task was started - for callback in self._task_status_change_callbacks: - callback(name, kwargs_updated) - else: - logger.error(f"Task `{name}` is already running!") - - def stop_task() -> None: - # cancel the task - task = self.__tasks.pop(name) - if task is not None: - self.__loop.call_soon_threadsafe(task["task"].cancel) - - # emit the notification that the task was stopped - for callback in self._task_status_change_callbacks: - callback(name, None) - - # create start and stop methods for each coroutine - setattr(self, f"start_{name}", start_task) - setattr(self, f"stop_{name}", stop_task) - def _register_callbacks(self) -> None: self._register_list_change_callbacks(self, f"{self.__class__.__name__}") self._register_DataService_instance_callbacks( @@ -203,45 +121,6 @@ class DataService(rpyc.Service): self._register_property_callbacks(self, f"{self.__class__.__name__}") self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}") - def _register_start_stop_task_callbacks( - self, obj: "DataService", parent_path: str - ) -> None: - """ - This function registers callbacks for start and stop methods of async functions. - These callbacks are stored in the '_task_status_change_callbacks' attribute and - are called when the status of a task changes. - - Parameters: - ----------- - obj: DataService - The target object on which callbacks are to be registered. - parent_path: str - The access path for the parent object. This is used to construct the full - access path for the notifications. - """ - - # Create and register a callback for the object - # only emit the notification when the call was registered by the root object - callback: Callable[[str, dict[str, Any] | None], None] = ( - lambda name, status: obj._emit_notification( - parent_path=parent_path, name=name, value=status - ) - if self == obj.__root__ - and not name.startswith("_") # we are only interested in public attributes - else None - ) - - obj._task_status_change_callbacks.append(callback) - - # Recursively register callbacks for all nested attributes of the object - attrs = obj.__get_class_and_instance_attributes() - - for nested_attr_name, nested_attr in attrs.items(): - if isinstance(nested_attr, DataService): - self._register_start_stop_task_callbacks( - nested_attr, parent_path=f"{parent_path}.{nested_attr_name}" - ) - def _register_list_change_callbacks( self, obj: "DataService", parent_path: str ) -> None: @@ -274,7 +153,7 @@ class DataService(rpyc.Service): """ # Convert all list attributes (both class and instance) to DataServiceList - attrs = obj.__get_class_and_instance_attributes() + attrs = get_class_and_instance_attributes(obj) for attr_name, attr_value in attrs.items(): if isinstance(attr_value, DataService): @@ -361,7 +240,7 @@ class DataService(rpyc.Service): obj._callbacks.add(callback) # Recursively register callbacks for all nested attributes of the object - attrs = obj.__get_class_and_instance_attributes() + attrs = get_class_and_instance_attributes(obj) for nested_attr_name, nested_attr in attrs.items(): if isinstance(nested_attr, DataServiceList): @@ -446,7 +325,7 @@ class DataService(rpyc.Service): propagates it through nested DataService instances. """ - attrs = obj.__get_class_and_instance_attributes() + attrs = get_class_and_instance_attributes(obj) for attr_name, attr_value in attrs.items(): if isinstance(attr_value, DataService): @@ -471,7 +350,7 @@ class DataService(rpyc.Service): # >>> return self.class_attr.voltage * self.current # # The dependencies for this property are: - # ('class_attr', 'voltage', 'current') + # > ('class_attr', 'voltage', 'current') if f"self.{dependency}" not in source_code_string: continue @@ -496,7 +375,7 @@ class DataService(rpyc.Service): ) else: callback = ( - lambda name, value, dependent_attr=attr_name, dep=dependency: obj._emit_notification( + lambda name, _, dependent_attr=attr_name, dep=dependency: obj._emit_notification( parent_path=parent_path, name=dependent_attr, value=getattr(obj, dependent_attr), @@ -507,22 +386,8 @@ class DataService(rpyc.Service): # Add to _callbacks obj._callbacks.add(callback) - def __get_class_and_instance_attributes(self) -> dict[str, Any]: - """Dictionary containing all attributes (both instance and class level) of a - given object. - - If an attribute exists at both the instance and class level,the value from the - instance attribute takes precedence. - The __root__ object is removed as this will lead to endless recursion in the for - loops. - """ - - attrs = dict(chain(type(self).__dict__.items(), self.__dict__.items())) - attrs.pop("__root__") - return attrs - def __check_instance_classes(self) -> None: - for attr_name, attr_value in self.__get_class_and_instance_attributes().items(): + for attr_name, attr_value in get_class_and_instance_attributes(self).items(): # every class defined by the user should inherit from DataService if not attr_name.startswith("_DataService__"): warn_if_instance_class_does_not_inherit_from_DataService(attr_value) @@ -634,8 +499,8 @@ class DataService(rpyc.Service): for k, v in sig.parameters.items() } running_task_info = None - if key in self.__tasks: # If there's a running task for this method - task_info = self.__tasks[key] + if key in self._tasks: # If there's a running task for this method + task_info = self._tasks[key] running_task_info = task_info["kwargs"] result[key] = { diff --git a/src/pyDataInterface/data_service/task_manager.py b/src/pyDataInterface/data_service/task_manager.py new file mode 100644 index 0000000..e5e36c5 --- /dev/null +++ b/src/pyDataInterface/data_service/task_manager.py @@ -0,0 +1,162 @@ +import asyncio +import inspect +from abc import abstractmethod +from collections.abc import Callable +from functools import wraps +from typing import TypedDict + +from loguru import logger +from tiqi_rpc import Any + +from pyDataInterface.utils import get_class_and_instance_attributes + + +class TaskDict(TypedDict): + task: asyncio.Task[None] + kwargs: dict[str, Any] + + +class TaskManager: + """ + The TaskManager class is a utility class designed to manage asynchronous tasks. It + provides functionality for starting and stopping these tasks. The class is primarily + used by the DataService class to manage its tasks. + + The TaskManager class has the following responsibilities: + + - Track all running tasks. + - Provide the ability to start and stop tasks. + - Emit notifications when the status of a task changes. + + The tasks are asynchronous functions which can be started or stopped with the + generated functions in this class. + """ + + def __init__(self) -> None: + self.__root__: "TaskManager" = self + """Keep track of the root object. This helps to filter the emission of + notifications.""" + + self._loop = asyncio.get_event_loop() + + self._tasks: dict[str, TaskDict] = {} + """A dictionary to keep track of running tasks. The keys are the names of the + tasks and the values are TaskDict instances which include the task itself and + its kwargs. + """ + + self._task_status_change_callbacks: list[ + Callable[[str, dict[str, Any] | None], Any] + ] = [] + """A list of callback functions to be invoked when the status of a task (start + or stop) changes.""" + + self._set_start_and_stop_for_async_methods() + + def _register_start_stop_task_callbacks( + self, obj: "TaskManager", parent_path: str + ) -> None: + """ + This function registers callbacks for start and stop methods of async functions. + These callbacks are stored in the '_task_status_change_callbacks' attribute and + are called when the status of a task changes. + + Parameters: + ----------- + obj: DataService + The target object on which callbacks are to be registered. + parent_path: str + The access path for the parent object. This is used to construct the full + access path for the notifications. + """ + + # Create and register a callback for the object + # only emit the notification when the call was registered by the root object + callback: Callable[[str, dict[str, Any] | None], None] = ( + lambda name, status: obj._emit_notification( + parent_path=parent_path, name=name, value=status + ) + if self == obj.__root__ + and not name.startswith("_") # we are only interested in public attributes + else None + ) + + obj._task_status_change_callbacks.append(callback) + + # Recursively register callbacks for all nested attributes of the object + attrs: dict[str, Any] = get_class_and_instance_attributes(obj) + + for nested_attr_name, nested_attr in attrs.items(): + if isinstance(nested_attr, TaskManager): + self._register_start_stop_task_callbacks( + nested_attr, parent_path=f"{parent_path}.{nested_attr_name}" + ) + + def _set_start_and_stop_for_async_methods(self) -> None: # noqa: C901 + # inspect the methods of the class + for name, method in inspect.getmembers( + self, predicate=inspect.iscoroutinefunction + ): + + @wraps(method) + def start_task(*args: Any, **kwargs: Any) -> None: + async def task(*args: Any, **kwargs: Any) -> None: + try: + await method(*args, **kwargs) + except asyncio.CancelledError: + print(f"Task {name} was cancelled") + + if not self._tasks.get(name): + # Get the signature of the coroutine method to start + sig = inspect.signature(method) + + # Create a list of the parameter names from the method signature. + parameter_names = list(sig.parameters.keys()) + + # Extend the list of positional arguments with None values to match + # the length of the parameter names list. This is done to ensure + # that zip can pair each parameter name with a corresponding value. + args_padded = list(args) + [None] * ( + len(parameter_names) - len(args) + ) + + # Create a dictionary of keyword arguments by pairing the parameter + # names with the values in 'args_padded'. Then merge this dictionary + # with the 'kwargs' dictionary. If a parameter is specified in both + # 'args_padded' and 'kwargs', the value from 'kwargs' is used. + kwargs_updated = { + **dict(zip(parameter_names, args_padded)), + **kwargs, + } + + # Store the task and its arguments in the '__tasks' dictionary. The + # key is the name of the method, and the value is a dictionary + # containing the task object and the updated keyword arguments. + self._tasks[name] = { + "task": self._loop.create_task(task(*args, **kwargs)), + "kwargs": kwargs_updated, + } + + # emit the notification that the task was started + for callback in self._task_status_change_callbacks: + callback(name, kwargs_updated) + else: + logger.error(f"Task `{name}` is already running!") + + def stop_task() -> None: + # cancel the task + task = self._tasks.pop(name) + if task is not None: + self._loop.call_soon_threadsafe(task["task"].cancel) + + # emit the notification that the task was stopped + for callback in self._task_status_change_callbacks: + callback(name, None) + + # create start and stop methods for each coroutine + setattr(self, f"start_{name}", start_task) + setattr(self, f"stop_{name}", stop_task) + + @abstractmethod + def _emit_notification(self, parent_path: str, name: str, value: Any) -> None: + raise NotImplementedError