diff --git a/bec_widgets/utils/serialization.py b/bec_widgets/utils/serialization.py index fba7be8a..a9c0b111 100644 --- a/bec_widgets/utils/serialization.py +++ b/bec_widgets/utils/serialization.py @@ -1,44 +1,25 @@ +from bec_lib.codecs import BECCodec from bec_lib.serialization import msgpack from qtpy.QtCore import QPointF +class QPointFEncoder(BECCodec): + obj_type = QPointF + + @staticmethod + def encode(obj: QPointF) -> list[float]: + """Encode a QPointF object to a list of floats.""" + return [obj.x(), obj.y()] + + @staticmethod + def decode(type_name: str, data: list[float]) -> list[float]: + """No-op function since QPointF is encoded as a list of floats.""" + return data + + def register_serializer_extension(): """ Register the serializer extension for the BECConnector. """ - if not module_is_registered("bec_widgets.utils.serialization"): - msgpack.register_object_hook(encode_qpointf, decode_qpointf) - - -def module_is_registered(module_name: str) -> bool: - """ - Check if the module is registered in the encoder. - - Args: - module_name (str): The name of the module to check. - - Returns: - bool: True if the module is registered, False otherwise. - """ - # pylint: disable=protected-access - for enc in msgpack._encoder: - if enc[0].__module__ == module_name: - return True - return False - - -def encode_qpointf(obj): - """ - Encode a QPointF object to a list of floats. As this is mostly used for sending - data to the client, it is not necessary to convert it back to a QPointF object. - """ - if isinstance(obj, QPointF): - return [obj.x(), obj.y()] - return obj - - -def decode_qpointf(obj): - """ - no-op function since QPointF is encoded as a list of floats. - """ - return obj + if not msgpack.is_registered(QPointF): + msgpack.register(QPointF, QPointFEncoder.encode, QPointFEncoder.decode) diff --git a/tests/unit_tests/test_serializer.py b/tests/unit_tests/test_serializer.py index eadfd940..0acfb07a 100644 --- a/tests/unit_tests/test_serializer.py +++ b/tests/unit_tests/test_serializer.py @@ -21,7 +21,6 @@ def test_multiple_extension_registration(): """ Test that multiple extension registrations do not cause issues. """ - assert serialization.module_is_registered("bec_widgets.utils.serialization") + assert msgpack.is_registered(QPointF) serialization.register_serializer_extension() - assert serialization.module_is_registered("bec_widgets.utils.serialization") - assert len(msgpack._encoder) == len(set(msgpack._encoder)) + assert msgpack.is_registered(QPointF)