mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-21 16:50:02 +02:00
Merge pull request #53 from tiqi-group/27-task-autostart-not-working-in-nested-classes
27 task autostart not working in nested classes
This commit is contained in:
commit
3d42366ada
@ -359,6 +359,12 @@ class CallbackManager:
|
|||||||
attrs: dict[str, Any] = get_class_and_instance_attributes(obj)
|
attrs: dict[str, Any] = get_class_and_instance_attributes(obj)
|
||||||
|
|
||||||
for nested_attr_name, nested_attr in attrs.items():
|
for nested_attr_name, nested_attr in attrs.items():
|
||||||
|
if isinstance(nested_attr, DataServiceList):
|
||||||
|
for i, item in enumerate(nested_attr):
|
||||||
|
if isinstance(item, AbstractDataService):
|
||||||
|
self._register_start_stop_task_callbacks(
|
||||||
|
item, parent_path=f"{parent_path}.{nested_attr_name}[{i}]"
|
||||||
|
)
|
||||||
if isinstance(nested_attr, AbstractDataService):
|
if isinstance(nested_attr, AbstractDataService):
|
||||||
self._register_start_stop_task_callbacks(
|
self._register_start_stop_task_callbacks(
|
||||||
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
|
nested_attr, parent_path=f"{parent_path}.{nested_attr_name}"
|
||||||
|
@ -7,6 +7,10 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict
|
from typing import TYPE_CHECKING, Any, TypedDict
|
||||||
|
|
||||||
|
from pydase.data_service.abstract_data_service import AbstractDataService
|
||||||
|
from pydase.data_service.data_service_list import DataServiceList
|
||||||
|
from pydase.utils.helpers import get_class_and_instance_attributes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .data_service import DataService
|
from .data_service import DataService
|
||||||
|
|
||||||
@ -95,92 +99,11 @@ class TaskManager:
|
|||||||
for name, method in inspect.getmembers(
|
for name, method in inspect.getmembers(
|
||||||
self.service, predicate=inspect.iscoroutinefunction
|
self.service, predicate=inspect.iscoroutinefunction
|
||||||
):
|
):
|
||||||
|
|
||||||
@wraps(method)
|
|
||||||
def start_task(*args: Any, **kwargs: Any) -> None:
|
|
||||||
def task_done_callback(task: asyncio.Task, name: str) -> None:
|
|
||||||
"""Handles tasks that have finished.
|
|
||||||
|
|
||||||
Removes a task from the tasks dictionary, calls the defined
|
|
||||||
callbacks, and logs and re-raises exceptions."""
|
|
||||||
|
|
||||||
# removing the finished task from the tasks i
|
|
||||||
self.tasks.pop(name, None)
|
|
||||||
|
|
||||||
# emit the notification that the task was stopped
|
|
||||||
for callback in self.task_status_change_callbacks:
|
|
||||||
callback(name, None)
|
|
||||||
|
|
||||||
exception = task.exception()
|
|
||||||
if exception is not None:
|
|
||||||
# Handle the exception, or you can re-raise it.
|
|
||||||
logger.error(
|
|
||||||
f"Task '{name}' encountered an exception: "
|
|
||||||
f"{type(exception).__name__}: {exception}"
|
|
||||||
)
|
|
||||||
raise exception
|
|
||||||
|
|
||||||
async def task(*args: Any, **kwargs: Any) -> None:
|
|
||||||
try:
|
|
||||||
await method(*args, **kwargs)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"Task {name} was cancelled")
|
|
||||||
|
|
||||||
if not self.tasks.get(name):
|
|
||||||
# Get the signature of the coroutine method to start
|
|
||||||
sig = inspect.signature(method)
|
|
||||||
|
|
||||||
# Create a list of the parameter names from the method signature.
|
|
||||||
parameter_names = list(sig.parameters.keys())
|
|
||||||
|
|
||||||
# Extend the list of positional arguments with None values to match
|
|
||||||
# the length of the parameter names list. This is done to ensure
|
|
||||||
# that zip can pair each parameter name with a corresponding value.
|
|
||||||
args_padded = list(args) + [None] * (
|
|
||||||
len(parameter_names) - len(args)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a dictionary of keyword arguments by pairing the parameter
|
|
||||||
# names with the values in 'args_padded'. Then merge this dictionary
|
|
||||||
# with the 'kwargs' dictionary. If a parameter is specified in both
|
|
||||||
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
|
||||||
kwargs_updated = {
|
|
||||||
**dict(zip(parameter_names, args_padded)),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# creating the task and adding the task_done_callback which checks
|
|
||||||
# if an exception has occured during the task execution
|
|
||||||
task_object = self._loop.create_task(task(*args, **kwargs))
|
|
||||||
task_object.add_done_callback(
|
|
||||||
lambda task: task_done_callback(task, name)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the task and its arguments in the '__tasks' dictionary. The
|
|
||||||
# key is the name of the method, and the value is a dictionary
|
|
||||||
# containing the task object and the updated keyword arguments.
|
|
||||||
self.tasks[name] = {
|
|
||||||
"task": task_object,
|
|
||||||
"kwargs": kwargs_updated,
|
|
||||||
}
|
|
||||||
|
|
||||||
# emit the notification that the task was started
|
|
||||||
for callback in self.task_status_change_callbacks:
|
|
||||||
callback(name, kwargs_updated)
|
|
||||||
else:
|
|
||||||
logger.error(f"Task `{name}` is already running!")
|
|
||||||
|
|
||||||
def stop_task() -> None:
|
|
||||||
# cancel the task
|
|
||||||
task = self.tasks.get(name, None)
|
|
||||||
if task is not None:
|
|
||||||
self._loop.call_soon_threadsafe(task["task"].cancel)
|
|
||||||
|
|
||||||
# create start and stop methods for each coroutine
|
# create start and stop methods for each coroutine
|
||||||
setattr(self.service, f"start_{name}", start_task)
|
setattr(self.service, f"start_{name}", self._make_start_task(name, method))
|
||||||
setattr(self.service, f"stop_{name}", stop_task)
|
setattr(self.service, f"stop_{name}", self._make_stop_task(name))
|
||||||
|
|
||||||
def start_autostart_tasks(self) -> None:
|
def _initiate_task_startup(self) -> None:
|
||||||
if self.service._autostart_tasks is not None:
|
if self.service._autostart_tasks is not None:
|
||||||
for service_name, args in self.service._autostart_tasks.items():
|
for service_name, args in self.service._autostart_tasks.items():
|
||||||
start_method = getattr(self.service, f"start_{service_name}", None)
|
start_method = getattr(self.service, f"start_{service_name}", None)
|
||||||
@ -190,3 +113,123 @@ class TaskManager:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"No start method found for service '{service_name}'"
|
f"No start method found for service '{service_name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def start_autostart_tasks(self) -> None:
|
||||||
|
self._initiate_task_startup()
|
||||||
|
attrs = get_class_and_instance_attributes(self.service)
|
||||||
|
|
||||||
|
for _, attr_value in attrs.items():
|
||||||
|
if isinstance(attr_value, AbstractDataService):
|
||||||
|
attr_value._task_manager.start_autostart_tasks()
|
||||||
|
elif isinstance(attr_value, DataServiceList):
|
||||||
|
for i, item in enumerate(attr_value):
|
||||||
|
if isinstance(item, AbstractDataService):
|
||||||
|
item._task_manager.start_autostart_tasks()
|
||||||
|
|
||||||
|
def _make_stop_task(self, name: str) -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
Factory function to create a 'stop_task' function for a running task.
|
||||||
|
|
||||||
|
The generated function cancels the associated asyncio task using 'name' for
|
||||||
|
identification, ensuring proper cleanup. Avoids closure and late binding issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the coroutine task, used for its identification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def stop_task() -> None:
|
||||||
|
# cancel the task
|
||||||
|
task = self.tasks.get(name, None)
|
||||||
|
if task is not None:
|
||||||
|
self._loop.call_soon_threadsafe(task["task"].cancel)
|
||||||
|
|
||||||
|
return stop_task
|
||||||
|
|
||||||
|
def _make_start_task( # noqa
|
||||||
|
self, name: str, method: Callable[..., Any]
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
Factory function to create a 'start_task' function for a coroutine.
|
||||||
|
|
||||||
|
The generated function starts the coroutine as an asyncio task, handling
|
||||||
|
registration and monitoring.
|
||||||
|
It uses 'name' and 'method' to avoid the closure and late binding issue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the coroutine, used for task management.
|
||||||
|
method (callable): The coroutine to be turned into an asyncio task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def start_task(*args: Any, **kwargs: Any) -> None:
|
||||||
|
def task_done_callback(task: asyncio.Task, name: str) -> None:
|
||||||
|
"""Handles tasks that have finished.
|
||||||
|
|
||||||
|
Removes a task from the tasks dictionary, calls the defined
|
||||||
|
callbacks, and logs and re-raises exceptions."""
|
||||||
|
|
||||||
|
# removing the finished task from the tasks i
|
||||||
|
self.tasks.pop(name, None)
|
||||||
|
|
||||||
|
# emit the notification that the task was stopped
|
||||||
|
for callback in self.task_status_change_callbacks:
|
||||||
|
callback(name, None)
|
||||||
|
|
||||||
|
exception = task.exception()
|
||||||
|
if exception is not None:
|
||||||
|
# Handle the exception, or you can re-raise it.
|
||||||
|
logger.error(
|
||||||
|
f"Task '{name}' encountered an exception: "
|
||||||
|
f"{type(exception).__name__}: {exception}"
|
||||||
|
)
|
||||||
|
raise exception
|
||||||
|
|
||||||
|
async def task(*args: Any, **kwargs: Any) -> None:
|
||||||
|
try:
|
||||||
|
await method(*args, **kwargs)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Task {name} was cancelled")
|
||||||
|
|
||||||
|
if not self.tasks.get(name):
|
||||||
|
# Get the signature of the coroutine method to start
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
# Create a list of the parameter names from the method signature.
|
||||||
|
parameter_names = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
# Extend the list of positional arguments with None values to match
|
||||||
|
# the length of the parameter names list. This is done to ensure
|
||||||
|
# that zip can pair each parameter name with a corresponding value.
|
||||||
|
args_padded = list(args) + [None] * (len(parameter_names) - len(args))
|
||||||
|
|
||||||
|
# Create a dictionary of keyword arguments by pairing the parameter
|
||||||
|
# names with the values in 'args_padded'. Then merge this dictionary
|
||||||
|
# with the 'kwargs' dictionary. If a parameter is specified in both
|
||||||
|
# 'args_padded' and 'kwargs', the value from 'kwargs' is used.
|
||||||
|
kwargs_updated = {
|
||||||
|
**dict(zip(parameter_names, args_padded)),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# creating the task and adding the task_done_callback which checks
|
||||||
|
# if an exception has occured during the task execution
|
||||||
|
task_object = self._loop.create_task(task(*args, **kwargs))
|
||||||
|
task_object.add_done_callback(
|
||||||
|
lambda task: task_done_callback(task, name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the task and its arguments in the '__tasks' dictionary. The
|
||||||
|
# key is the name of the method, and the value is a dictionary
|
||||||
|
# containing the task object and the updated keyword arguments.
|
||||||
|
self.tasks[name] = {
|
||||||
|
"task": task_object,
|
||||||
|
"kwargs": kwargs_updated,
|
||||||
|
}
|
||||||
|
|
||||||
|
# emit the notification that the task was started
|
||||||
|
for callback in self.task_status_change_callbacks:
|
||||||
|
callback(name, kwargs_updated)
|
||||||
|
else:
|
||||||
|
logger.error(f"Task `{name}` is already running!")
|
||||||
|
|
||||||
|
return start_task
|
||||||
|
56
tests/data_service/test_callback_manager.py
Normal file
56
tests/data_service/test_callback_manager.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
|
import pydase
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def test_DataService_task_callback(capsys: CaptureFixture) -> None:
|
||||||
|
class MyService(pydase.DataService):
|
||||||
|
async def my_task(self) -> None:
|
||||||
|
logger.info("Triggered task.")
|
||||||
|
|
||||||
|
async def my_other_task(self) -> None:
|
||||||
|
logger.info("Triggered other task.")
|
||||||
|
|
||||||
|
service = MyService()
|
||||||
|
service.start_my_task() # type: ignore
|
||||||
|
service.start_my_other_task() # type: ignore
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected_output = sorted(
|
||||||
|
[
|
||||||
|
"MyService.my_task = {}",
|
||||||
|
"MyService.my_other_task = {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_output = sorted(captured.out.strip().split("\n")) # type: ignore
|
||||||
|
assert expected_output == actual_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_DataServiceList_task_callback(capsys: CaptureFixture) -> None:
|
||||||
|
class MySubService(pydase.DataService):
|
||||||
|
async def my_task(self) -> None:
|
||||||
|
logger.info("Triggered task.")
|
||||||
|
|
||||||
|
async def my_other_task(self) -> None:
|
||||||
|
logger.info("Triggered other task.")
|
||||||
|
|
||||||
|
class MyService(pydase.DataService):
|
||||||
|
sub_services_list = [MySubService() for i in range(2)]
|
||||||
|
|
||||||
|
service = MyService()
|
||||||
|
service.sub_services_list[0].start_my_task() # type: ignore
|
||||||
|
service.sub_services_list[1].start_my_other_task() # type: ignore
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected_output = sorted(
|
||||||
|
[
|
||||||
|
"MyService.sub_services_list[0].my_task = {}",
|
||||||
|
"MyService.sub_services_list[1].my_other_task = {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_output = sorted(captured.out.strip().split("\n")) # type: ignore
|
||||||
|
assert expected_output == actual_output
|
104
tests/data_service/test_task_manager.py
Normal file
104
tests/data_service/test_task_manager.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
|
import pydase
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def test_autostart_task_callback(capsys: CaptureFixture) -> None:
|
||||||
|
class MyService(pydase.DataService):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._autostart_tasks = { # type: ignore
|
||||||
|
"my_task": (),
|
||||||
|
"my_other_task": (),
|
||||||
|
}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def my_task(self) -> None:
|
||||||
|
logger.info("Triggered task.")
|
||||||
|
|
||||||
|
async def my_other_task(self) -> None:
|
||||||
|
logger.info("Triggered other task.")
|
||||||
|
|
||||||
|
service = MyService()
|
||||||
|
service._task_manager.start_autostart_tasks()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected_output = sorted(
|
||||||
|
[
|
||||||
|
"MyService.my_task = {}",
|
||||||
|
"MyService.my_other_task = {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_output = sorted(captured.out.strip().split("\n")) # type: ignore
|
||||||
|
assert expected_output == actual_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_DataService_subclass_autostart_task_callback(capsys: CaptureFixture) -> None:
|
||||||
|
class MySubService(pydase.DataService):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._autostart_tasks = { # type: ignore
|
||||||
|
"my_task": (),
|
||||||
|
"my_other_task": (),
|
||||||
|
}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def my_task(self) -> None:
|
||||||
|
logger.info("Triggered task.")
|
||||||
|
|
||||||
|
async def my_other_task(self) -> None:
|
||||||
|
logger.info("Triggered other task.")
|
||||||
|
|
||||||
|
class MyService(pydase.DataService):
|
||||||
|
sub_service = MySubService()
|
||||||
|
|
||||||
|
service = MyService()
|
||||||
|
service._task_manager.start_autostart_tasks()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected_output = sorted(
|
||||||
|
[
|
||||||
|
"MyService.sub_service.my_task = {}",
|
||||||
|
"MyService.sub_service.my_other_task = {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_output = sorted(captured.out.strip().split("\n")) # type: ignore
|
||||||
|
assert expected_output == actual_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_DataServiceList_subclass_autostart_task_callback(
|
||||||
|
capsys: CaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
class MySubService(pydase.DataService):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._autostart_tasks = { # type: ignore
|
||||||
|
"my_task": (),
|
||||||
|
"my_other_task": (),
|
||||||
|
}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def my_task(self) -> None:
|
||||||
|
logger.info("Triggered task.")
|
||||||
|
|
||||||
|
async def my_other_task(self) -> None:
|
||||||
|
logger.info("Triggered other task.")
|
||||||
|
|
||||||
|
class MyService(pydase.DataService):
|
||||||
|
sub_services_list = [MySubService() for i in range(2)]
|
||||||
|
|
||||||
|
service = MyService()
|
||||||
|
service._task_manager.start_autostart_tasks()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expected_output = sorted(
|
||||||
|
[
|
||||||
|
"MyService.sub_services_list[0].my_task = {}",
|
||||||
|
"MyService.sub_services_list[0].my_other_task = {}",
|
||||||
|
"MyService.sub_services_list[1].my_task = {}",
|
||||||
|
"MyService.sub_services_list[1].my_other_task = {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_output = sorted(captured.out.strip().split("\n")) # type: ignore
|
||||||
|
assert expected_output == actual_output
|
Loading…
x
Reference in New Issue
Block a user