diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 5dc42d7..d202dc9 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -5,7 +5,7 @@ import threading import traceback import uuid from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from bec_lib.file_utils import get_full_path from bec_lib.logger import bec_logger @@ -49,11 +49,6 @@ class TaskStatus(DeviceStatus): self._state = TaskState.NOT_STARTED self._task_id = str(uuid.uuid4()) - @property - def exc(self) -> Exception: - """Get the exception of the task""" - return self.exception() - @property def state(self) -> str: """Get the state of the task""" @@ -61,7 +56,7 @@ class TaskStatus(DeviceStatus): @state.setter def state(self, value: TaskState): - self._state = value + self._state = TaskState(value) @property def task_id(self) -> bool: @@ -76,8 +71,9 @@ class TaskHandler: """Initialize the handler""" self._tasks = {} self._parent = parent + self._lock = threading.RLock() - def submit_task(self, task: callable, run: bool = True) -> TaskStatus: + def submit_task(self, task: Callable, run: bool = True) -> TaskStatus: """Submit a task to the task handler. Args: @@ -109,7 +105,7 @@ class TaskHandler: thread.start() task_status.state = TaskState.RUNNING - def _wrap_task(self, task: callable, task_status: TaskStatus): + def _wrap_task(self, task: Callable, task_status: TaskStatus): """Wrap the task in a function""" try: task() @@ -121,8 +117,8 @@ class TaskHandler: f" Traceback: {content}" ) ) - task_status.set_exception(exc) task_status.state = TaskState.TIMEOUT + task_status.set_exception(exc) except TaskKilledError as exc: exc = exc.__class__( f"Task {task_status.task_id} was killed. ThreadID:" @@ -135,20 +131,21 @@ class TaskHandler: f" Traceback: {content}" ) ) - task_status.set_exception(exc) task_status.state = TaskState.KILLED + task_status.set_exception(exc) except Exception as exc: # pylint: disable=broad-except content = traceback.format_exc() logger.warning( f"Exception in task handler for task {task_status.task_id}, Traceback: {content}" ) - task_status.set_exception(exc) task_status.state = TaskState.ERROR + task_status.set_exception(exc) else: - task_status.set_finished() task_status.state = TaskState.COMPLETED + task_status.set_finished() finally: - self._tasks.pop(task_status.task_id) + with self._lock: + self._tasks.pop(task_status.task_id) def kill_task(self, task_status: TaskStatus) -> None: """Kill the thread @@ -172,9 +169,9 @@ class TaskHandler: def shutdown(self): """Shutdown all tasks of task handler""" - for info in self._tasks.values(): - self.kill_task(info[0]) - self._tasks.clear() + with self._lock: + for info in self._tasks.values(): + self.kill_task(info[0]) class FileHandler: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b6cbc9f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,144 @@ +import threading +import time + +import pytest +from ophyd import Device + +from ophyd_devices.utils.psi_device_base_utils import ( + FileHandler, + TaskHandler, + TaskKilledError, + TaskState, + TaskStatus, +) + + +@pytest.fixture +def file_handler(): + """Fixture for FileHandler""" + yield FileHandler() + + +@pytest.fixture +def device(): + """Fixture for Device""" + yield Device(name="device") + + +@pytest.fixture +def task_handler(device): + """Fixture for TaskHandler""" + yield TaskHandler(parent=device) + + +def test_utils_file_handler_has_full_path(file_handler): + """Ensure that file_handler has a get_full_path method""" + assert hasattr(file_handler, "get_full_path") + + +def test_utils_task_status(device): + """Test TaskStatus creation""" + status = TaskStatus(device=device) + assert status.device.name == "device" + assert status.state == "not_started" + assert status.task_id == status._task_id + status.state = "running" + assert status.state == TaskState.RUNNING + status.state = TaskState.COMPLETED + assert status.state == "completed" + + +@pytest.mark.timeout(100) +def test_utils_task_handler_task_killed(task_handler): + """Ensure that task_handler has a submit_task method""" + # No tasks should be running + assert len(task_handler._tasks) == 0 + event = threading.Event() + task_stopped = threading.Event() + task_started = threading.Event() + + def finished_cb(): + task_stopped.set() + + def my_wait_task(): + task_started.set() + for _ in range(100): + event.wait(timeout=0.1) + + # Create task + status = task_handler.submit_task(my_wait_task, run=False) + status.add_callback(finished_cb) + assert status.state == TaskState.NOT_STARTED + # Start task + task_handler.start_task(status) + task_started.wait() + assert status.state == TaskState.RUNNING + # Stop task + task_handler.kill_task(status) + task_stopped.wait() + assert status.state == TaskState.KILLED + assert status.exception().__class__ == TaskKilledError + + +@pytest.mark.timeout(100) +def test_utils_task_handler_task_successful(task_handler): + """Ensure that the task handler runs a successful task""" + assert len(task_handler._tasks) == 0 + event = threading.Event() + task_stopped = threading.Event() + task_started = threading.Event() + + def finished_cb(): + task_stopped.set() + + def my_wait_task(): + task_started.set() + for _ in range(100): + ret = event.wait(timeout=0.1) + if ret is True: + break + + status = task_handler.submit_task(my_wait_task, run=False) + status.add_callback(finished_cb) + task_handler.start_task(status) + task_started.wait() + assert status.state == TaskState.RUNNING + event.set() + task_stopped.wait() + assert status.state == TaskState.COMPLETED + + +def test_utils_task_handler_shutdown(task_handler): + """Test to shutdown the handler""" + + task_completed_cb1 = threading.Event() + task_completed_cb2 = threading.Event() + + def finished_cb1(): + task_completed_cb1.set() + + def finished_cb2(): + task_completed_cb2.set() + + def cb1(): + for _ in range(1000): + time.sleep(0.2) + + def cb2(): + for _ in range(1000): + time.sleep(0.2) + + status1 = task_handler.submit_task(cb1) + status1.add_callback(finished_cb1) + status2 = task_handler.submit_task(cb2) + status2.add_callback(finished_cb2) + assert len(task_handler._tasks) == 2 + assert status1.state == TaskState.RUNNING + assert status2.state == TaskState.RUNNING + task_handler.shutdown() + task_completed_cb1.wait() + task_completed_cb2.wait() + assert len(task_handler._tasks) == 0 + assert status1.state == TaskState.KILLED + assert status2.state == TaskState.KILLED + assert status1.exception().__class__ == TaskKilledError