mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-07 20:10:42 +02:00
feat (psi-device-base-utils): enhance TaskHandler to support task arguments in submit_task method
This commit is contained in:
parent
cfad4c09f4
commit
00ca4574ea
@ -73,17 +73,25 @@ class TaskHandler:
|
|||||||
self._parent = parent
|
self._parent = parent
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
def submit_task(self, task: Callable, run: bool = True) -> TaskStatus:
|
def submit_task(
|
||||||
|
self,
|
||||||
|
task: Callable,
|
||||||
|
task_args: tuple | None = None,
|
||||||
|
task_kwargs: dict | None = None,
|
||||||
|
run: bool = True,
|
||||||
|
) -> TaskStatus:
|
||||||
"""Submit a task to the task handler.
|
"""Submit a task to the task handler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The task to run.
|
task: The task to run.
|
||||||
run: Whether to run the task immediately.
|
run: Whether to run the task immediately.
|
||||||
"""
|
"""
|
||||||
|
task_args = task_args if task_args else ()
|
||||||
|
task_kwargs = task_kwargs if task_kwargs else {}
|
||||||
task_status = TaskStatus(device=self._parent)
|
task_status = TaskStatus(device=self._parent)
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._wrap_task,
|
target=self._wrap_task,
|
||||||
args=(task, task_status),
|
args=(task, task_args, task_kwargs, task_status),
|
||||||
name=f"task {task_status.task_id}",
|
name=f"task {task_status.task_id}",
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
@ -102,13 +110,15 @@ class TaskHandler:
|
|||||||
if thread.is_alive():
|
if thread.is_alive():
|
||||||
logger.warning(f"Task with ID {task_status.task_id} is already running.")
|
logger.warning(f"Task with ID {task_status.task_id} is already running.")
|
||||||
return
|
return
|
||||||
thread.start()
|
|
||||||
task_status.state = TaskState.RUNNING
|
task_status.state = TaskState.RUNNING
|
||||||
|
thread.start()
|
||||||
|
|
||||||
def _wrap_task(self, task: Callable, task_status: TaskStatus):
|
def _wrap_task(
|
||||||
|
self, task: Callable, task_args: tuple, task_kwargs: dict, task_status: TaskStatus
|
||||||
|
):
|
||||||
"""Wrap the task in a function"""
|
"""Wrap the task in a function"""
|
||||||
try:
|
try:
|
||||||
task()
|
task(*task_args, **task_kwargs)
|
||||||
except TimeoutError as exc:
|
except TimeoutError as exc:
|
||||||
content = traceback.format_exc()
|
content = traceback.format_exc()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -48,6 +48,44 @@ def test_utils_task_status(device):
|
|||||||
assert status.state == "completed"
|
assert status.state == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_utils_task_handler_submit_task_with_args(task_handler):
|
||||||
|
"""Ensure that task_handler has a submit_task method"""
|
||||||
|
|
||||||
|
def my_task(input_arg: bool, input_kwarg: bool = False):
|
||||||
|
if input_kwarg is True:
|
||||||
|
raise ValueError("input_kwarg is True")
|
||||||
|
if input_arg is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# This should fail
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
status = task_handler.submit_task(my_task)
|
||||||
|
status.wait()
|
||||||
|
# This should pass
|
||||||
|
|
||||||
|
task_stopped = threading.Event()
|
||||||
|
|
||||||
|
def finished_cb():
|
||||||
|
task_stopped.set()
|
||||||
|
|
||||||
|
status = task_handler.submit_task(
|
||||||
|
my_task, task_args=(True,), task_kwargs={"input_kwarg": False}
|
||||||
|
)
|
||||||
|
status.add_callback(finished_cb)
|
||||||
|
task_stopped.wait()
|
||||||
|
assert status.done is True
|
||||||
|
assert status.state == TaskState.COMPLETED
|
||||||
|
# This should fail
|
||||||
|
task_stopped = threading.Event()
|
||||||
|
status = task_handler.submit_task(my_task, task_args=(True,), task_kwargs={"input_kwarg": True})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
status.wait()
|
||||||
|
assert status.state == TaskState.ERROR
|
||||||
|
assert status.done is True
|
||||||
|
assert status.exception().__class__ == ValueError
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(100)
|
@pytest.mark.timeout(100)
|
||||||
def test_utils_task_handler_task_killed(task_handler):
|
def test_utils_task_handler_task_killed(task_handler):
|
||||||
"""Ensure that task_handler has a submit_task method"""
|
"""Ensure that task_handler has a submit_task method"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user