diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index dcff133..269e94e 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -1,4 +1,4 @@ -""" Utility handler to run tasks (function, conditions) in an asynchronous fashion.""" +"""Utility handler to run tasks (function, conditions) in an asynchronous fashion.""" import ctypes import threading @@ -73,17 +73,25 @@ class TaskHandler: self._parent = parent self._lock = threading.RLock() - def submit_task(self, task: Callable, run: bool = True) -> TaskStatus: + def submit_task( + self, + task: Callable, + task_args: tuple | None = None, + task_kwargs: dict | None = None, + run: bool = True, + ) -> TaskStatus: """Submit a task to the task handler. Args: task: The task to run. run: Whether to run the task immediately. """ + task_args = task_args if task_args else () + task_kwargs = task_kwargs if task_kwargs else {} task_status = TaskStatus(device=self._parent) thread = threading.Thread( target=self._wrap_task, - args=(task, task_status), + args=(task, task_args, task_kwargs, task_status), name=f"task {task_status.task_id}", daemon=True, ) @@ -102,13 +110,15 @@ class TaskHandler: if thread.is_alive(): logger.warning(f"Task with ID {task_status.task_id} is already running.") return - thread.start() task_status.state = TaskState.RUNNING + thread.start() - def _wrap_task(self, task: Callable, task_status: TaskStatus): + def _wrap_task( + self, task: Callable, task_args: tuple, task_kwargs: dict, task_status: TaskStatus + ): """Wrap the task in a function""" try: - task() + task(*task_args, **task_kwargs) except TimeoutError as exc: content = traceback.format_exc() logger.warning( diff --git a/tests/test_utils.py b/tests/test_utils.py index b6cbc9f..55e8519 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -48,6 +48,44 @@ def test_utils_task_status(device): assert status.state == "completed" +def test_utils_task_handler_submit_task_with_args(task_handler): + """Ensure that task_handler has a submit_task method""" + + def my_task(input_arg: bool, input_kwarg: bool = False): + if input_kwarg is True: + raise ValueError("input_kwarg is True") + if input_arg is True: + return True + return False + + # This should fail + with pytest.raises(TypeError): + status = task_handler.submit_task(my_task) + status.wait() + # This should pass + + task_stopped = threading.Event() + + def finished_cb(): + task_stopped.set() + + status = task_handler.submit_task( + my_task, task_args=(True,), task_kwargs={"input_kwarg": False} + ) + status.add_callback(finished_cb) + task_stopped.wait() + assert status.done is True + assert status.state == TaskState.COMPLETED + # This should fail + task_stopped = threading.Event() + status = task_handler.submit_task(my_task, task_args=(True,), task_kwargs={"input_kwarg": True}) + with pytest.raises(ValueError): + status.wait() + assert status.state == TaskState.ERROR + assert status.done is True + assert status.exception().__class__ == ValueError + + @pytest.mark.timeout(100) def test_utils_task_handler_task_killed(task_handler): """Ensure that task_handler has a submit_task method"""