mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-06 11:50:40 +02:00
test(psi-device-base-utils): add tests for task handler
This commit is contained in:
parent
b75207b7c0
commit
8ed3f37b14
@ -5,7 +5,7 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
from bec_lib.file_utils import get_full_path
|
from bec_lib.file_utils import get_full_path
|
||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
@ -49,11 +49,6 @@ class TaskStatus(DeviceStatus):
|
|||||||
self._state = TaskState.NOT_STARTED
|
self._state = TaskState.NOT_STARTED
|
||||||
self._task_id = str(uuid.uuid4())
|
self._task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@property
|
|
||||||
def exc(self) -> Exception:
|
|
||||||
"""Get the exception of the task"""
|
|
||||||
return self.exception()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> str:
|
def state(self) -> str:
|
||||||
"""Get the state of the task"""
|
"""Get the state of the task"""
|
||||||
@ -61,7 +56,7 @@ class TaskStatus(DeviceStatus):
|
|||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, value: TaskState):
|
def state(self, value: TaskState):
|
||||||
self._state = value
|
self._state = TaskState(value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_id(self) -> bool:
|
def task_id(self) -> bool:
|
||||||
@ -76,8 +71,9 @@ class TaskHandler:
|
|||||||
"""Initialize the handler"""
|
"""Initialize the handler"""
|
||||||
self._tasks = {}
|
self._tasks = {}
|
||||||
self._parent = parent
|
self._parent = parent
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
def submit_task(self, task: callable, run: bool = True) -> TaskStatus:
|
def submit_task(self, task: Callable, run: bool = True) -> TaskStatus:
|
||||||
"""Submit a task to the task handler.
|
"""Submit a task to the task handler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -109,7 +105,7 @@ class TaskHandler:
|
|||||||
thread.start()
|
thread.start()
|
||||||
task_status.state = TaskState.RUNNING
|
task_status.state = TaskState.RUNNING
|
||||||
|
|
||||||
def _wrap_task(self, task: callable, task_status: TaskStatus):
|
def _wrap_task(self, task: Callable, task_status: TaskStatus):
|
||||||
"""Wrap the task in a function"""
|
"""Wrap the task in a function"""
|
||||||
try:
|
try:
|
||||||
task()
|
task()
|
||||||
@ -121,8 +117,8 @@ class TaskHandler:
|
|||||||
f" Traceback: {content}"
|
f" Traceback: {content}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
task_status.set_exception(exc)
|
|
||||||
task_status.state = TaskState.TIMEOUT
|
task_status.state = TaskState.TIMEOUT
|
||||||
|
task_status.set_exception(exc)
|
||||||
except TaskKilledError as exc:
|
except TaskKilledError as exc:
|
||||||
exc = exc.__class__(
|
exc = exc.__class__(
|
||||||
f"Task {task_status.task_id} was killed. ThreadID:"
|
f"Task {task_status.task_id} was killed. ThreadID:"
|
||||||
@ -135,20 +131,21 @@ class TaskHandler:
|
|||||||
f" Traceback: {content}"
|
f" Traceback: {content}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
task_status.set_exception(exc)
|
|
||||||
task_status.state = TaskState.KILLED
|
task_status.state = TaskState.KILLED
|
||||||
|
task_status.set_exception(exc)
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
content = traceback.format_exc()
|
content = traceback.format_exc()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Exception in task handler for task {task_status.task_id}, Traceback: {content}"
|
f"Exception in task handler for task {task_status.task_id}, Traceback: {content}"
|
||||||
)
|
)
|
||||||
task_status.set_exception(exc)
|
|
||||||
task_status.state = TaskState.ERROR
|
task_status.state = TaskState.ERROR
|
||||||
|
task_status.set_exception(exc)
|
||||||
else:
|
else:
|
||||||
task_status.set_finished()
|
|
||||||
task_status.state = TaskState.COMPLETED
|
task_status.state = TaskState.COMPLETED
|
||||||
|
task_status.set_finished()
|
||||||
finally:
|
finally:
|
||||||
self._tasks.pop(task_status.task_id)
|
with self._lock:
|
||||||
|
self._tasks.pop(task_status.task_id)
|
||||||
|
|
||||||
def kill_task(self, task_status: TaskStatus) -> None:
|
def kill_task(self, task_status: TaskStatus) -> None:
|
||||||
"""Kill the thread
|
"""Kill the thread
|
||||||
@ -172,9 +169,9 @@ class TaskHandler:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown all tasks of task handler"""
|
"""Shutdown all tasks of task handler"""
|
||||||
for info in self._tasks.values():
|
with self._lock:
|
||||||
self.kill_task(info[0])
|
for info in self._tasks.values():
|
||||||
self._tasks.clear()
|
self.kill_task(info[0])
|
||||||
|
|
||||||
|
|
||||||
class FileHandler:
|
class FileHandler:
|
||||||
|
144
tests/test_utils.py
Normal file
144
tests/test_utils.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from ophyd import Device
|
||||||
|
|
||||||
|
from ophyd_devices.utils.psi_device_base_utils import (
|
||||||
|
FileHandler,
|
||||||
|
TaskHandler,
|
||||||
|
TaskKilledError,
|
||||||
|
TaskState,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_handler():
|
||||||
|
"""Fixture for FileHandler"""
|
||||||
|
yield FileHandler()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def device():
|
||||||
|
"""Fixture for Device"""
|
||||||
|
yield Device(name="device")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_handler(device):
|
||||||
|
"""Fixture for TaskHandler"""
|
||||||
|
yield TaskHandler(parent=device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_file_handler_has_full_path(file_handler):
|
||||||
|
"""Ensure that file_handler has a get_full_path method"""
|
||||||
|
assert hasattr(file_handler, "get_full_path")
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_task_status(device):
|
||||||
|
"""Test TaskStatus creation"""
|
||||||
|
status = TaskStatus(device=device)
|
||||||
|
assert status.device.name == "device"
|
||||||
|
assert status.state == "not_started"
|
||||||
|
assert status.task_id == status._task_id
|
||||||
|
status.state = "running"
|
||||||
|
assert status.state == TaskState.RUNNING
|
||||||
|
status.state = TaskState.COMPLETED
|
||||||
|
assert status.state == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(100)
|
||||||
|
def test_utils_task_handler_task_killed(task_handler):
|
||||||
|
"""Ensure that task_handler has a submit_task method"""
|
||||||
|
# No tasks should be running
|
||||||
|
assert len(task_handler._tasks) == 0
|
||||||
|
event = threading.Event()
|
||||||
|
task_stopped = threading.Event()
|
||||||
|
task_started = threading.Event()
|
||||||
|
|
||||||
|
def finished_cb():
|
||||||
|
task_stopped.set()
|
||||||
|
|
||||||
|
def my_wait_task():
|
||||||
|
task_started.set()
|
||||||
|
for _ in range(100):
|
||||||
|
event.wait(timeout=0.1)
|
||||||
|
|
||||||
|
# Create task
|
||||||
|
status = task_handler.submit_task(my_wait_task, run=False)
|
||||||
|
status.add_callback(finished_cb)
|
||||||
|
assert status.state == TaskState.NOT_STARTED
|
||||||
|
# Start task
|
||||||
|
task_handler.start_task(status)
|
||||||
|
task_started.wait()
|
||||||
|
assert status.state == TaskState.RUNNING
|
||||||
|
# Stop task
|
||||||
|
task_handler.kill_task(status)
|
||||||
|
task_stopped.wait()
|
||||||
|
assert status.state == TaskState.KILLED
|
||||||
|
assert status.exception().__class__ == TaskKilledError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(100)
|
||||||
|
def test_utils_task_handler_task_successful(task_handler):
|
||||||
|
"""Ensure that the task handler runs a successful task"""
|
||||||
|
assert len(task_handler._tasks) == 0
|
||||||
|
event = threading.Event()
|
||||||
|
task_stopped = threading.Event()
|
||||||
|
task_started = threading.Event()
|
||||||
|
|
||||||
|
def finished_cb():
|
||||||
|
task_stopped.set()
|
||||||
|
|
||||||
|
def my_wait_task():
|
||||||
|
task_started.set()
|
||||||
|
for _ in range(100):
|
||||||
|
ret = event.wait(timeout=0.1)
|
||||||
|
if ret is True:
|
||||||
|
break
|
||||||
|
|
||||||
|
status = task_handler.submit_task(my_wait_task, run=False)
|
||||||
|
status.add_callback(finished_cb)
|
||||||
|
task_handler.start_task(status)
|
||||||
|
task_started.wait()
|
||||||
|
assert status.state == TaskState.RUNNING
|
||||||
|
event.set()
|
||||||
|
task_stopped.wait()
|
||||||
|
assert status.state == TaskState.COMPLETED
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_task_handler_shutdown(task_handler):
|
||||||
|
"""Test to shutdown the handler"""
|
||||||
|
|
||||||
|
task_completed_cb1 = threading.Event()
|
||||||
|
task_completed_cb2 = threading.Event()
|
||||||
|
|
||||||
|
def finished_cb1():
|
||||||
|
task_completed_cb1.set()
|
||||||
|
|
||||||
|
def finished_cb2():
|
||||||
|
task_completed_cb2.set()
|
||||||
|
|
||||||
|
def cb1():
|
||||||
|
for _ in range(1000):
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
def cb2():
|
||||||
|
for _ in range(1000):
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
status1 = task_handler.submit_task(cb1)
|
||||||
|
status1.add_callback(finished_cb1)
|
||||||
|
status2 = task_handler.submit_task(cb2)
|
||||||
|
status2.add_callback(finished_cb2)
|
||||||
|
assert len(task_handler._tasks) == 2
|
||||||
|
assert status1.state == TaskState.RUNNING
|
||||||
|
assert status2.state == TaskState.RUNNING
|
||||||
|
task_handler.shutdown()
|
||||||
|
task_completed_cb1.wait()
|
||||||
|
task_completed_cb2.wait()
|
||||||
|
assert len(task_handler._tasks) == 0
|
||||||
|
assert status1.state == TaskState.KILLED
|
||||||
|
assert status2.state == TaskState.KILLED
|
||||||
|
assert status1.exception().__class__ == TaskKilledError
|
Loading…
x
Reference in New Issue
Block a user