Merge pull request #202 from tiqi-group/feat/add_more_task_config_options

Feat: add more task config options
This commit is contained in:
Mose Müller 2025-01-17 20:40:28 +01:00 committed by GitHub
commit 36ee760610
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 320 additions and 34 deletions

View File

@ -26,15 +26,25 @@ class PerInstanceTaskDescriptor(Generic[R]):
the service class.
"""
def __init__(
def __init__( # noqa: PLR0913
self,
func: Callable[[Any], Coroutine[None, None, R]]
| Callable[[], Coroutine[None, None, R]],
autostart: bool = False,
autostart: bool,
restart_on_failure: bool,
restart_sec: float,
start_limit_interval_sec: float | None,
start_limit_burst: int,
exit_on_failure: bool,
) -> None:
self.__func = func
self.__autostart = autostart
self.__task_instances: dict[object, Task[R]] = {}
self.__restart_on_failure = restart_on_failure
self.__restart_sec = restart_sec
self.__start_limit_interval_sec = start_limit_interval_sec
self.__start_limit_burst = start_limit_burst
self.__exit_on_failure = exit_on_failure
def __set_name__(self, owner: type[DataService], name: str) -> None:
"""Stores the name of the task within the owning class. This method is called
@ -67,14 +77,28 @@ class PerInstanceTaskDescriptor(Generic[R]):
if instance not in self.__task_instances:
self.__task_instances[instance] = instance._initialise_new_objects(
self.__task_name,
Task(self.__func.__get__(instance, owner), autostart=self.__autostart),
Task(
self.__func.__get__(instance, owner),
autostart=self.__autostart,
restart_on_failure=self.__restart_on_failure,
restart_sec=self.__restart_sec,
start_limit_interval_sec=self.__start_limit_interval_sec,
start_limit_burst=self.__start_limit_burst,
exit_on_failure=self.__exit_on_failure,
),
)
return self.__task_instances[instance]
def task(
*, autostart: bool = False
def task( # noqa: PLR0913
*,
autostart: bool = False,
restart_on_failure: bool = True,
restart_sec: float = 1.0,
start_limit_interval_sec: float | None = None,
start_limit_burst: int = 3,
exit_on_failure: bool = False,
) -> Callable[
[
Callable[[Any], Coroutine[None, None, R]]
@ -96,13 +120,30 @@ def task(
periodically or perform asynchronous operations, such as polling data sources,
updating databases, or any recurring job that should be managed within the context
of a `DataService`.
time.
The keyword arguments that can be passed to this decorator are inspired by systemd
unit services.
Args:
autostart:
If set to True, the task will automatically start when the service is
initialized. Defaults to False.
restart_on_failure:
Configures whether the task shall be restarted when it exits with an
exception other than [`asyncio.CancelledError`][asyncio.CancelledError].
restart_sec:
Configures the time to sleep before restarting a task. Defaults to 1.0.
start_limit_interval_sec:
Configures start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to None (disabled rate limiting).
start_limit_burst:
Configures unit start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to 3.
exit_on_failure:
If True, exit the service if the task fails and restart_on_failure is False
or burst limits are exceeded.
Returns:
A decorator that wraps an asynchronous function in a
[`PerInstanceTaskDescriptor`][pydase.task.decorator.PerInstanceTaskDescriptor]
@ -140,6 +181,14 @@ def task(
func: Callable[[Any], Coroutine[None, None, R]]
| Callable[[], Coroutine[None, None, R]],
) -> PerInstanceTaskDescriptor[R]:
return PerInstanceTaskDescriptor(func, autostart=autostart)
return PerInstanceTaskDescriptor(
func,
autostart=autostart,
restart_on_failure=restart_on_failure,
restart_sec=restart_sec,
start_limit_interval_sec=start_limit_interval_sec,
start_limit_burst=start_limit_burst,
exit_on_failure=exit_on_failure,
)
return decorator

View File

@ -1,7 +1,10 @@
import asyncio
import inspect
import logging
import os
import signal
from collections.abc import Callable, Coroutine
from datetime import datetime
from time import time
from typing import (
Generic,
TypeVar,
@ -28,6 +31,9 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
decorator, it is replaced by a `Task` instance that controls the execution of the
original function.
The keyword arguments that can be passed to this class are inspired by systemd unit
services.
Args:
func:
The asynchronous function that this task wraps. It must be a coroutine
@ -35,6 +41,22 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
autostart:
If set to True, the task will automatically start when the service is
initialized. Defaults to False.
restart_on_failure:
Configures whether the task shall be restarted when it exits with an
exception other than [`asyncio.CancelledError`][asyncio.CancelledError].
restart_sec:
Configures the time to sleep before restarting a task. Defaults to 1.0.
start_limit_interval_sec:
Configures start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to None (disabled rate limiting).
start_limit_burst:
Configures unit start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to 3.
exit_on_failure:
If True, exit the service if the task fails and restart_on_failure is False
or burst limits are exceeded.
Example:
```python
@ -63,14 +85,24 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
`service.my_task.start()` and `service.my_task.stop()`, respectively.
"""
def __init__(
def __init__( # noqa: PLR0913
self,
func: Callable[[], Coroutine[None, None, R | None]],
*,
autostart: bool = False,
autostart: bool,
restart_on_failure: bool,
restart_sec: float,
start_limit_interval_sec: float | None,
start_limit_burst: int,
exit_on_failure: bool,
) -> None:
super().__init__()
self._autostart = autostart
self._restart_on_failure = restart_on_failure
self._restart_sec = restart_sec
self._start_limit_interval_sec = start_limit_interval_sec
self._start_limit_burst = start_limit_burst
self._exit_on_failure = exit_on_failure
self._func_name = func.__name__
self._func = func
self._task: asyncio.Task[R | None] | None = None
@ -109,36 +141,94 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
self._task = None
self._status = TaskStatus.NOT_RUNNING
exception = None
try:
exception = task.exception()
except asyncio.CancelledError:
return
if exception is not None:
logger.exception(
"Task '%s' encountered an exception: %s: %s",
logger.error(
"Task '%s' encountered an exception: %r",
self._func_name,
type(exception).__name__,
exception,
)
raise exception
os.kill(os.getpid(), signal.SIGTERM)
else:
self._result = task.result()
async def run_task() -> R | None:
if inspect.iscoroutinefunction(self._func):
logger.info("Creating task %r", self._func_name)
self._task = self._loop.create_task(self.__running_task_loop())
self._task.add_done_callback(task_done_callback)
async def __running_task_loop(self) -> R | None:
logger.info("Starting task %r", self._func_name)
self._status = TaskStatus.RUNNING
res: Coroutine[None, None, R | None] = self._func()
attempts = 0
start_time_of_start_limit_interval = None
while True:
try:
return await res
await self._func()
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", self._func_name)
return None
logger.warning(
"Cannot start task %r. Function has not been bound yet", self._func_name
raise
except Exception as e:
attempts, start_time_of_start_limit_interval = (
self._handle_task_exception(
e, attempts, start_time_of_start_limit_interval
)
)
if not self._should_restart_task(
attempts, start_time_of_start_limit_interval
):
if self._exit_on_failure:
raise e
break
await asyncio.sleep(self._restart_sec)
return None
logger.info("Creating task %r", self._func_name)
self._task = self._loop.create_task(run_task())
self._task.add_done_callback(task_done_callback)
def _handle_task_exception(
self,
exception: Exception,
attempts: int,
start_time_of_start_limit_interval: float | None,
) -> tuple[int, float]:
"""Handle an exception raised during task execution."""
if start_time_of_start_limit_interval is None:
start_time_of_start_limit_interval = time()
attempts += 1
logger.exception(
"Task %r encountered an exception: %r [attempt %s since %s].",
self._func.__name__,
exception,
attempts,
datetime.fromtimestamp(start_time_of_start_limit_interval),
)
return attempts, start_time_of_start_limit_interval
def _should_restart_task(
self, attempts: int, start_time_of_start_limit_interval: float
) -> bool:
"""Determine if the task should be restarted."""
if not self._restart_on_failure:
return False
if self._start_limit_interval_sec is not None:
if (
time() - start_time_of_start_limit_interval
) > self._start_limit_interval_sec:
# Reset attempts if interval is exceeded
start_time_of_start_limit_interval = time()
attempts = 1
elif attempts > self._start_limit_burst:
logger.error(
"Task %r exceeded restart burst limit. Stopping.",
self._func.__name__,
)
return False
return True
def stop(self) -> None:
"""Stops the running asynchronous task by cancelling it."""

View File

@ -289,3 +289,150 @@ async def test_manual_start_with_multiple_service_instances(
await asyncio.sleep(0.01)
assert "Task 'my_task' was cancelled" in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_restart_on_failure(caplog: LogCaptureFixture) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=True, restart_sec=0.1)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.01)
assert "Task 'my_task' encountered an exception" in caplog.text
caplog.clear()
await asyncio.sleep(0.1)
assert service_instance.my_task.status == TaskStatus.RUNNING
assert "Task 'my_task' encountered an exception" in caplog.text
assert "Triggered task." in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_restart_sec(caplog: LogCaptureFixture) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=True, restart_sec=0.1)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.001)
assert "Triggered task." in caplog.text
caplog.clear()
await asyncio.sleep(0.05)
assert "Triggered task." not in caplog.text
await asyncio.sleep(0.05)
assert "Triggered task." in caplog.text # Ensures the task restarted after 0.2s
@pytest.mark.asyncio(scope="function")
async def test_exceeding_start_limit_interval_sec_and_burst(
caplog: LogCaptureFixture,
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.0,
start_limit_interval_sec=1.0,
start_limit_burst=2,
)
async def my_task(self) -> None:
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.1)
assert "Task 'my_task' exceeded restart burst limit" in caplog.text
assert service_instance.my_task.status == TaskStatus.NOT_RUNNING
@pytest.mark.asyncio(scope="function")
async def test_non_exceeding_start_limit_interval_sec_and_burst(
caplog: LogCaptureFixture,
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.1,
start_limit_interval_sec=0.1,
start_limit_burst=2,
)
async def my_task(self) -> None:
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.5)
assert "Task 'my_task' exceeded restart burst limit" not in caplog.text
assert service_instance.my_task.status == TaskStatus.RUNNING
@pytest.mark.asyncio(scope="function")
async def test_exit_on_failure(
monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture
) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=False, exit_on_failure=True)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Critical failure")
def mock_os_kill(pid: int, signal: int) -> None:
logger.critical("os.kill called with signal=%s and pid=%s", signal, pid)
monkeypatch.setattr("os.kill", mock_os_kill)
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.1)
assert "os.kill called with signal=" in caplog.text
assert "Task 'my_task' encountered an exception" in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_exit_on_failure_exceeding_rate_limit(
monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.0,
start_limit_interval_sec=0.1,
start_limit_burst=2,
exit_on_failure=True,
)
async def my_task(self) -> None:
raise Exception("Critical failure")
def mock_os_kill(pid: int, signal: int) -> None:
logger.critical("os.kill called with signal=%s and pid=%s", signal, pid)
monkeypatch.setattr("os.kill", mock_os_kill)
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.5)
assert "os.kill called with signal=" in caplog.text
assert "Task 'my_task' encountered an exception" in caplog.text