From 60a7dda60ab7720df0614fd94d5684fa37f3a3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Wed, 3 Apr 2024 10:28:06 +0200 Subject: [PATCH] restructures client to have separate thread for its asyncio loop --- src/pydase/client/client.py | 81 +++++++++++++++-------- src/pydase/client/proxy_class_factory.py | 61 ++++++++++------- src/pydase/server/web_server/sio_setup.py | 10 ++- 3 files changed, 100 insertions(+), 52 deletions(-) diff --git a/src/pydase/client/client.py b/src/pydase/client/client.py index 1b8d794..3a7543c 100644 --- a/src/pydase/client/client.py +++ b/src/pydase/client/client.py @@ -1,6 +1,7 @@ +import asyncio import logging -import time -from typing import Any, TypedDict +import threading +from typing import Any, TypedDict, cast import socketio # type: ignore @@ -26,30 +27,55 @@ class NotifyDict(TypedDict): class Client(pydase.DataService): def __init__(self, hostname: str, port: int): super().__init__() - self._sio = socketio.Client() - self._setup_events() - self._proxy_class_factory = ProxyClassFactory(self._sio) - self.proxy = ProxyConnection() - self._sio.connect( - f"ws://{hostname}:{port}", + self._hostname = hostname + self._port = port + self._sio = socketio.AsyncClient() + self._loop = asyncio.new_event_loop() + self._proxy_class_factory = ProxyClassFactory(self._sio, self._loop) + self._thread = threading.Thread( + target=self.__asyncio_loop_thread, args=(self._loop,), daemon=True + ) + self._thread.start() + self.proxy: ProxyConnection + asyncio.run_coroutine_threadsafe(self._connect(), self._loop).result() + + async def _connect(self) -> None: + logger.debug("Connecting to server '%s:%s' ...", self._hostname, self._port) + await self._setup_events() + await self._sio.connect( + f"ws://{self._hostname}:{self._port}", socketio_path="/ws/socket.io", transports=["websocket"], ) - while not self.proxy._initialised: - time.sleep(0.01) - def _setup_events(self) -> None: + def __asyncio_loop_thread(self, loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + async def _setup_events(self) -> None: @self._sio.event - def class_structure(data: SerializedDataService) -> None: - if not self.proxy._initialised: - self.proxy = self._proxy_class_factory.create_proxy(data) + async def connect() -> None: + logger.debug("Connected to '%s:%s' ...", self._hostname, self._port) + serialized_data = cast( + SerializedDataService, await self._sio.call("service_serialization") + ) + if not hasattr(self, "proxy"): + self.proxy = self._proxy_class_factory.create_proxy(serialized_data) else: # need to change to avoid overwriting the proxy class - data["type"] = "DeviceConnection" - super(pydase.DataService, self.proxy)._notify_changed("", loads(data)) + serialized_data["type"] = "DeviceConnection" + super(pydase.DataService, self.proxy)._notify_changed( + "", loads(serialized_data) + ) + self.proxy._connected = True @self._sio.event - def notify(data: NotifyDict) -> None: + async def disconnect() -> None: + logger.debug("Disconnected") + self.proxy._connected = False + + @self._sio.event + async def notify(data: NotifyDict) -> None: # Notify the DataServiceObserver directly, not going through # self._notify_changed as this would trigger the "update_value" event super(pydase.DataService, self.proxy)._notify_changed( @@ -57,8 +83,8 @@ class Client(pydase.DataService): loads(data["data"]["value"]), ) - def disconnect(self) -> None: - self._sio.disconnect() + async def _disconnect(self) -> None: + await self._sio.disconnect() def _notify_changed(self, changed_attribute: str, value: Any) -> None: if ( @@ -69,11 +95,14 @@ class Client(pydase.DataService): ): logger.debug(f"{changed_attribute}: {value}") - self._sio.call( - "update_value", - { - "access_path": changed_attribute[6:], - "value": dump(value), - }, - ) + async def update_value() -> None: + await self._sio.call( + "update_value", + { + "access_path": changed_attribute[6:], + "value": dump(value), + }, + ) + + asyncio.run_coroutine_threadsafe(update_value(), loop=self._loop) return super()._notify_changed(changed_attribute, value) diff --git a/src/pydase/client/proxy_class_factory.py b/src/pydase/client/proxy_class_factory.py index 2a866a2..70a20e6 100644 --- a/src/pydase/client/proxy_class_factory.py +++ b/src/pydase/client/proxy_class_factory.py @@ -1,3 +1,4 @@ +import asyncio import logging from collections.abc import Callable from copy import copy @@ -20,7 +21,8 @@ logger = logging.getLogger(__name__) class ProxyClassMixin: - _sio: socketio.Client + _sio: socketio.AsyncClient + _loop: asyncio.AbstractEventLoop def __setattr__(self, key: str, value: Any) -> None: # prevent overriding of proxy attributes @@ -44,18 +46,18 @@ class ProxyConnection(pydase.components.DeviceConnection, ProxyClassMixin): self._initialised = False self._reconnection_wait_time = 1.0 - @property - def connected(self) -> bool: - return self._sio.connected - class ProxyClassFactory: - def __init__(self, sio_client: socketio.Client) -> None: + def __init__( + self, sio_client: socketio.AsyncClient, loop: asyncio.AbstractEventLoop + ) -> None: self.sio_client = sio_client + self.loop = loop def create_proxy(self, data: SerializedObject) -> ProxyConnection: proxy_class = self._deserialize_component_type(data, ProxyConnection) proxy_class._sio = self.sio_client + proxy_class._loop = self.loop proxy_class._initialised = True return proxy_class # type: ignore @@ -87,6 +89,7 @@ class ProxyClassFactory: serialized_object, component_class ) proxy_class._sio = self.sio_client + proxy_class._loop = self.loop proxy_class._initialised = True return proxy_class return None @@ -95,18 +98,21 @@ class ProxyClassFactory: self, serialized_object: SerializedMethod ) -> Callable[..., Any]: def method_proxy(self: ProxyBaseClass, *args: Any, **kwargs: Any) -> Any: - serialized_response = cast( - dict[str, Any], - self._sio.call( + async def trigger_method() -> Any: + return await self._sio.call( "trigger_method", { "access_path": serialized_object["full_access_path"], "args": dump(list(args)), "kwargs": dump(kwargs), }, - ), - ) - return loads(serialized_response) # type: ignore + ) + + result = asyncio.run_coroutine_threadsafe( + trigger_method(), + loop=self._loop, + ).result() + return loads(result) return method_proxy @@ -160,27 +166,34 @@ class ProxyClassFactory: return create_proxy_class(serialized_object)() def _create_attr_property(self, serialized_attr: SerializedObject) -> property: - def get(self: ProxyBaseClass) -> Any: # type: ignore - return loads( - cast( - SerializedObject, - self._sio.call("get_value", serialized_attr["full_access_path"]), + def get(self: ProxyBaseClass) -> Any: + async def get_result() -> Any: + return await self._sio.call( + "get_value", serialized_attr["full_access_path"] ) - ) + + result = asyncio.run_coroutine_threadsafe( + get_result(), + loop=self._loop, + ).result() + return loads(result) get.__doc__ = serialized_attr["doc"] - def set(self: ProxyBaseClass, value: Any) -> None: # type: ignore - result = cast( - SerializedObject | None, - self._sio.call( + def set(self: ProxyBaseClass, value: Any) -> None: + async def set_result() -> Any: + return await self._sio.call( "update_value", { "access_path": serialized_attr["full_access_path"], "value": dump(value), }, - ), - ) + ) + + result: SerializedObject | None = asyncio.run_coroutine_threadsafe( + set_result(), + loop=self._loop, + ).result() if result is not None: loads(result) diff --git a/src/pydase/server/web_server/sio_setup.py b/src/pydase/server/web_server/sio_setup.py index 30034a2..13149c7 100644 --- a/src/pydase/server/web_server/sio_setup.py +++ b/src/pydase/server/web_server/sio_setup.py @@ -119,12 +119,18 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> async def connect(sid: str, environ: Any) -> None: logging.debug("Client [%s] connected", click.style(str(sid), fg="cyan")) - await sio.emit("class_structure", state_manager.cache, to=sid) - @sio.event # type: ignore async def disconnect(sid: str) -> None: logging.debug("Client [%s] disconnected", click.style(str(sid), fg="cyan")) + @sio.event # type: ignore + async def service_serialization(sid: str) -> SerializedObject: + logging.debug( + "Client [%s] requested service serialization", + click.style(str(sid), fg="cyan"), + ) + return state_manager.cache + @sio.event async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: # type: ignore path = data["access_path"]