diff --git a/bec_widgets/applications/companion_app.py b/bec_widgets/applications/companion_app.py index f229b677..43524194 100644 --- a/bec_widgets/applications/companion_app.py +++ b/bec_widgets/applications/companion_app.py @@ -5,6 +5,7 @@ import json import os import signal import sys +import traceback from contextlib import redirect_stderr, redirect_stdout import darkdetect @@ -63,6 +64,7 @@ class GUIServer: self.app: QApplication | None = None self.launcher_window: LaunchWindow | None = None self.dispatcher: BECDispatcher | None = None + self._shutdown_started = False def start(self): """ @@ -123,17 +125,8 @@ class GUIServer: self.app.aboutToQuit.connect(self.shutdown) self.app.setQuitOnLastWindowClosed(True) - def sigint_handler(*args): - # display message, for people to let it terminate gracefully - print("Caught SIGINT, exiting") - # Widgets should be all closed. - with RPCRegister.delayed_broadcast(): - for widget in QApplication.instance().topLevelWidgets(): # type: ignore - widget.close() - self.shutdown() - - signal.signal(signal.SIGINT, sigint_handler) - signal.signal(signal.SIGTERM, sigint_handler) + signal.signal(signal.SIGINT, self.request_shutdown) + signal.signal(signal.SIGTERM, self.request_shutdown) sys.exit(self.app.exec()) @@ -150,16 +143,67 @@ class GUIServer: ) self.app.setWindowIcon(icon) + def request_shutdown(self, signum=None, _frame=None): + """ + Request Qt application shutdown from an RPC call or OS signal. + + Cleanup itself is handled by ``shutdown()``, which is connected to + ``QApplication.aboutToQuit``. Calling it directly here would run BEC/RPC + teardown before Qt has processed the widget close events. + """ + signal_name = signal.Signals(signum).name if signum is not None else "shutdown" + pid = os.getpid() + if self.app is None: + logger.info(f"Caught {signal_name}, shutting down GUI server pid={pid} without app") + self.shutdown() + return + + widgets = [ + f"{widget.__class__.__name__}(objectName={widget.objectName()!r})" + for widget in self.app.topLevelWidgets() + ] + logger.info( + f"Caught {signal_name}, requesting GUI server shutdown pid={pid} " + f"top_level_widgets={widgets}" + ) + with RPCRegister.delayed_broadcast(): + for widget in self.app.topLevelWidgets(): + widget.close() + self.app.quit() + + @staticmethod + def _run_shutdown_step(step: str, callback): + try: + callback() + except Exception as exc: + logger.error( + f"GUIServer shutdown step failed pid={os.getpid()} step={step}: {exc}\n" + f"{traceback.format_exc()}" + ) + def shutdown(self): - logger.info("Shutdown GUIServer", repr(self)) - if self.launcher_window and shiboken6.isValid(self.launcher_window): - self.launcher_window.close() - self.launcher_window.deleteLater() - if pylsp_server.is_running(): - pylsp_server.stop() - if self.dispatcher: - self.dispatcher.stop_cli_server() - self.dispatcher.disconnect_all() + if self._shutdown_started: + return + self._shutdown_started = True + logger.info(f"Shutdown GUIServer pid={os.getpid()} {repr(self)}") + + def close_launcher_window(): + if self.launcher_window and shiboken6.isValid(self.launcher_window): + self.launcher_window.close() + self.launcher_window.deleteLater() + + def stop_pylsp_server(): + if pylsp_server.is_running(): + pylsp_server.stop() + + def stop_dispatcher(): + if self.dispatcher: + self.dispatcher.stop_cli_server() + self.dispatcher.disconnect_all() + + self._run_shutdown_step("close_launcher_window", close_launcher_window) + self._run_shutdown_step("stop_pylsp_server", stop_pylsp_server) + self._run_shutdown_step("stop_dispatcher", stop_dispatcher) def main(): diff --git a/bec_widgets/cli/client_utils.py b/bec_widgets/cli/client_utils.py index 551e41b8..57443805 100644 --- a/bec_widgets/cli/client_utils.py +++ b/bec_widgets/cli/client_utils.py @@ -5,6 +5,7 @@ from __future__ import annotations import json import os import select +import signal import subprocess import threading import time @@ -33,6 +34,9 @@ else: logger = bec_logger.logger IGNORE_WIDGETS = ["LaunchWindow"] +PROCESS_TERMINATION_TIMEOUT = 10 +PROCESS_OUTPUT_THREAD_JOIN_TIMEOUT = 2 +GRACEFUL_SERVER_SHUTDOWN_TIMEOUT = 5 RegistryState: TypeAlias = dict[ Literal["gui_id", "name", "widget_class", "config", "__rpc__", "container_proxy"], @@ -75,6 +79,123 @@ def _get_output(process, logger) -> None: logger.error(f"Error reading process output: {str(e)}") +def _process_group_id(process) -> int | None: + pid = getattr(process, "pid", None) + if os.name != "posix" or not isinstance(pid, int): + return None + try: + return os.getpgid(pid) + except ProcessLookupError: + return None + + +def _process_details(process) -> str: + args = getattr(process, "args", None) + if isinstance(args, list): + command = " ".join(str(arg) for arg in args) + else: + command = str(args) + return ( + f"pid={getattr(process, 'pid', None)} pgid={_process_group_id(process)} command={command}" + ) + + +def _process_group_snapshot(process) -> str: + pgid = _process_group_id(process) + if pgid is None: + return "Process group snapshot unavailable: process group no longer exists" + try: + result = subprocess.run( + ["ps", "-o", "pid,ppid,pgid,stat,command", "-g", str(pgid)], + check=False, + capture_output=True, + text=True, + timeout=2, + ) + except Exception as exc: + return f"Process group snapshot unavailable: {exc}" + output = result.stdout.strip() + if not output: + return f"Process group snapshot empty for pgid={pgid}" + return output + + +def _terminate_plot_process(process, logger, timeout: float = PROCESS_TERMINATION_TIMEOUT) -> None: + if process.poll() is not None: + return + + process_details = _process_details(process) + try: + pgid = _process_group_id(process) + if pgid is not None: + logger.info(f"Terminating GUI process group {process_details}") + os.killpg(pgid, signal.SIGTERM) + else: + logger.info(f"Terminating GUI process {process_details}") + process.terminate() + except ProcessLookupError: + process.wait(timeout=timeout) + return + except Exception as exc: + logger.warning( + f"Failed to terminate GUI process group: {exc}; terminating process only. " + f"{process_details}" + ) + process.terminate() + + try: + process.wait(timeout=timeout) + return + except subprocess.TimeoutExpired: + logger.warning( + f"GUI process did not stop within {timeout}s; killing it. " + f"{process_details}\n{_process_group_snapshot(process)}" + ) + + try: + pgid = _process_group_id(process) + if pgid is not None: + os.killpg(pgid, signal.SIGKILL) + else: + process.kill() + except ProcessLookupError as e: + logger.error(f"Failed to kill GUI process group: {e}") + process.wait(timeout=timeout) + return + process.wait(timeout=timeout) + + +def _wait_for_process_exit(process, timeout: float) -> bool: + try: + process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + return False + return True + + +def _join_process_output_thread(process, thread: threading.Thread | None, logger) -> None: + if thread is None: + return + thread.join(timeout=PROCESS_OUTPUT_THREAD_JOIN_TIMEOUT) + if not thread.is_alive(): + return + + for stream in (process.stdout, process.stderr): + if stream is None: + continue + try: + stream.close() + except OSError as e: + logger.error(f"Failed to close stream {str(e)}") + pass + thread.join(timeout=PROCESS_OUTPUT_THREAD_JOIN_TIMEOUT) + if thread.is_alive(): + logger.warning( + "GUI process output reader thread did not stop after process shutdown. " + f"{_process_details(process)}" + ) + + def _start_plot_process( gui_id: str, gui_class_id: str, @@ -465,11 +586,13 @@ class BECGuiClient(RPCBase): if self._process: logger.success("Stopping GUI...") - self._process.terminate() - if self._process_output_processing_thread: - self._process_output_processing_thread.join() - self._process.wait() + if not self._request_server_shutdown(): + _terminate_plot_process(self._process, logger) + _join_process_output_thread( + self._process, self._process_output_processing_thread, logger + ) self._process = None + self._process_output_processing_thread = None # Unregister the registry state self._client.connector.unregister( @@ -488,6 +611,30 @@ class BECGuiClient(RPCBase): #### Private methods #### ######################### + def _request_server_shutdown(self) -> bool: + if self._process is None or self._process.poll() is not None: + return True + process_details = _process_details(self._process) + logger.info(f"Requesting graceful GUI shutdown {process_details}") + try: + self.launcher._run_rpc( # pylint: disable=protected-access + "system.shutdown", wait_for_rpc_response=False + ) + except Exception as exc: + logger.warning( + f"Could not request graceful GUI shutdown via RPC: {exc}. " f"{process_details}" + ) + return False + if _wait_for_process_exit(self._process, GRACEFUL_SERVER_SHUTDOWN_TIMEOUT): + logger.info(f"GUI server exited after graceful shutdown {process_details}") + return True + logger.warning( + "GUI server did not exit after graceful shutdown request; " + f"falling back to process termination. {process_details}\n" + f"{_process_group_snapshot(self._process)}" + ) + return False + def _check_if_server_is_alive(self): """Checks if the process is alive""" if self._process is None: diff --git a/bec_widgets/utils/rpc_server.py b/bec_widgets/utils/rpc_server.py index af27e4ca..d3b9cf4b 100644 --- a/bec_widgets/utils/rpc_server.py +++ b/bec_widgets/utils/rpc_server.py @@ -12,7 +12,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import from qtpy.QtCore import Qt, QTimer -from qtpy.QtWidgets import QWidget +from qtpy.QtWidgets import QApplication, QWidget from redis.exceptions import RedisError from bec_widgets.utils.bec_connector import BECConnector @@ -290,10 +290,23 @@ class RPCServer: def run_system_rpc(self, method: str, args: list, kwargs: dict): if method == "system.launch_dock_area": return self._launch_dock_area(*args, **kwargs) + if method == "system.shutdown": + return self._shutdown_gui_server() if method == "system.list_capabilities": - return {"system.launch_dock_area": True} + return {"system.launch_dock_area": True, "system.shutdown": True} raise ValueError(f"Unknown system RPC method: {method}") + @staticmethod + def _shutdown_gui_server() -> None: + app = QApplication.instance() + if app is None: + return + gui_server = getattr(app, "gui_server", None) + if gui_server is not None and hasattr(gui_server, "request_shutdown"): + QTimer.singleShot(0, gui_server.request_shutdown) + return + QTimer.singleShot(0, app.quit) + @staticmethod def _launch_dock_area( name: str | None = None, @@ -468,8 +481,9 @@ class RPCServer: container_proxy = parent.gui_id else: container_proxy = None - except Exception: + except Exception as e: container_proxy = None + logger.error(f"Error while serializing RPC result: {e}") if wait and not self.rpc_register.object_is_registered(connector): raise RegistryNotReadyError(f"Connector {connector} not registered yet") diff --git a/tests/unit_tests/test_client_utils.py b/tests/unit_tests/test_client_utils.py index 3c489b8a..efce65bd 100644 --- a/tests/unit_tests/test_client_utils.py +++ b/tests/unit_tests/test_client_utils.py @@ -1,3 +1,5 @@ +import signal +import subprocess from contextlib import contextmanager from unittest import mock @@ -266,3 +268,81 @@ def test_client_utils_gui_client_set_rpc_timeout(): gui.set_rpc_timeout(10) assert gui._rpc_timeout == 10 + + +def test_client_utils_kill_server_waits_for_process_before_joining_output_thread(): + gui = BECGuiClient() + gui._client = mock.MagicMock() + gui._process = mock.MagicMock(pid=123, stdout=None, stderr=None) + gui._process.poll.return_value = None + order = [] + gui._process.wait.side_effect = lambda timeout: order.append("wait") + gui._process_output_processing_thread = mock.MagicMock() + gui._process_output_processing_thread.join.side_effect = lambda timeout: order.append("join") + gui._process_output_processing_thread.is_alive.return_value = False + + with ( + mock.patch.object(gui, "_request_server_shutdown", return_value=False), + mock.patch("bec_widgets.cli.client_utils.os.getpgid", return_value=123), + mock.patch("bec_widgets.cli.client_utils.os.killpg") as killpg, + ): + gui.kill_server() + + killpg.assert_called_once_with(123, signal.SIGTERM) + assert order == ["wait", "join"] + assert gui._process is None + assert gui._process_output_processing_thread is None + + +def test_client_utils_kill_server_requests_graceful_shutdown_before_signal(): + gui = BECGuiClient() + gui._client = mock.MagicMock() + process = mock.MagicMock(stdout=None, stderr=None) + process.poll.return_value = None + gui._process = process + gui._process_output_processing_thread = mock.MagicMock() + gui._process_output_processing_thread.is_alive.return_value = False + launcher = mock.MagicMock() + + with ( + mock.patch.object( + BECGuiClient, "launcher", new_callable=mock.PropertyMock + ) as launcher_prop, + mock.patch("bec_widgets.cli.client_utils.os.killpg") as killpg, + ): + launcher_prop.return_value = launcher + gui.kill_server() + + launcher._run_rpc.assert_called_once_with("system.shutdown", wait_for_rpc_response=False) + process.wait.assert_called_once_with(timeout=5) + killpg.assert_not_called() + assert gui._process is None + assert gui._process_output_processing_thread is None + + +def test_client_utils_kill_server_kills_process_group_after_timeout(): + gui = BECGuiClient() + gui._client = mock.MagicMock() + process = mock.MagicMock(pid=123, stdout=None, stderr=None, args=["bec-gui-server"]) + process.poll.return_value = None + process.wait.side_effect = [subprocess.TimeoutExpired(cmd="bec-gui-server", timeout=10), None] + gui._process = process + + with ( + mock.patch.object(gui, "_request_server_shutdown", return_value=False), + mock.patch("bec_widgets.cli.client_utils.os.getpgid", return_value=123), + mock.patch("bec_widgets.cli.client_utils.os.killpg") as killpg, + mock.patch("bec_widgets.cli.client_utils.subprocess.run") as run, + ): + run.return_value.stdout = "PID PPID PGID STAT COMMAND\n123 1 123 S bec-gui-server" + gui.kill_server() + + assert killpg.call_args_list == [mock.call(123, signal.SIGTERM), mock.call(123, signal.SIGKILL)] + assert process.wait.call_args_list == [mock.call(timeout=10), mock.call(timeout=10)] + run.assert_called_once_with( + ["ps", "-o", "pid,ppid,pgid,stat,command", "-g", "123"], + check=False, + capture_output=True, + text=True, + timeout=2, + ) diff --git a/tests/unit_tests/test_rpc_server.py b/tests/unit_tests/test_rpc_server.py index 6ef01531..8b71d8f7 100644 --- a/tests/unit_tests/test_rpc_server.py +++ b/tests/unit_tests/test_rpc_server.py @@ -5,6 +5,7 @@ import pytest from bec_lib.service_config import ServiceConfig from qtpy.QtWidgets import QWidget +from bec_widgets.applications import companion_app as companion_app_module from bec_widgets.applications.companion_app import GUIServer from bec_widgets.utils import rpc_server as rpc_server_module from bec_widgets.utils.bec_connector import BECConnector @@ -59,6 +60,52 @@ def test_gui_server_get_service_config(gui_server): assert gui_server._get_service_config().config == ServiceConfig().config +def test_gui_server_signal_shutdown_closes_widgets_and_quits_app(gui_server): + widget = MagicMock() + gui_server.app = MagicMock() + gui_server.app.topLevelWidgets.return_value = [widget] + + gui_server.request_shutdown() + + widget.close.assert_called_once() + gui_server.app.quit.assert_called_once() + + +def test_gui_server_shutdown_is_idempotent(gui_server): + gui_server.launcher_window = MagicMock() + gui_server.dispatcher = MagicMock() + + with ( + patch.object(companion_app_module.shiboken6, "isValid", return_value=True), + patch.object(companion_app_module.pylsp_server, "is_running", return_value=False), + ): + gui_server.shutdown() + gui_server.shutdown() + + gui_server.launcher_window.close.assert_called_once() + gui_server.launcher_window.deleteLater.assert_called_once() + gui_server.dispatcher.stop_cli_server.assert_called_once() + gui_server.dispatcher.disconnect_all.assert_called_once() + + +def test_rpc_server_system_capabilities_include_shutdown(rpc_server): + assert rpc_server.run_system_rpc("system.list_capabilities", [], {}) == { + "system.launch_dock_area": True, + "system.shutdown": True, + } + + +def test_rpc_server_system_shutdown_requests_gui_server_shutdown(rpc_server, qapp): + gui_server = MagicMock() + qapp.gui_server = gui_server + + rpc_server.run_system_rpc("system.shutdown", [], {}) + qapp.processEvents() + + gui_server.request_shutdown.assert_called_once() + del qapp.gui_server + + def test_singleshot_rpc_repeat_raises_on_repeated_singleshot(rpc_server): """ Test that a singleshot RPC method raises an error when called multiple times.