mirror of
https://github.com/bec-project/bec_widgets.git
synced 2026-03-04 16:02:51 +01:00
fix(rpc_server): use single shot instead of processEvents to avoid dead locks
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
@@ -12,7 +11,6 @@ from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.utils.import_utils import lazy_import
|
||||
from qtpy.QtCore import Qt, QTimer
|
||||
from qtpy.QtWidgets import QApplication
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from bec_widgets.cli.rpc.rpc_register import RPCRegister
|
||||
@@ -32,6 +30,10 @@ logger = bec_logger.logger
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RegistryNotReadyError(Exception):
|
||||
"""Raised when trying to access an object from the RPC registry that is not yet registered."""
|
||||
|
||||
|
||||
@contextmanager
|
||||
def rpc_exception_hook(err_func):
|
||||
"""This context replaces the popup message box for error display with a specific hook"""
|
||||
@@ -55,6 +57,19 @@ def rpc_exception_hook(err_func):
|
||||
popup.custom_exception_hook = old_exception_hook
|
||||
|
||||
|
||||
class SingleshotRPCRepeat:
|
||||
|
||||
def __init__(self, max_delay: int = 2000):
|
||||
self.max_delay = max_delay
|
||||
self.accumulated_delay = 0
|
||||
|
||||
def __iadd__(self, delay: int):
|
||||
self.accumulated_delay += delay
|
||||
if self.accumulated_delay > self.max_delay:
|
||||
raise RegistryNotReadyError("Max delay exceeded for RPC singleshot repeat")
|
||||
return self
|
||||
|
||||
|
||||
class RPCServer:
|
||||
|
||||
client: BECClient
|
||||
@@ -86,6 +101,7 @@ class RPCServer:
|
||||
self._heartbeat_timer.start(200)
|
||||
self._registry_update_callbacks = []
|
||||
self._broadcasted_data = {}
|
||||
self._rpc_singleshot_repeats: dict[str, SingleshotRPCRepeat] = {}
|
||||
|
||||
self.status = messages.BECStatus.RUNNING
|
||||
logger.success(f"Server started with gui_id: {self.gui_id}")
|
||||
@@ -109,7 +125,8 @@ class RPCServer:
|
||||
self.send_response(request_id, False, {"error": content})
|
||||
else:
|
||||
logger.debug(f"RPC instruction executed successfully: {res}")
|
||||
self.send_response(request_id, True, {"result": res})
|
||||
self._rpc_singleshot_repeats[request_id] = SingleshotRPCRepeat()
|
||||
QTimer.singleShot(0, lambda: self.serialize_result_and_send(request_id, res))
|
||||
|
||||
def send_response(self, request_id: str, accepted: bool, msg: dict):
|
||||
self.client.connector.set_and_publish(
|
||||
@@ -167,14 +184,61 @@ class RPCServer:
|
||||
res = None
|
||||
else:
|
||||
res = method_obj(*args, **kwargs)
|
||||
return res
|
||||
|
||||
def serialize_result_and_send(self, request_id: str, res: object):
|
||||
"""
|
||||
Serialize the result of an RPC call and send it back to the client.
|
||||
|
||||
Note: If the object is not yet registered in the RPC registry, this method
|
||||
will retry serialization after a short delay, up to a maximum delay. In order
|
||||
to avoid processEvents calls in the middle of serialization, QTimer.singleShot is used.
|
||||
This allows the target event to 'float' to the next event loop iteration until the
|
||||
object is registered.
|
||||
The 'jump' to the next event loop is indicated by raising a RegistryNotReadyError, see
|
||||
_serialize_bec_connector.
|
||||
|
||||
Args:
|
||||
request_id (str): The ID of the request.
|
||||
res (object): The result of the RPC call.
|
||||
"""
|
||||
retry_delay = 100
|
||||
try:
|
||||
if isinstance(res, list):
|
||||
res = [self.serialize_object(obj) for obj in res]
|
||||
elif isinstance(res, dict):
|
||||
res = {key: self.serialize_object(val) for key, val in res.items()}
|
||||
else:
|
||||
res = self.serialize_object(res)
|
||||
return res
|
||||
except RegistryNotReadyError:
|
||||
try:
|
||||
self._rpc_singleshot_repeats[request_id] += retry_delay
|
||||
QTimer.singleShot(
|
||||
retry_delay, lambda: self.serialize_result_and_send(request_id, res)
|
||||
)
|
||||
except RegistryNotReadyError:
|
||||
logger.error(
|
||||
f"Max delay exceeded for RPC request {request_id}, sending error response"
|
||||
)
|
||||
self.send_response(
|
||||
request_id,
|
||||
False,
|
||||
{
|
||||
"error": f"Max delay exceeded for RPC request {request_id}, object not registered in time."
|
||||
},
|
||||
)
|
||||
self._rpc_singleshot_repeats.pop(request_id, None)
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while serializing RPC result: {exc}")
|
||||
self.send_response(
|
||||
request_id,
|
||||
False,
|
||||
{"error": f"Error while serializing RPC result: {exc}\n{traceback.format_exc()}"},
|
||||
)
|
||||
else:
|
||||
self.send_response(request_id, True, {"result": res})
|
||||
self._rpc_singleshot_repeats.pop(request_id, None)
|
||||
|
||||
def serialize_object(self, obj: T) -> None | dict | T:
|
||||
"""
|
||||
@@ -256,11 +320,8 @@ class RPCServer:
|
||||
except Exception:
|
||||
container_proxy = None
|
||||
|
||||
if wait:
|
||||
while not self.rpc_register.object_is_registered(connector):
|
||||
QApplication.processEvents()
|
||||
logger.info(f"Waiting for {connector} to be registered...")
|
||||
time.sleep(0.1)
|
||||
if wait and not self.rpc_register.object_is_registered(connector):
|
||||
raise RegistryNotReadyError(f"Connector {connector} not registered yet")
|
||||
|
||||
widget_class = getattr(connector, "rpc_widget_class", None)
|
||||
if not widget_class:
|
||||
|
||||
@@ -1,9 +1,28 @@
|
||||
import argparse
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from bec_lib.service_config import ServiceConfig
|
||||
from qtpy.QtWidgets import QWidget
|
||||
|
||||
from bec_widgets.cli.server import GUIServer
|
||||
from bec_widgets.utils.bec_connector import BECConnector
|
||||
from bec_widgets.utils.rpc_server import RegistryNotReadyError, RPCServer, SingleshotRPCRepeat
|
||||
|
||||
from .client_mocks import mocked_client
|
||||
|
||||
|
||||
class DummyWidget(BECConnector, QWidget):
|
||||
def __init__(self, parent=None, client=None, **kwargs):
|
||||
super().__init__(parent=parent, client=client, **kwargs)
|
||||
self.setObjectName("DummyWidget")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_widget(qtbot, mocked_client):
|
||||
widget = DummyWidget(client=mocked_client)
|
||||
qtbot.addWidget(widget)
|
||||
return widget
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -14,6 +33,13 @@ def gui_server():
|
||||
return GUIServer(args=args)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rpc_server(mocked_client):
|
||||
rpc_server = RPCServer(gui_id="test_gui", client=mocked_client)
|
||||
yield rpc_server
|
||||
rpc_server.shutdown()
|
||||
|
||||
|
||||
def test_gui_server_start_server_without_service_config(gui_server):
|
||||
"""
|
||||
Test that the server is started with the correct arguments.
|
||||
@@ -30,3 +56,85 @@ def test_gui_server_get_service_config(gui_server):
|
||||
Test that the server is started with the correct arguments.
|
||||
"""
|
||||
assert gui_server._get_service_config().config == ServiceConfig().config
|
||||
|
||||
|
||||
def test_singleshot_rpc_repeat_raises_on_repeated_singleshot(rpc_server):
|
||||
"""
|
||||
Test that a singleshot RPC method raises an error when called multiple times.
|
||||
"""
|
||||
repeat = SingleshotRPCRepeat()
|
||||
rpc_server._rpc_singleshot_repeats["test_method"] = repeat
|
||||
|
||||
repeat += 100 # First call should work fine
|
||||
with pytest.raises(RegistryNotReadyError):
|
||||
repeat += 2000 # Should raise here
|
||||
|
||||
|
||||
def test_serialize_result_and_send_with_singleshot_retry(rpc_server, qtbot, dummy_widget):
|
||||
"""
|
||||
Test that serialize_result_and_send retries when RegistryNotReadyError is raised,
|
||||
and eventually succeeds when the object becomes registered.
|
||||
"""
|
||||
request_id = "test_request_123"
|
||||
|
||||
dummy = dummy_widget
|
||||
|
||||
# Track how many times serialize_object is called
|
||||
call_count = 0
|
||||
|
||||
def serialize_side_effect(obj):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First 2 calls raise RegistryNotReadyError
|
||||
if call_count <= 2:
|
||||
raise RegistryNotReadyError(f"Not ready yet (call {call_count})")
|
||||
# Third call succeeds
|
||||
return {"gui_id": dummy.gui_id, "success": True}
|
||||
|
||||
# Patch serialize_object to control when it raises RegistryNotReadyError
|
||||
with patch.object(rpc_server, "serialize_object", side_effect=serialize_side_effect):
|
||||
with patch.object(rpc_server, "send_response") as mock_send_response:
|
||||
# Start the serialization process
|
||||
rpc_server._rpc_singleshot_repeats[request_id] = SingleshotRPCRepeat()
|
||||
rpc_server.serialize_result_and_send(request_id, dummy)
|
||||
|
||||
# Verify that serialize_object was called 3 times
|
||||
qtbot.waitUntil(lambda: call_count >= 3, timeout=5000)
|
||||
|
||||
# Verify that send_response was called with success
|
||||
mock_send_response.assert_called_once()
|
||||
args = mock_send_response.call_args[0]
|
||||
assert args[0] == request_id
|
||||
assert args[1] is True # accepted=True
|
||||
assert "result" in args[2]
|
||||
|
||||
|
||||
def test_serialize_result_and_send_max_delay_exceeded(rpc_server, qtbot, dummy_widget):
|
||||
"""
|
||||
Test that serialize_result_and_send sends an error response when max delay is exceeded.
|
||||
"""
|
||||
request_id = "test_request_456"
|
||||
|
||||
dummy = dummy_widget
|
||||
|
||||
# Always raise RegistryNotReadyError
|
||||
with patch.object(
|
||||
rpc_server, "serialize_object", side_effect=RegistryNotReadyError("Always not ready")
|
||||
):
|
||||
with patch.object(rpc_server, "send_response") as mock_send_response:
|
||||
# Start the serialization process
|
||||
rpc_server._rpc_singleshot_repeats[request_id] = SingleshotRPCRepeat()
|
||||
rpc_server.serialize_result_and_send(request_id, dummy)
|
||||
|
||||
# Process event loop to allow all singleshot timers to fire
|
||||
# Max delay is 2000ms, with 100ms retry intervals = ~20 retries
|
||||
# Wait for the max delay plus some buffer
|
||||
qtbot.wait(2500)
|
||||
|
||||
# Verify that send_response was called with an error
|
||||
mock_send_response.assert_called()
|
||||
args = mock_send_response.call_args[0]
|
||||
assert args[0] == request_id
|
||||
assert args[1] is False # accepted=False
|
||||
assert "error" in args[2]
|
||||
assert "Max delay exceeded" in args[2]["error"]
|
||||
|
||||
Reference in New Issue
Block a user