mirror of
https://github.com/tiqi-group/pydase.git
synced 2026-02-14 06:18:41 +01:00
restructures client to have separate thread for its asyncio loop
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypedDict
|
||||
import threading
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import socketio # type: ignore
|
||||
|
||||
@@ -26,30 +27,55 @@ class NotifyDict(TypedDict):
|
||||
class Client(pydase.DataService):
|
||||
def __init__(self, hostname: str, port: int):
|
||||
super().__init__()
|
||||
self._sio = socketio.Client()
|
||||
self._setup_events()
|
||||
self._proxy_class_factory = ProxyClassFactory(self._sio)
|
||||
self.proxy = ProxyConnection()
|
||||
self._sio.connect(
|
||||
f"ws://{hostname}:{port}",
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
self._sio = socketio.AsyncClient()
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._proxy_class_factory = ProxyClassFactory(self._sio, self._loop)
|
||||
self._thread = threading.Thread(
|
||||
target=self.__asyncio_loop_thread, args=(self._loop,), daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
self.proxy: ProxyConnection
|
||||
asyncio.run_coroutine_threadsafe(self._connect(), self._loop).result()
|
||||
|
||||
async def _connect(self) -> None:
|
||||
logger.debug("Connecting to server '%s:%s' ...", self._hostname, self._port)
|
||||
await self._setup_events()
|
||||
await self._sio.connect(
|
||||
f"ws://{self._hostname}:{self._port}",
|
||||
socketio_path="/ws/socket.io",
|
||||
transports=["websocket"],
|
||||
)
|
||||
while not self.proxy._initialised:
|
||||
time.sleep(0.01)
|
||||
|
||||
def _setup_events(self) -> None:
|
||||
def __asyncio_loop_thread(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
async def _setup_events(self) -> None:
|
||||
@self._sio.event
|
||||
def class_structure(data: SerializedDataService) -> None:
|
||||
if not self.proxy._initialised:
|
||||
self.proxy = self._proxy_class_factory.create_proxy(data)
|
||||
async def connect() -> None:
|
||||
logger.debug("Connected to '%s:%s' ...", self._hostname, self._port)
|
||||
serialized_data = cast(
|
||||
SerializedDataService, await self._sio.call("service_serialization")
|
||||
)
|
||||
if not hasattr(self, "proxy"):
|
||||
self.proxy = self._proxy_class_factory.create_proxy(serialized_data)
|
||||
else:
|
||||
# need to change to avoid overwriting the proxy class
|
||||
data["type"] = "DeviceConnection"
|
||||
super(pydase.DataService, self.proxy)._notify_changed("", loads(data))
|
||||
serialized_data["type"] = "DeviceConnection"
|
||||
super(pydase.DataService, self.proxy)._notify_changed(
|
||||
"", loads(serialized_data)
|
||||
)
|
||||
self.proxy._connected = True
|
||||
|
||||
@self._sio.event
|
||||
def notify(data: NotifyDict) -> None:
|
||||
async def disconnect() -> None:
|
||||
logger.debug("Disconnected")
|
||||
self.proxy._connected = False
|
||||
|
||||
@self._sio.event
|
||||
async def notify(data: NotifyDict) -> None:
|
||||
# Notify the DataServiceObserver directly, not going through
|
||||
# self._notify_changed as this would trigger the "update_value" event
|
||||
super(pydase.DataService, self.proxy)._notify_changed(
|
||||
@@ -57,8 +83,8 @@ class Client(pydase.DataService):
|
||||
loads(data["data"]["value"]),
|
||||
)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self._sio.disconnect()
|
||||
async def _disconnect(self) -> None:
|
||||
await self._sio.disconnect()
|
||||
|
||||
def _notify_changed(self, changed_attribute: str, value: Any) -> None:
|
||||
if (
|
||||
@@ -69,11 +95,14 @@ class Client(pydase.DataService):
|
||||
):
|
||||
logger.debug(f"{changed_attribute}: {value}")
|
||||
|
||||
self._sio.call(
|
||||
"update_value",
|
||||
{
|
||||
"access_path": changed_attribute[6:],
|
||||
"value": dump(value),
|
||||
},
|
||||
)
|
||||
async def update_value() -> None:
|
||||
await self._sio.call(
|
||||
"update_value",
|
||||
{
|
||||
"access_path": changed_attribute[6:],
|
||||
"value": dump(value),
|
||||
},
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(update_value(), loop=self._loop)
|
||||
return super()._notify_changed(changed_attribute, value)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
@@ -20,7 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyClassMixin:
|
||||
_sio: socketio.Client
|
||||
_sio: socketio.AsyncClient
|
||||
_loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
# prevent overriding of proxy attributes
|
||||
@@ -44,18 +46,18 @@ class ProxyConnection(pydase.components.DeviceConnection, ProxyClassMixin):
|
||||
self._initialised = False
|
||||
self._reconnection_wait_time = 1.0
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self._sio.connected
|
||||
|
||||
|
||||
class ProxyClassFactory:
|
||||
def __init__(self, sio_client: socketio.Client) -> None:
|
||||
def __init__(
|
||||
self, sio_client: socketio.AsyncClient, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self.sio_client = sio_client
|
||||
self.loop = loop
|
||||
|
||||
def create_proxy(self, data: SerializedObject) -> ProxyConnection:
|
||||
proxy_class = self._deserialize_component_type(data, ProxyConnection)
|
||||
proxy_class._sio = self.sio_client
|
||||
proxy_class._loop = self.loop
|
||||
proxy_class._initialised = True
|
||||
return proxy_class # type: ignore
|
||||
|
||||
@@ -87,6 +89,7 @@ class ProxyClassFactory:
|
||||
serialized_object, component_class
|
||||
)
|
||||
proxy_class._sio = self.sio_client
|
||||
proxy_class._loop = self.loop
|
||||
proxy_class._initialised = True
|
||||
return proxy_class
|
||||
return None
|
||||
@@ -95,18 +98,21 @@ class ProxyClassFactory:
|
||||
self, serialized_object: SerializedMethod
|
||||
) -> Callable[..., Any]:
|
||||
def method_proxy(self: ProxyBaseClass, *args: Any, **kwargs: Any) -> Any:
|
||||
serialized_response = cast(
|
||||
dict[str, Any],
|
||||
self._sio.call(
|
||||
async def trigger_method() -> Any:
|
||||
return await self._sio.call(
|
||||
"trigger_method",
|
||||
{
|
||||
"access_path": serialized_object["full_access_path"],
|
||||
"args": dump(list(args)),
|
||||
"kwargs": dump(kwargs),
|
||||
},
|
||||
),
|
||||
)
|
||||
return loads(serialized_response) # type: ignore
|
||||
)
|
||||
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
trigger_method(),
|
||||
loop=self._loop,
|
||||
).result()
|
||||
return loads(result)
|
||||
|
||||
return method_proxy
|
||||
|
||||
@@ -160,27 +166,34 @@ class ProxyClassFactory:
|
||||
return create_proxy_class(serialized_object)()
|
||||
|
||||
def _create_attr_property(self, serialized_attr: SerializedObject) -> property:
|
||||
def get(self: ProxyBaseClass) -> Any: # type: ignore
|
||||
return loads(
|
||||
cast(
|
||||
SerializedObject,
|
||||
self._sio.call("get_value", serialized_attr["full_access_path"]),
|
||||
def get(self: ProxyBaseClass) -> Any:
|
||||
async def get_result() -> Any:
|
||||
return await self._sio.call(
|
||||
"get_value", serialized_attr["full_access_path"]
|
||||
)
|
||||
)
|
||||
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
get_result(),
|
||||
loop=self._loop,
|
||||
).result()
|
||||
return loads(result)
|
||||
|
||||
get.__doc__ = serialized_attr["doc"]
|
||||
|
||||
def set(self: ProxyBaseClass, value: Any) -> None: # type: ignore
|
||||
result = cast(
|
||||
SerializedObject | None,
|
||||
self._sio.call(
|
||||
def set(self: ProxyBaseClass, value: Any) -> None:
|
||||
async def set_result() -> Any:
|
||||
return await self._sio.call(
|
||||
"update_value",
|
||||
{
|
||||
"access_path": serialized_attr["full_access_path"],
|
||||
"value": dump(value),
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result: SerializedObject | None = asyncio.run_coroutine_threadsafe(
|
||||
set_result(),
|
||||
loop=self._loop,
|
||||
).result()
|
||||
if result is not None:
|
||||
loads(result)
|
||||
|
||||
|
||||
@@ -119,12 +119,18 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
||||
async def connect(sid: str, environ: Any) -> None:
|
||||
logging.debug("Client [%s] connected", click.style(str(sid), fg="cyan"))
|
||||
|
||||
await sio.emit("class_structure", state_manager.cache, to=sid)
|
||||
|
||||
@sio.event # type: ignore
|
||||
async def disconnect(sid: str) -> None:
|
||||
logging.debug("Client [%s] disconnected", click.style(str(sid), fg="cyan"))
|
||||
|
||||
@sio.event # type: ignore
|
||||
async def service_serialization(sid: str) -> SerializedObject:
|
||||
logging.debug(
|
||||
"Client [%s] requested service serialization",
|
||||
click.style(str(sid), fg="cyan"),
|
||||
)
|
||||
return state_manager.cache
|
||||
|
||||
@sio.event
|
||||
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: # type: ignore
|
||||
path = data["access_path"]
|
||||
|
||||
Reference in New Issue
Block a user