diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index 442d44c..927b2ef 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -4,7 +4,7 @@ from collections.abc import Callable from enum import Enum from functools import wraps from itertools import chain -from typing import Any +from typing import Any, TypedDict import rpyc from loguru import logger @@ -16,6 +16,11 @@ from pyDataInterface.utils import ( from .data_service_list import DataServiceList +class TaskDict(TypedDict): + task: asyncio.Task[None] + kwargs: dict[str, Any] + + class DataService(rpyc.Service): _list_mapping: dict[int, DataServiceList] = {} """ @@ -63,6 +68,13 @@ class DataService(rpyc.Service): self._autostart_tasks = {} self._callbacks: set[Callable[[str, Any], None]] = set() + + self._task_status_change_callbacks: list[ + Callable[[str, dict[str, Any] | None], Any] + ] = [] + """A list of callback functions to be invoked when the status of a task (start + or stop) changes.""" + self._set_start_and_stop_for_async_methods() self._register_callbacks() @@ -133,7 +145,39 @@ class DataService(rpyc.Service): print(f"Task {name} was cancelled") if not self.__tasks.get(name): - self.__tasks[name] = self.__loop.create_task(task(*args, **kwargs)) + # Get the signature of the coroutine method to start + sig = inspect.signature(method) + + # Create a list of the parameter names from the method signature. + parameter_names = list(sig.parameters.keys()) + + # Extend the list of positional arguments with None values to match + # the length of the parameter names list. This is done to ensure + # that zip can pair each parameter name with a corresponding value. + args_padded = list(args) + [None] * ( + len(parameter_names) - len(args) + ) + + # Create a dictionary of keyword arguments by pairing the parameter + # names with the values in 'args_padded'. Then merge this dictionary + # with the 'kwargs' dictionary. If a parameter is specified in both + # 'args_padded' and 'kwargs', the value from 'kwargs' is used. + kwargs_updated = { + **dict(zip(parameter_names, args_padded)), + **kwargs, + } + + # Store the task and its arguments in the '__tasks' dictionary. The + # key is the name of the method, and the value is a dictionary + # containing the task object and the updated keyword arguments. + self.__tasks[name] = { + "task": self.__loop.create_task(task(*args, **kwargs)), + "kwargs": kwargs_updated, + } + + # emit the notification that the task was started + for callback in self._task_status_change_callbacks: + callback(name, kwargs_updated) else: logger.error(f"Task `{name}` is already running!") @@ -141,7 +185,11 @@ class DataService(rpyc.Service): # cancel the task task = self.__tasks.pop(name) if task is not None: - self.__loop.call_soon_threadsafe(task.cancel) + self.__loop.call_soon_threadsafe(task["task"].cancel) + + # emit the notification that the task was stopped + for callback in self._task_status_change_callbacks: + callback(name, None) # create start and stop methods for each coroutine setattr(self, f"start_{name}", start_task) @@ -153,6 +201,46 @@ class DataService(rpyc.Service): self, f"{self.__class__.__name__}" ) self._register_property_callbacks(self, f"{self.__class__.__name__}") + self._register_start_stop_task_callbacks(self, f"{self.__class__.__name__}") + + def _register_start_stop_task_callbacks( + self, obj: "DataService", parent_path: str + ) -> None: + """ + This function registers callbacks for start and stop methods of async functions. + These callbacks are stored in the '_task_status_change_callbacks' attribute and + are called when the status of a task changes. + + Parameters: + ----------- + obj: DataService + The target object on which callbacks are to be registered. + parent_path: str + The access path for the parent object. This is used to construct the full + access path for the notifications. + """ + + # Create and register a callback for the object + # only emit the notification when the call was registered by the root object + callback: Callable[[str, dict[str, Any] | None], None] = ( + lambda name, status: obj._emit_notification( + parent_path=parent_path, name=name, value=status + ) + if self == obj.__root__ + and not name.startswith("_") # we are only interested in public attributes + else None + ) + + obj._task_status_change_callbacks.append(callback) + + # Recursively register callbacks for all nested attributes of the object + attrs = obj.__get_class_and_instance_attributes() + + for nested_attr_name, nested_attr in attrs.items(): + if isinstance(nested_attr, DataService): + self._register_start_stop_task_callbacks( + nested_attr, parent_path=f"{parent_path}.{nested_attr_name}" + ) def _register_list_change_callbacks( self, obj: "DataService", parent_path: str @@ -545,11 +633,17 @@ class DataService(rpyc.Service): else None for k, v in sig.parameters.items() } + running_task_info = None + if key in self.__tasks: # If there's a running task for this method + task_info = self.__tasks[key] + running_task_info = task_info["kwargs"] + result[key] = { "type": "method", "async": asyncio.iscoroutinefunction(value), "parameters": parameters, "doc": inspect.getdoc(value), + "value": running_task_info, } elif isinstance(getattr(self.__class__, key, None), property): prop: property = getattr(self.__class__, key)