diff --git a/bec_widgets/cli/client_utils.py b/bec_widgets/cli/client_utils.py index 27746f49..6ef053cf 100644 --- a/bec_widgets/cli/client_utils.py +++ b/bec_widgets/cli/client_utils.py @@ -20,6 +20,7 @@ from rich.table import Table import bec_widgets.cli.client as client from bec_widgets.cli.rpc.rpc_base import RPCBase, RPCReference +from bec_widgets.utils.serialization import register_serializer_extension if TYPE_CHECKING: # pragma: no cover from bec_lib.messages import GUIRegistryStateMessage @@ -215,6 +216,7 @@ class BECGuiClient(RPCBase): self._server_registry: dict[str, RegistryState] = {} self._ipython_registry: dict[str, RPCReference] = {} self.available_widgets = AvailableWidgetsNamespace() + register_serializer_extension() #################### #### Client API #### diff --git a/bec_widgets/utils/bec_dispatcher.py b/bec_widgets/utils/bec_dispatcher.py index d92d77ea..5f722c8b 100644 --- a/bec_widgets/utils/bec_dispatcher.py +++ b/bec_widgets/utils/bec_dispatcher.py @@ -14,6 +14,8 @@ from bec_lib.service_config import ServiceConfig from qtpy.QtCore import QObject from qtpy.QtCore import Signal as pyqtSignal +from bec_widgets.utils.serialization import register_serializer_extension + logger = bec_logger.logger if TYPE_CHECKING: # pragma: no cover @@ -120,6 +122,8 @@ class BECDispatcher: except redis.exceptions.ConnectionError: logger.warning("Could not connect to Redis, skipping start of BECClient.") + register_serializer_extension() + logger.success("Initialized BECDispatcher") self.start_cli_server(gui_id=gui_id) diff --git a/bec_widgets/utils/serialization.py b/bec_widgets/utils/serialization.py new file mode 100644 index 00000000..fba7be8a --- /dev/null +++ b/bec_widgets/utils/serialization.py @@ -0,0 +1,44 @@ +from bec_lib.serialization import msgpack +from qtpy.QtCore import QPointF + + +def register_serializer_extension(): + """ + Register the serializer extension for the BECConnector. + """ + if not module_is_registered("bec_widgets.utils.serialization"): + msgpack.register_object_hook(encode_qpointf, decode_qpointf) + + +def module_is_registered(module_name: str) -> bool: + """ + Check if the module is registered in the encoder. + + Args: + module_name (str): The name of the module to check. + + Returns: + bool: True if the module is registered, False otherwise. + """ + # pylint: disable=protected-access + for enc in msgpack._encoder: + if enc[0].__module__ == module_name: + return True + return False + + +def encode_qpointf(obj): + """ + Encode a QPointF object to a list of floats. As this is mostly used for sending + data to the client, it is not necessary to convert it back to a QPointF object. + """ + if isinstance(obj, QPointF): + return [obj.x(), obj.y()] + return obj + + +def decode_qpointf(obj): + """ + no-op function since QPointF is encoded as a list of floats. + """ + return obj diff --git a/tests/unit_tests/test_serializer.py b/tests/unit_tests/test_serializer.py new file mode 100644 index 00000000..eadfd940 --- /dev/null +++ b/tests/unit_tests/test_serializer.py @@ -0,0 +1,27 @@ +import pytest +from bec_lib.serialization import msgpack +from qtpy.QtCore import QPointF + +from bec_widgets.utils import serialization + + +@pytest.mark.parametrize("data, expected", [(QPointF(20, 10), [20, 10])]) +def test_serialize(data, expected): + """ + Test serialization of various data types. Note that the auto-use fixture of + the bec-dispatcher already registers the serializer extension, so we don't need to + register it again here. + """ + + serialized_data = msgpack.loads(msgpack.dumps(data)) + assert serialized_data == expected + + +def test_multiple_extension_registration(): + """ + Test that multiple extension registrations do not cause issues. + """ + assert serialization.module_is_registered("bec_widgets.utils.serialization") + serialization.register_serializer_extension() + assert serialization.module_is_registered("bec_widgets.utils.serialization") + assert len(msgpack._encoder) == len(set(msgpack._encoder))