Merge pull request #190 from tiqi-group/fix/async_functions

Fix: triggering async functions
This commit is contained in:
Mose Müller 2024-12-02 14:58:55 +01:00 committed by GitHub
commit 7b786be892
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 172 additions and 67 deletions

View File

@ -1,7 +1,6 @@
import asyncio
import logging
from collections.abc import Iterable
from copy import copy
from typing import TYPE_CHECKING, Any, cast
import socketio # type: ignore
@ -202,24 +201,7 @@ class ProxyClassMixin:
def _handle_serialized_method(
self, attr_name: str, serialized_object: SerializedObject
) -> None:
def add_prefix_to_last_path_element(s: str, prefix: str) -> str:
parts = s.split(".")
parts[-1] = f"{prefix}_{parts[-1]}"
return ".".join(parts)
if serialized_object["type"] == "method":
if serialized_object["async"] is True:
start_method = copy(serialized_object)
start_method["full_access_path"] = add_prefix_to_last_path_element(
start_method["full_access_path"], "start"
)
stop_method = copy(serialized_object)
stop_method["full_access_path"] = add_prefix_to_last_path_element(
stop_method["full_access_path"], "stop"
)
self._add_method_proxy(f"start_{attr_name}", start_method)
self._add_method_proxy(f"stop_{attr_name}", stop_method)
else:
self._add_method_proxy(attr_name, serialized_object)
def _add_method_proxy(

View File

@ -1,3 +1,4 @@
import inspect
import logging
from typing import TYPE_CHECKING
@ -7,9 +8,11 @@ import aiohttp_middlewares.error
from pydase.data_service.state_manager import StateManager
from pydase.server.web_server.api.v1.endpoints import (
get_value,
trigger_async_method,
trigger_method,
update_value,
)
from pydase.utils.helpers import get_object_attr_from_path
from pydase.utils.serialization.serializer import dump
if TYPE_CHECKING:
@ -21,12 +24,9 @@ STATUS_OK = 200
STATUS_FAILED = 400
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
api_application = aiohttp.web.Application(
middlewares=(aiohttp_middlewares.error.error_middleware(),)
)
async def _get_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
async def _get_value(
state_manager: StateManager, request: aiohttp.web.Request
) -> aiohttp.web.Response:
logger.info("Handle api request: %s", request)
access_path = request.rel_url.query["access_path"]
@ -40,7 +40,10 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
status = STATUS_FAILED
return aiohttp.web.json_response(result, status=status)
async def _update_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
async def _update_value(
state_manager: StateManager, request: aiohttp.web.Request
) -> aiohttp.web.Response:
data: UpdateDict = await request.json()
try:
@ -51,18 +54,45 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
logger.exception(e)
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
async def _trigger_method(request: aiohttp.web.Request) -> aiohttp.web.Response:
async def _trigger_method(
state_manager: StateManager, request: aiohttp.web.Request
) -> aiohttp.web.Response:
data: TriggerMethodDict = await request.json()
method = get_object_attr_from_path(state_manager.service, data["access_path"])
try:
return aiohttp.web.json_response(trigger_method(state_manager, data))
if inspect.iscoroutinefunction(method):
method_return = await trigger_async_method(
state_manager=state_manager, data=data
)
else:
method_return = trigger_method(state_manager=state_manager, data=data)
return aiohttp.web.json_response(method_return)
except Exception as e:
logger.exception(e)
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
api_application.router.add_get("/get_value", _get_value)
api_application.router.add_put("/update_value", _update_value)
api_application.router.add_put("/trigger_method", _trigger_method)
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
api_application = aiohttp.web.Application(
middlewares=(aiohttp_middlewares.error.error_middleware(),)
)
api_application.router.add_get(
"/get_value",
lambda request: _get_value(state_manager=state_manager, request=request),
)
api_application.router.add_put(
"/update_value",
lambda request: _update_value(state_manager=state_manager, request=request),
)
api_application.router.add_put(
"/trigger_method",
lambda request: _trigger_method(state_manager=state_manager, request=request),
)
return api_application

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import TYPE_CHECKING, Any
import pydase.utils.serialization.deserializer
import pydase.utils.serialization.serializer
@ -7,6 +7,9 @@ from pydase.server.web_server.sio_setup import TriggerMethodDict, UpdateDict
from pydase.utils.helpers import get_object_attr_from_path
from pydase.utils.serialization.types import SerializedObject
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
loads = pydase.utils.serialization.deserializer.loads
Serializer = pydase.utils.serialization.serializer.Serializer
@ -36,3 +39,19 @@ def trigger_method(state_manager: StateManager, data: TriggerMethodDict) -> Any:
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
return Serializer.serialize_object(method(*args, **kwargs))
async def trigger_async_method(
state_manager: StateManager, data: TriggerMethodDict
) -> Any:
method: Callable[..., Awaitable[Any]] = get_object_attr_from_path(
state_manager.service, data["access_path"]
)
serialized_args = data.get("args", None)
args = loads(serialized_args) if serialized_args else []
serialized_kwargs = data.get("kwargs", None)
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
return Serializer.serialize_object(await method(*args, **kwargs))

View File

@ -1,8 +1,11 @@
import asyncio
import inspect
import logging
import sys
from typing import Any, TypedDict
from pydase.utils.helpers import get_object_attr_from_path
if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
@ -11,11 +14,11 @@ else:
import click
import socketio # type: ignore[import-untyped]
import pydase.server.web_server.api.v1.endpoints
import pydase.utils.serialization.deserializer
import pydase.utils.serialization.serializer
from pydase.data_service.data_service_observer import DataServiceObserver
from pydase.data_service.state_manager import StateManager
from pydase.server.web_server.api.v1 import endpoints
from pydase.utils.logging import SocketIOHandler
from pydase.utils.serialization.serializer import SerializedObject
@ -155,9 +158,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
@sio.event
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None:
try:
pydase.server.web_server.api.v1.endpoints.update_value(
state_manager=state_manager, data=data
)
endpoints.update_value(state_manager=state_manager, data=data)
except Exception as e:
logger.exception(e)
return dump(e)
@ -166,7 +167,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
@sio.event
async def get_value(sid: str, access_path: str) -> SerializedObject:
try:
return pydase.server.web_server.api.v1.endpoints.get_value(
return endpoints.get_value(
state_manager=state_manager, access_path=access_path
)
except Exception as e:
@ -175,10 +176,14 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
@sio.event
async def trigger_method(sid: str, data: TriggerMethodDict) -> Any:
method = get_object_attr_from_path(state_manager.service, data["access_path"])
try:
return pydase.server.web_server.api.v1.endpoints.trigger_method(
if inspect.iscoroutinefunction(method):
return await endpoints.trigger_async_method(
state_manager=state_manager, data=data
)
return endpoints.trigger_method(state_manager=state_manager, data=data)
except Exception as e:
logger.error(e)
return dump(e)

View File

@ -41,6 +41,9 @@ def pydase_client() -> Generator[pydase.Client, None, Any]:
def my_method(self, input_str: str) -> str:
return input_str
async def my_async_method(self, input_str: str) -> str:
return input_str
server = pydase.Server(MyService(), web_port=9999)
thread = threading.Thread(target=server.run, daemon=True)
thread.start()
@ -79,6 +82,14 @@ def test_method_execution(pydase_client: pydase.Client) -> None:
pydase_client.proxy.my_method(kwarg="hello")
def test_async_method_execution(pydase_client: pydase.Client) -> None:
assert pydase_client.proxy.my_async_method("My return string") == "My return string"
assert (
pydase_client.proxy.my_async_method(input_str="My return string")
== "My return string"
)
def test_nested_service(pydase_client: pydase.Client) -> None:
assert pydase_client.proxy.sub_service.name == "SubService"
pydase_client.proxy.sub_service.name = "New name"

View File

@ -6,6 +6,7 @@ from typing import Any
import aiohttp
import pydase
import pytest
from pydase.utils.serialization.deserializer import Deserializer
@pytest.fixture()
@ -40,7 +41,10 @@ def pydase_server() -> Generator[None, None, None]:
return self._readonly_attr
def my_method(self, input_str: str) -> str:
return input_str
return f"{input_str}: my_method"
async def my_async_method(self, input_str: str) -> str:
return f"{input_str}: my_async_method"
server = pydase.Server(MyService(), web_port=9998)
thread = threading.Thread(target=server.run, daemon=True)
@ -192,3 +196,57 @@ async def test_update_value(
resp = await session.get(f"/api/v1/get_value?access_path={access_path}")
content = json.loads(await resp.text())
assert content == new_value
@pytest.mark.parametrize(
"access_path, expected, ok",
[
(
"my_method",
"Hello from function: my_method",
True,
),
(
"my_async_method",
"Hello from function: my_async_method",
True,
),
(
"invalid_method",
None,
False,
),
],
)
@pytest.mark.asyncio()
async def test_trigger_method(
access_path: str,
expected: Any,
ok: bool,
pydase_server: pydase.DataService,
) -> None:
async with aiohttp.ClientSession("http://localhost:9998") as session:
resp = await session.put(
"/api/v1/trigger_method",
json={
"access_path": access_path,
"kwargs": {
"full_access_path": "",
"type": "dict",
"value": {
"input_str": {
"docs": None,
"full_access_path": "",
"readonly": False,
"type": "str",
"value": "Hello from function",
},
},
},
},
)
assert resp.ok == ok
if resp.ok:
content = Deserializer.deserialize(json.loads(await resp.text()))
assert content == expected