Merge pull request #202 from tiqi-group/feat/add_more_task_config_options

Feat: add more task config options
This commit is contained in:
Mose Müller 2025-01-17 20:40:28 +01:00 committed by GitHub
commit 36ee760610
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 320 additions and 34 deletions

View File

@ -26,15 +26,25 @@ class PerInstanceTaskDescriptor(Generic[R]):
the service class. the service class.
""" """
def __init__( def __init__( # noqa: PLR0913
self, self,
func: Callable[[Any], Coroutine[None, None, R]] func: Callable[[Any], Coroutine[None, None, R]]
| Callable[[], Coroutine[None, None, R]], | Callable[[], Coroutine[None, None, R]],
autostart: bool = False, autostart: bool,
restart_on_failure: bool,
restart_sec: float,
start_limit_interval_sec: float | None,
start_limit_burst: int,
exit_on_failure: bool,
) -> None: ) -> None:
self.__func = func self.__func = func
self.__autostart = autostart self.__autostart = autostart
self.__task_instances: dict[object, Task[R]] = {} self.__task_instances: dict[object, Task[R]] = {}
self.__restart_on_failure = restart_on_failure
self.__restart_sec = restart_sec
self.__start_limit_interval_sec = start_limit_interval_sec
self.__start_limit_burst = start_limit_burst
self.__exit_on_failure = exit_on_failure
def __set_name__(self, owner: type[DataService], name: str) -> None: def __set_name__(self, owner: type[DataService], name: str) -> None:
"""Stores the name of the task within the owning class. This method is called """Stores the name of the task within the owning class. This method is called
@ -67,14 +77,28 @@ class PerInstanceTaskDescriptor(Generic[R]):
if instance not in self.__task_instances: if instance not in self.__task_instances:
self.__task_instances[instance] = instance._initialise_new_objects( self.__task_instances[instance] = instance._initialise_new_objects(
self.__task_name, self.__task_name,
Task(self.__func.__get__(instance, owner), autostart=self.__autostart), Task(
self.__func.__get__(instance, owner),
autostart=self.__autostart,
restart_on_failure=self.__restart_on_failure,
restart_sec=self.__restart_sec,
start_limit_interval_sec=self.__start_limit_interval_sec,
start_limit_burst=self.__start_limit_burst,
exit_on_failure=self.__exit_on_failure,
),
) )
return self.__task_instances[instance] return self.__task_instances[instance]
def task( def task( # noqa: PLR0913
*, autostart: bool = False *,
autostart: bool = False,
restart_on_failure: bool = True,
restart_sec: float = 1.0,
start_limit_interval_sec: float | None = None,
start_limit_burst: int = 3,
exit_on_failure: bool = False,
) -> Callable[ ) -> Callable[
[ [
Callable[[Any], Coroutine[None, None, R]] Callable[[Any], Coroutine[None, None, R]]
@ -96,13 +120,30 @@ def task(
periodically or perform asynchronous operations, such as polling data sources, periodically or perform asynchronous operations, such as polling data sources,
updating databases, or any recurring job that should be managed within the context updating databases, or any recurring job that should be managed within the context
of a `DataService`. of a `DataService`.
time.
The keyword arguments that can be passed to this decorator are inspired by systemd
unit services.
Args: Args:
autostart: autostart:
If set to True, the task will automatically start when the service is If set to True, the task will automatically start when the service is
initialized. Defaults to False. initialized. Defaults to False.
restart_on_failure:
Configures whether the task shall be restarted when it exits with an
exception other than [`asyncio.CancelledError`][asyncio.CancelledError].
restart_sec:
Configures the time to sleep before restarting a task. Defaults to 1.0.
start_limit_interval_sec:
Configures start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to None (disabled rate limiting).
start_limit_burst:
Configures unit start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to 3.
exit_on_failure:
If True, exit the service if the task fails and restart_on_failure is False
or burst limits are exceeded.
Returns: Returns:
A decorator that wraps an asynchronous function in a A decorator that wraps an asynchronous function in a
[`PerInstanceTaskDescriptor`][pydase.task.decorator.PerInstanceTaskDescriptor] [`PerInstanceTaskDescriptor`][pydase.task.decorator.PerInstanceTaskDescriptor]
@ -140,6 +181,14 @@ def task(
func: Callable[[Any], Coroutine[None, None, R]] func: Callable[[Any], Coroutine[None, None, R]]
| Callable[[], Coroutine[None, None, R]], | Callable[[], Coroutine[None, None, R]],
) -> PerInstanceTaskDescriptor[R]: ) -> PerInstanceTaskDescriptor[R]:
return PerInstanceTaskDescriptor(func, autostart=autostart) return PerInstanceTaskDescriptor(
func,
autostart=autostart,
restart_on_failure=restart_on_failure,
restart_sec=restart_sec,
start_limit_interval_sec=start_limit_interval_sec,
start_limit_burst=start_limit_burst,
exit_on_failure=exit_on_failure,
)
return decorator return decorator

View File

@ -1,7 +1,10 @@
import asyncio import asyncio
import inspect
import logging import logging
import os
import signal
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from datetime import datetime
from time import time
from typing import ( from typing import (
Generic, Generic,
TypeVar, TypeVar,
@ -28,6 +31,9 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
decorator, it is replaced by a `Task` instance that controls the execution of the decorator, it is replaced by a `Task` instance that controls the execution of the
original function. original function.
The keyword arguments that can be passed to this class are inspired by systemd unit
services.
Args: Args:
func: func:
The asynchronous function that this task wraps. It must be a coroutine The asynchronous function that this task wraps. It must be a coroutine
@ -35,6 +41,22 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
autostart: autostart:
If set to True, the task will automatically start when the service is If set to True, the task will automatically start when the service is
initialized. Defaults to False. initialized. Defaults to False.
restart_on_failure:
Configures whether the task shall be restarted when it exits with an
exception other than [`asyncio.CancelledError`][asyncio.CancelledError].
restart_sec:
Configures the time to sleep before restarting a task. Defaults to 1.0.
start_limit_interval_sec:
Configures start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to None (disabled rate limiting).
start_limit_burst:
Configures unit start rate limiting. Tasks which are started more than
`start_limit_burst` times within an `start_limit_interval_sec` time span are
not permitted to start any more. Defaults to 3.
exit_on_failure:
If True, exit the service if the task fails and restart_on_failure is False
or burst limits are exceeded.
Example: Example:
```python ```python
@ -63,14 +85,24 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
`service.my_task.start()` and `service.my_task.stop()`, respectively. `service.my_task.start()` and `service.my_task.stop()`, respectively.
""" """
def __init__( def __init__( # noqa: PLR0913
self, self,
func: Callable[[], Coroutine[None, None, R | None]], func: Callable[[], Coroutine[None, None, R | None]],
*, *,
autostart: bool = False, autostart: bool,
restart_on_failure: bool,
restart_sec: float,
start_limit_interval_sec: float | None,
start_limit_burst: int,
exit_on_failure: bool,
) -> None: ) -> None:
super().__init__() super().__init__()
self._autostart = autostart self._autostart = autostart
self._restart_on_failure = restart_on_failure
self._restart_sec = restart_sec
self._start_limit_interval_sec = start_limit_interval_sec
self._start_limit_burst = start_limit_burst
self._exit_on_failure = exit_on_failure
self._func_name = func.__name__ self._func_name = func.__name__
self._func = func self._func = func
self._task: asyncio.Task[R | None] | None = None self._task: asyncio.Task[R | None] | None = None
@ -109,37 +141,95 @@ class Task(pydase.data_service.data_service.DataService, Generic[R]):
self._task = None self._task = None
self._status = TaskStatus.NOT_RUNNING self._status = TaskStatus.NOT_RUNNING
exception = task.exception() exception = None
try:
exception = task.exception()
except asyncio.CancelledError:
return
if exception is not None: if exception is not None:
logger.exception( logger.error(
"Task '%s' encountered an exception: %s: %s", "Task '%s' encountered an exception: %r",
self._func_name, self._func_name,
type(exception).__name__,
exception, exception,
) )
raise exception os.kill(os.getpid(), signal.SIGTERM)
else:
self._result = task.result() self._result = task.result()
async def run_task() -> R | None:
if inspect.iscoroutinefunction(self._func):
logger.info("Starting task %r", self._func_name)
self._status = TaskStatus.RUNNING
res: Coroutine[None, None, R | None] = self._func()
try:
return await res
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", self._func_name)
return None
logger.warning(
"Cannot start task %r. Function has not been bound yet", self._func_name
)
return None
logger.info("Creating task %r", self._func_name) logger.info("Creating task %r", self._func_name)
self._task = self._loop.create_task(run_task()) self._task = self._loop.create_task(self.__running_task_loop())
self._task.add_done_callback(task_done_callback) self._task.add_done_callback(task_done_callback)
async def __running_task_loop(self) -> R | None:
logger.info("Starting task %r", self._func_name)
self._status = TaskStatus.RUNNING
attempts = 0
start_time_of_start_limit_interval = None
while True:
try:
await self._func()
except asyncio.CancelledError:
logger.info("Task '%s' was cancelled", self._func_name)
raise
except Exception as e:
attempts, start_time_of_start_limit_interval = (
self._handle_task_exception(
e, attempts, start_time_of_start_limit_interval
)
)
if not self._should_restart_task(
attempts, start_time_of_start_limit_interval
):
if self._exit_on_failure:
raise e
break
await asyncio.sleep(self._restart_sec)
return None
def _handle_task_exception(
self,
exception: Exception,
attempts: int,
start_time_of_start_limit_interval: float | None,
) -> tuple[int, float]:
"""Handle an exception raised during task execution."""
if start_time_of_start_limit_interval is None:
start_time_of_start_limit_interval = time()
attempts += 1
logger.exception(
"Task %r encountered an exception: %r [attempt %s since %s].",
self._func.__name__,
exception,
attempts,
datetime.fromtimestamp(start_time_of_start_limit_interval),
)
return attempts, start_time_of_start_limit_interval
def _should_restart_task(
self, attempts: int, start_time_of_start_limit_interval: float
) -> bool:
"""Determine if the task should be restarted."""
if not self._restart_on_failure:
return False
if self._start_limit_interval_sec is not None:
if (
time() - start_time_of_start_limit_interval
) > self._start_limit_interval_sec:
# Reset attempts if interval is exceeded
start_time_of_start_limit_interval = time()
attempts = 1
elif attempts > self._start_limit_burst:
logger.error(
"Task %r exceeded restart burst limit. Stopping.",
self._func.__name__,
)
return False
return True
def stop(self) -> None: def stop(self) -> None:
"""Stops the running asynchronous task by cancelling it.""" """Stops the running asynchronous task by cancelling it."""

View File

@ -289,3 +289,150 @@ async def test_manual_start_with_multiple_service_instances(
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert "Task 'my_task' was cancelled" in caplog.text assert "Task 'my_task' was cancelled" in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_restart_on_failure(caplog: LogCaptureFixture) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=True, restart_sec=0.1)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.01)
assert "Task 'my_task' encountered an exception" in caplog.text
caplog.clear()
await asyncio.sleep(0.1)
assert service_instance.my_task.status == TaskStatus.RUNNING
assert "Task 'my_task' encountered an exception" in caplog.text
assert "Triggered task." in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_restart_sec(caplog: LogCaptureFixture) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=True, restart_sec=0.1)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.001)
assert "Triggered task." in caplog.text
caplog.clear()
await asyncio.sleep(0.05)
assert "Triggered task." not in caplog.text
await asyncio.sleep(0.05)
assert "Triggered task." in caplog.text # Ensures the task restarted after 0.2s
@pytest.mark.asyncio(scope="function")
async def test_exceeding_start_limit_interval_sec_and_burst(
caplog: LogCaptureFixture,
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.0,
start_limit_interval_sec=1.0,
start_limit_burst=2,
)
async def my_task(self) -> None:
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.1)
assert "Task 'my_task' exceeded restart burst limit" in caplog.text
assert service_instance.my_task.status == TaskStatus.NOT_RUNNING
@pytest.mark.asyncio(scope="function")
async def test_non_exceeding_start_limit_interval_sec_and_burst(
caplog: LogCaptureFixture,
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.1,
start_limit_interval_sec=0.1,
start_limit_burst=2,
)
async def my_task(self) -> None:
raise Exception("Task failure")
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.5)
assert "Task 'my_task' exceeded restart burst limit" not in caplog.text
assert service_instance.my_task.status == TaskStatus.RUNNING
@pytest.mark.asyncio(scope="function")
async def test_exit_on_failure(
monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture
) -> None:
class MyService(pydase.DataService):
@task(restart_on_failure=False, exit_on_failure=True)
async def my_task(self) -> None:
logger.info("Triggered task.")
raise Exception("Critical failure")
def mock_os_kill(pid: int, signal: int) -> None:
logger.critical("os.kill called with signal=%s and pid=%s", signal, pid)
monkeypatch.setattr("os.kill", mock_os_kill)
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.1)
assert "os.kill called with signal=" in caplog.text
assert "Task 'my_task' encountered an exception" in caplog.text
@pytest.mark.asyncio(scope="function")
async def test_exit_on_failure_exceeding_rate_limit(
monkeypatch: pytest.MonkeyPatch, caplog: LogCaptureFixture
) -> None:
class MyService(pydase.DataService):
@task(
restart_on_failure=True,
restart_sec=0.0,
start_limit_interval_sec=0.1,
start_limit_burst=2,
exit_on_failure=True,
)
async def my_task(self) -> None:
raise Exception("Critical failure")
def mock_os_kill(pid: int, signal: int) -> None:
logger.critical("os.kill called with signal=%s and pid=%s", signal, pid)
monkeypatch.setattr("os.kill", mock_os_kill)
service_instance = MyService()
state_manager = StateManager(service_instance)
DataServiceObserver(state_manager)
service_instance.my_task.start()
await asyncio.sleep(0.5)
assert "os.kill called with signal=" in caplog.text
assert "Task 'my_task' encountered an exception" in caplog.text