diff --git a/src/pydase/task/decorator.py b/src/pydase/task/decorator.py index b6974b2..0149bb5 100644 --- a/src/pydase/task/decorator.py +++ b/src/pydase/task/decorator.py @@ -26,15 +26,25 @@ class PerInstanceTaskDescriptor(Generic[R]): the service class. """ - def __init__( + def __init__( # noqa: PLR0913 self, func: Callable[[Any], Coroutine[None, None, R]] | Callable[[], Coroutine[None, None, R]], - autostart: bool = False, + autostart: bool, + restart_on_failure: bool, + restart_sec: float, + start_limit_interval_sec: float | None, + start_limit_burst: int, + exit_on_failure: bool, ) -> None: self.__func = func self.__autostart = autostart self.__task_instances: dict[object, Task[R]] = {} + self.__restart_on_failure = restart_on_failure + self.__restart_sec = restart_sec + self.__start_limit_interval_sec = start_limit_interval_sec + self.__start_limit_burst = start_limit_burst + self.__exit_on_failure = exit_on_failure def __set_name__(self, owner: type[DataService], name: str) -> None: """Stores the name of the task within the owning class. This method is called @@ -67,14 +77,28 @@ class PerInstanceTaskDescriptor(Generic[R]): if instance not in self.__task_instances: self.__task_instances[instance] = instance._initialise_new_objects( self.__task_name, - Task(self.__func.__get__(instance, owner), autostart=self.__autostart), + Task( + self.__func.__get__(instance, owner), + autostart=self.__autostart, + restart_on_failure=self.__restart_on_failure, + restart_sec=self.__restart_sec, + start_limit_interval_sec=self.__start_limit_interval_sec, + start_limit_burst=self.__start_limit_burst, + exit_on_failure=self.__exit_on_failure, + ), ) return self.__task_instances[instance] -def task( - *, autostart: bool = False +def task( # noqa: PLR0913 + *, + autostart: bool = False, + restart_on_failure: bool = True, + restart_sec: float = 1.0, + start_limit_interval_sec: float | None = None, + start_limit_burst: int = 3, + exit_on_failure: bool = False, ) -> Callable[ [ Callable[[Any], Coroutine[None, None, R]] @@ -96,13 +120,30 @@ def task( periodically or perform asynchronous operations, such as polling data sources, updating databases, or any recurring job that should be managed within the context of a `DataService`. - time. + + The keyword arguments that can be passed to this decorator are inspired by systemd + unit services. Args: autostart: If set to True, the task will automatically start when the service is initialized. Defaults to False. - + restart_on_failure: + Configures whether the task shall be restarted when it exits with an + exception other than [`asyncio.CancelledError`][asyncio.CancelledError]. + restart_sec: + Configures the time to sleep before restarting a task. Defaults to 1.0. + start_limit_interval_sec: + Configures start rate limiting. Tasks which are started more than + `start_limit_burst` times within an `start_limit_interval_sec` time span are + not permitted to start any more. Defaults to None (disabled rate limiting). + start_limit_burst: + Configures unit start rate limiting. Tasks which are started more than + `start_limit_burst` times within an `start_limit_interval_sec` time span are + not permitted to start any more. Defaults to 3. + exit_on_failure: + If True, exit the service if the task fails and restart_on_failure is False + or burst limits are exceeded. Returns: A decorator that wraps an asynchronous function in a [`PerInstanceTaskDescriptor`][pydase.task.decorator.PerInstanceTaskDescriptor] @@ -140,6 +181,14 @@ def task( func: Callable[[Any], Coroutine[None, None, R]] | Callable[[], Coroutine[None, None, R]], ) -> PerInstanceTaskDescriptor[R]: - return PerInstanceTaskDescriptor(func, autostart=autostart) + return PerInstanceTaskDescriptor( + func, + autostart=autostart, + restart_on_failure=restart_on_failure, + restart_sec=restart_sec, + start_limit_interval_sec=start_limit_interval_sec, + start_limit_burst=start_limit_burst, + exit_on_failure=exit_on_failure, + ) return decorator diff --git a/src/pydase/task/task.py b/src/pydase/task/task.py index 6865f0f..321a907 100644 --- a/src/pydase/task/task.py +++ b/src/pydase/task/task.py @@ -1,7 +1,10 @@ import asyncio -import inspect import logging +import os +import signal from collections.abc import Callable, Coroutine +from datetime import datetime +from time import time from typing import ( Generic, TypeVar, @@ -28,6 +31,9 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): decorator, it is replaced by a `Task` instance that controls the execution of the original function. + The keyword arguments that can be passed to this class are inspired by systemd unit + services. + Args: func: The asynchronous function that this task wraps. It must be a coroutine @@ -35,6 +41,22 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): autostart: If set to True, the task will automatically start when the service is initialized. Defaults to False. + restart_on_failure: + Configures whether the task shall be restarted when it exits with an + exception other than [`asyncio.CancelledError`][asyncio.CancelledError]. + restart_sec: + Configures the time to sleep before restarting a task. Defaults to 1.0. + start_limit_interval_sec: + Configures start rate limiting. Tasks which are started more than + `start_limit_burst` times within an `start_limit_interval_sec` time span are + not permitted to start any more. Defaults to None (disabled rate limiting). + start_limit_burst: + Configures unit start rate limiting. Tasks which are started more than + `start_limit_burst` times within an `start_limit_interval_sec` time span are + not permitted to start any more. Defaults to 3. + exit_on_failure: + If True, exit the service if the task fails and restart_on_failure is False + or burst limits are exceeded. Example: ```python @@ -63,14 +85,24 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): `service.my_task.start()` and `service.my_task.stop()`, respectively. """ - def __init__( + def __init__( # noqa: PLR0913 self, func: Callable[[], Coroutine[None, None, R | None]], *, - autostart: bool = False, + autostart: bool, + restart_on_failure: bool, + restart_sec: float, + start_limit_interval_sec: float | None, + start_limit_burst: int, + exit_on_failure: bool, ) -> None: super().__init__() self._autostart = autostart + self._restart_on_failure = restart_on_failure + self._restart_sec = restart_sec + self._start_limit_interval_sec = start_limit_interval_sec + self._start_limit_burst = start_limit_burst + self._exit_on_failure = exit_on_failure self._func_name = func.__name__ self._func = func self._task: asyncio.Task[R | None] | None = None @@ -109,37 +141,95 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]): self._task = None self._status = TaskStatus.NOT_RUNNING - exception = task.exception() + exception = None + try: + exception = task.exception() + except asyncio.CancelledError: + return + if exception is not None: - logger.exception( - "Task '%s' encountered an exception: %s: %s", + logger.error( + "Task '%s' encountered an exception: %r", self._func_name, - type(exception).__name__, exception, ) - raise exception - - self._result = task.result() - - async def run_task() -> R | None: - if inspect.iscoroutinefunction(self._func): - logger.info("Starting task %r", self._func_name) - self._status = TaskStatus.RUNNING - res: Coroutine[None, None, R | None] = self._func() - try: - return await res - except asyncio.CancelledError: - logger.info("Task '%s' was cancelled", self._func_name) - return None - logger.warning( - "Cannot start task %r. Function has not been bound yet", self._func_name - ) - return None + os.kill(os.getpid(), signal.SIGTERM) + else: + self._result = task.result() logger.info("Creating task %r", self._func_name) - self._task = self._loop.create_task(run_task()) + self._task = self._loop.create_task(self.__running_task_loop()) self._task.add_done_callback(task_done_callback) + async def __running_task_loop(self) -> R | None: + logger.info("Starting task %r", self._func_name) + self._status = TaskStatus.RUNNING + attempts = 0 + start_time_of_start_limit_interval = None + + while True: + try: + await self._func() + except asyncio.CancelledError: + logger.info("Task '%s' was cancelled", self._func_name) + raise + except Exception as e: + attempts, start_time_of_start_limit_interval = ( + self._handle_task_exception( + e, attempts, start_time_of_start_limit_interval + ) + ) + if not self._should_restart_task( + attempts, start_time_of_start_limit_interval + ): + if self._exit_on_failure: + raise e + break + await asyncio.sleep(self._restart_sec) + return None + + def _handle_task_exception( + self, + exception: Exception, + attempts: int, + start_time_of_start_limit_interval: float | None, + ) -> tuple[int, float]: + """Handle an exception raised during task execution.""" + if start_time_of_start_limit_interval is None: + start_time_of_start_limit_interval = time() + + attempts += 1 + logger.exception( + "Task %r encountered an exception: %r [attempt %s since %s].", + self._func.__name__, + exception, + attempts, + datetime.fromtimestamp(start_time_of_start_limit_interval), + ) + return attempts, start_time_of_start_limit_interval + + def _should_restart_task( + self, attempts: int, start_time_of_start_limit_interval: float + ) -> bool: + """Determine if the task should be restarted.""" + if not self._restart_on_failure: + return False + + if self._start_limit_interval_sec is not None: + if ( + time() - start_time_of_start_limit_interval + ) > self._start_limit_interval_sec: + # Reset attempts if interval is exceeded + start_time_of_start_limit_interval = time() + attempts = 1 + elif attempts > self._start_limit_burst: + logger.error( + "Task %r exceeded restart burst limit. Stopping.", + self._func.__name__, + ) + return False + return True + def stop(self) -> None: """Stops the running asynchronous task by cancelling it.""" diff --git a/tests/task/test_task.py b/tests/task/test_task.py index 9c83d42..ea267ff 100644 --- a/tests/task/test_task.py +++ b/tests/task/test_task.py @@ -289,3 +289,150 @@ async def test_manual_start_with_multiple_service_instances( await asyncio.sleep(0.01) assert "Task 'my_task' was cancelled" in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_restart_on_failure(caplog: LogCaptureFixture) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=True, restart_sec=0.1) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.01) + assert "Task 'my_task' encountered an exception" in caplog.text + caplog.clear() + await asyncio.sleep(0.1) + assert service_instance.my_task.status == TaskStatus.RUNNING + assert "Task 'my_task' encountered an exception" in caplog.text + assert "Triggered task." in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_restart_sec(caplog: LogCaptureFixture) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=True, restart_sec=0.1) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.001) + assert "Triggered task." in caplog.text + caplog.clear() + await asyncio.sleep(0.05) + assert "Triggered task." not in caplog.text + await asyncio.sleep(0.05) + assert "Triggered task." in caplog.text # Ensures the task restarted after 0.2s + + +@pytest.mark.asyncio(scope="function") +async def test_exceeding_start_limit_interval_sec_and_burst( + caplog: LogCaptureFixture, +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.0, + start_limit_interval_sec=1.0, + start_limit_burst=2, + ) + async def my_task(self) -> None: + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.1) + assert "Task 'my_task' exceeded restart burst limit" in caplog.text + assert service_instance.my_task.status == TaskStatus.NOT_RUNNING + + +@pytest.mark.asyncio(scope="function") +async def test_non_exceeding_start_limit_interval_sec_and_burst( + caplog: LogCaptureFixture, +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.1, + start_limit_interval_sec=0.1, + start_limit_burst=2, + ) + async def my_task(self) -> None: + raise Exception("Task failure") + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.5) + assert "Task 'my_task' exceeded restart burst limit" not in caplog.text + assert service_instance.my_task.status == TaskStatus.RUNNING + + +@pytest.mark.asyncio(scope="function") +async def test_exit_on_failure( + monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture +) -> None: + class MyService(pydase.DataService): + @task(restart_on_failure=False, exit_on_failure=True) + async def my_task(self) -> None: + logger.info("Triggered task.") + raise Exception("Critical failure") + + def mock_os_kill(pid: int, signal: int) -> None: + logger.critical("os.kill called with signal=%s and pid=%s", signal, pid) + + monkeypatch.setattr("os.kill", mock_os_kill) + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.1) + assert "os.kill called with signal=" in caplog.text + assert "Task 'my_task' encountered an exception" in caplog.text + + +@pytest.mark.asyncio(scope="function") +async def test_exit_on_failure_exceeding_rate_limit( + monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture +) -> None: + class MyService(pydase.DataService): + @task( + restart_on_failure=True, + restart_sec=0.0, + start_limit_interval_sec=0.1, + start_limit_burst=2, + exit_on_failure=True, + ) + async def my_task(self) -> None: + raise Exception("Critical failure") + + def mock_os_kill(pid: int, signal: int) -> None: + logger.critical("os.kill called with signal=%s and pid=%s", signal, pid) + + monkeypatch.setattr("os.kill", mock_os_kill) + + service_instance = MyService() + state_manager = StateManager(service_instance) + DataServiceObserver(state_manager) + service_instance.my_task.start() + + await asyncio.sleep(0.5) + assert "os.kill called with signal=" in caplog.text + assert "Task 'my_task' encountered an exception" in caplog.text