test(psi-device-base-utils): add tests for task handler

This commit is contained in:
appel_c 2025-02-24 17:55:00 +01:00
parent b75207b7c0
commit 8ed3f37b14
2 changed files with 158 additions and 17 deletions

View File

@ -5,7 +5,7 @@ import threading
import traceback
import uuid
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger
@ -49,11 +49,6 @@ class TaskStatus(DeviceStatus):
self._state = TaskState.NOT_STARTED
self._task_id = str(uuid.uuid4())
@property
def exc(self) -> Exception:
"""Get the exception of the task"""
return self.exception()
@property
def state(self) -> str:
"""Get the state of the task"""
@ -61,7 +56,7 @@ class TaskStatus(DeviceStatus):
@state.setter
def state(self, value: TaskState):
self._state = value
self._state = TaskState(value)
@property
def task_id(self) -> bool:
@ -76,8 +71,9 @@ class TaskHandler:
"""Initialize the handler"""
self._tasks = {}
self._parent = parent
self._lock = threading.RLock()
def submit_task(self, task: callable, run: bool = True) -> TaskStatus:
def submit_task(self, task: Callable, run: bool = True) -> TaskStatus:
"""Submit a task to the task handler.
Args:
@ -109,7 +105,7 @@ class TaskHandler:
thread.start()
task_status.state = TaskState.RUNNING
def _wrap_task(self, task: callable, task_status: TaskStatus):
def _wrap_task(self, task: Callable, task_status: TaskStatus):
"""Wrap the task in a function"""
try:
task()
@ -121,8 +117,8 @@ class TaskHandler:
f" Traceback: {content}"
)
)
task_status.set_exception(exc)
task_status.state = TaskState.TIMEOUT
task_status.set_exception(exc)
except TaskKilledError as exc:
exc = exc.__class__(
f"Task {task_status.task_id} was killed. ThreadID:"
@ -135,20 +131,21 @@ class TaskHandler:
f" Traceback: {content}"
)
)
task_status.set_exception(exc)
task_status.state = TaskState.KILLED
task_status.set_exception(exc)
except Exception as exc: # pylint: disable=broad-except
content = traceback.format_exc()
logger.warning(
f"Exception in task handler for task {task_status.task_id}, Traceback: {content}"
)
task_status.set_exception(exc)
task_status.state = TaskState.ERROR
task_status.set_exception(exc)
else:
task_status.set_finished()
task_status.state = TaskState.COMPLETED
task_status.set_finished()
finally:
self._tasks.pop(task_status.task_id)
with self._lock:
self._tasks.pop(task_status.task_id)
def kill_task(self, task_status: TaskStatus) -> None:
"""Kill the thread
@ -172,9 +169,9 @@ class TaskHandler:
def shutdown(self):
"""Shutdown all tasks of task handler"""
for info in self._tasks.values():
self.kill_task(info[0])
self._tasks.clear()
with self._lock:
for info in self._tasks.values():
self.kill_task(info[0])
class FileHandler:

144
tests/test_utils.py Normal file
View File

@ -0,0 +1,144 @@
import threading
import time
import pytest
from ophyd import Device
from ophyd_devices.utils.psi_device_base_utils import (
FileHandler,
TaskHandler,
TaskKilledError,
TaskState,
TaskStatus,
)
@pytest.fixture
def file_handler():
"""Fixture for FileHandler"""
yield FileHandler()
@pytest.fixture
def device():
"""Fixture for Device"""
yield Device(name="device")
@pytest.fixture
def task_handler(device):
"""Fixture for TaskHandler"""
yield TaskHandler(parent=device)
def test_utils_file_handler_has_full_path(file_handler):
"""Ensure that file_handler has a get_full_path method"""
assert hasattr(file_handler, "get_full_path")
def test_utils_task_status(device):
"""Test TaskStatus creation"""
status = TaskStatus(device=device)
assert status.device.name == "device"
assert status.state == "not_started"
assert status.task_id == status._task_id
status.state = "running"
assert status.state == TaskState.RUNNING
status.state = TaskState.COMPLETED
assert status.state == "completed"
@pytest.mark.timeout(100)
def test_utils_task_handler_task_killed(task_handler):
"""Ensure that task_handler has a submit_task method"""
# No tasks should be running
assert len(task_handler._tasks) == 0
event = threading.Event()
task_stopped = threading.Event()
task_started = threading.Event()
def finished_cb():
task_stopped.set()
def my_wait_task():
task_started.set()
for _ in range(100):
event.wait(timeout=0.1)
# Create task
status = task_handler.submit_task(my_wait_task, run=False)
status.add_callback(finished_cb)
assert status.state == TaskState.NOT_STARTED
# Start task
task_handler.start_task(status)
task_started.wait()
assert status.state == TaskState.RUNNING
# Stop task
task_handler.kill_task(status)
task_stopped.wait()
assert status.state == TaskState.KILLED
assert status.exception().__class__ == TaskKilledError
@pytest.mark.timeout(100)
def test_utils_task_handler_task_successful(task_handler):
"""Ensure that the task handler runs a successful task"""
assert len(task_handler._tasks) == 0
event = threading.Event()
task_stopped = threading.Event()
task_started = threading.Event()
def finished_cb():
task_stopped.set()
def my_wait_task():
task_started.set()
for _ in range(100):
ret = event.wait(timeout=0.1)
if ret is True:
break
status = task_handler.submit_task(my_wait_task, run=False)
status.add_callback(finished_cb)
task_handler.start_task(status)
task_started.wait()
assert status.state == TaskState.RUNNING
event.set()
task_stopped.wait()
assert status.state == TaskState.COMPLETED
def test_utils_task_handler_shutdown(task_handler):
"""Test to shutdown the handler"""
task_completed_cb1 = threading.Event()
task_completed_cb2 = threading.Event()
def finished_cb1():
task_completed_cb1.set()
def finished_cb2():
task_completed_cb2.set()
def cb1():
for _ in range(1000):
time.sleep(0.2)
def cb2():
for _ in range(1000):
time.sleep(0.2)
status1 = task_handler.submit_task(cb1)
status1.add_callback(finished_cb1)
status2 = task_handler.submit_task(cb2)
status2.add_callback(finished_cb2)
assert len(task_handler._tasks) == 2
assert status1.state == TaskState.RUNNING
assert status2.state == TaskState.RUNNING
task_handler.shutdown()
task_completed_cb1.wait()
task_completed_cb2.wait()
assert len(task_handler._tasks) == 0
assert status1.state == TaskState.KILLED
assert status2.state == TaskState.KILLED
assert status1.exception().__class__ == TaskKilledError