restructures client to have separate thread for its asyncio loop

This commit is contained in:
Mose Müller
2024-04-03 10:28:06 +02:00
parent 381d98b078
commit 60a7dda60a
3 changed files with 100 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
import logging import logging
import time import threading
from typing import Any, TypedDict from typing import Any, TypedDict, cast
import socketio # type: ignore import socketio # type: ignore
@@ -26,30 +27,55 @@ class NotifyDict(TypedDict):
class Client(pydase.DataService): class Client(pydase.DataService):
def __init__(self, hostname: str, port: int): def __init__(self, hostname: str, port: int):
super().__init__() super().__init__()
self._sio = socketio.Client() self._hostname = hostname
self._setup_events() self._port = port
self._proxy_class_factory = ProxyClassFactory(self._sio) self._sio = socketio.AsyncClient()
self.proxy = ProxyConnection() self._loop = asyncio.new_event_loop()
self._sio.connect( self._proxy_class_factory = ProxyClassFactory(self._sio, self._loop)
f"ws://{hostname}:{port}", self._thread = threading.Thread(
target=self.__asyncio_loop_thread, args=(self._loop,), daemon=True
)
self._thread.start()
self.proxy: ProxyConnection
asyncio.run_coroutine_threadsafe(self._connect(), self._loop).result()
async def _connect(self) -> None:
logger.debug("Connecting to server '%s:%s' ...", self._hostname, self._port)
await self._setup_events()
await self._sio.connect(
f"ws://{self._hostname}:{self._port}",
socketio_path="/ws/socket.io", socketio_path="/ws/socket.io",
transports=["websocket"], transports=["websocket"],
) )
while not self.proxy._initialised:
time.sleep(0.01)
def _setup_events(self) -> None: def __asyncio_loop_thread(self, loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
async def _setup_events(self) -> None:
@self._sio.event @self._sio.event
def class_structure(data: SerializedDataService) -> None: async def connect() -> None:
if not self.proxy._initialised: logger.debug("Connected to '%s:%s' ...", self._hostname, self._port)
self.proxy = self._proxy_class_factory.create_proxy(data) serialized_data = cast(
SerializedDataService, await self._sio.call("service_serialization")
)
if not hasattr(self, "proxy"):
self.proxy = self._proxy_class_factory.create_proxy(serialized_data)
else: else:
# need to change to avoid overwriting the proxy class # need to change to avoid overwriting the proxy class
data["type"] = "DeviceConnection" serialized_data["type"] = "DeviceConnection"
super(pydase.DataService, self.proxy)._notify_changed("", loads(data)) super(pydase.DataService, self.proxy)._notify_changed(
"", loads(serialized_data)
)
self.proxy._connected = True
@self._sio.event @self._sio.event
def notify(data: NotifyDict) -> None: async def disconnect() -> None:
logger.debug("Disconnected")
self.proxy._connected = False
@self._sio.event
async def notify(data: NotifyDict) -> None:
# Notify the DataServiceObserver directly, not going through # Notify the DataServiceObserver directly, not going through
# self._notify_changed as this would trigger the "update_value" event # self._notify_changed as this would trigger the "update_value" event
super(pydase.DataService, self.proxy)._notify_changed( super(pydase.DataService, self.proxy)._notify_changed(
@@ -57,8 +83,8 @@ class Client(pydase.DataService):
loads(data["data"]["value"]), loads(data["data"]["value"]),
) )
def disconnect(self) -> None: async def _disconnect(self) -> None:
self._sio.disconnect() await self._sio.disconnect()
def _notify_changed(self, changed_attribute: str, value: Any) -> None: def _notify_changed(self, changed_attribute: str, value: Any) -> None:
if ( if (
@@ -69,11 +95,14 @@ class Client(pydase.DataService):
): ):
logger.debug(f"{changed_attribute}: {value}") logger.debug(f"{changed_attribute}: {value}")
self._sio.call( async def update_value() -> None:
await self._sio.call(
"update_value", "update_value",
{ {
"access_path": changed_attribute[6:], "access_path": changed_attribute[6:],
"value": dump(value), "value": dump(value),
}, },
) )
asyncio.run_coroutine_threadsafe(update_value(), loop=self._loop)
return super()._notify_changed(changed_attribute, value) return super()._notify_changed(changed_attribute, value)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from copy import copy from copy import copy
@@ -20,7 +21,8 @@ logger = logging.getLogger(__name__)
class ProxyClassMixin: class ProxyClassMixin:
_sio: socketio.Client _sio: socketio.AsyncClient
_loop: asyncio.AbstractEventLoop
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
# prevent overriding of proxy attributes # prevent overriding of proxy attributes
@@ -44,18 +46,18 @@ class ProxyConnection(pydase.components.DeviceConnection, ProxyClassMixin):
self._initialised = False self._initialised = False
self._reconnection_wait_time = 1.0 self._reconnection_wait_time = 1.0
@property
def connected(self) -> bool:
return self._sio.connected
class ProxyClassFactory: class ProxyClassFactory:
def __init__(self, sio_client: socketio.Client) -> None: def __init__(
self, sio_client: socketio.AsyncClient, loop: asyncio.AbstractEventLoop
) -> None:
self.sio_client = sio_client self.sio_client = sio_client
self.loop = loop
def create_proxy(self, data: SerializedObject) -> ProxyConnection: def create_proxy(self, data: SerializedObject) -> ProxyConnection:
proxy_class = self._deserialize_component_type(data, ProxyConnection) proxy_class = self._deserialize_component_type(data, ProxyConnection)
proxy_class._sio = self.sio_client proxy_class._sio = self.sio_client
proxy_class._loop = self.loop
proxy_class._initialised = True proxy_class._initialised = True
return proxy_class # type: ignore return proxy_class # type: ignore
@@ -87,6 +89,7 @@ class ProxyClassFactory:
serialized_object, component_class serialized_object, component_class
) )
proxy_class._sio = self.sio_client proxy_class._sio = self.sio_client
proxy_class._loop = self.loop
proxy_class._initialised = True proxy_class._initialised = True
return proxy_class return proxy_class
return None return None
@@ -95,18 +98,21 @@ class ProxyClassFactory:
self, serialized_object: SerializedMethod self, serialized_object: SerializedMethod
) -> Callable[..., Any]: ) -> Callable[..., Any]:
def method_proxy(self: ProxyBaseClass, *args: Any, **kwargs: Any) -> Any: def method_proxy(self: ProxyBaseClass, *args: Any, **kwargs: Any) -> Any:
serialized_response = cast( async def trigger_method() -> Any:
dict[str, Any], return await self._sio.call(
self._sio.call(
"trigger_method", "trigger_method",
{ {
"access_path": serialized_object["full_access_path"], "access_path": serialized_object["full_access_path"],
"args": dump(list(args)), "args": dump(list(args)),
"kwargs": dump(kwargs), "kwargs": dump(kwargs),
}, },
),
) )
return loads(serialized_response) # type: ignore
result = asyncio.run_coroutine_threadsafe(
trigger_method(),
loop=self._loop,
).result()
return loads(result)
return method_proxy return method_proxy
@@ -160,27 +166,34 @@ class ProxyClassFactory:
return create_proxy_class(serialized_object)() return create_proxy_class(serialized_object)()
def _create_attr_property(self, serialized_attr: SerializedObject) -> property: def _create_attr_property(self, serialized_attr: SerializedObject) -> property:
def get(self: ProxyBaseClass) -> Any: # type: ignore def get(self: ProxyBaseClass) -> Any:
return loads( async def get_result() -> Any:
cast( return await self._sio.call(
SerializedObject, "get_value", serialized_attr["full_access_path"]
self._sio.call("get_value", serialized_attr["full_access_path"]),
)
) )
result = asyncio.run_coroutine_threadsafe(
get_result(),
loop=self._loop,
).result()
return loads(result)
get.__doc__ = serialized_attr["doc"] get.__doc__ = serialized_attr["doc"]
def set(self: ProxyBaseClass, value: Any) -> None: # type: ignore def set(self: ProxyBaseClass, value: Any) -> None:
result = cast( async def set_result() -> Any:
SerializedObject | None, return await self._sio.call(
self._sio.call(
"update_value", "update_value",
{ {
"access_path": serialized_attr["full_access_path"], "access_path": serialized_attr["full_access_path"],
"value": dump(value), "value": dump(value),
}, },
),
) )
result: SerializedObject | None = asyncio.run_coroutine_threadsafe(
set_result(),
loop=self._loop,
).result()
if result is not None: if result is not None:
loads(result) loads(result)

View File

@@ -119,12 +119,18 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
async def connect(sid: str, environ: Any) -> None: async def connect(sid: str, environ: Any) -> None:
logging.debug("Client [%s] connected", click.style(str(sid), fg="cyan")) logging.debug("Client [%s] connected", click.style(str(sid), fg="cyan"))
await sio.emit("class_structure", state_manager.cache, to=sid)
@sio.event # type: ignore @sio.event # type: ignore
async def disconnect(sid: str) -> None: async def disconnect(sid: str) -> None:
logging.debug("Client [%s] disconnected", click.style(str(sid), fg="cyan")) logging.debug("Client [%s] disconnected", click.style(str(sid), fg="cyan"))
@sio.event # type: ignore
async def service_serialization(sid: str) -> SerializedObject:
logging.debug(
"Client [%s] requested service serialization",
click.style(str(sid), fg="cyan"),
)
return state_manager.cache
@sio.event @sio.event
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: # type: ignore async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: # type: ignore
path = data["access_path"] path = data["access_path"]