mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-20 08:20:02 +02:00
Merge pull request #190 from tiqi-group/fix/async_functions
Fix: triggering async functions
This commit is contained in:
commit
7b786be892
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import socketio # type: ignore
|
||||
@ -202,24 +201,7 @@ class ProxyClassMixin:
|
||||
def _handle_serialized_method(
|
||||
self, attr_name: str, serialized_object: SerializedObject
|
||||
) -> None:
|
||||
def add_prefix_to_last_path_element(s: str, prefix: str) -> str:
|
||||
parts = s.split(".")
|
||||
parts[-1] = f"{prefix}_{parts[-1]}"
|
||||
return ".".join(parts)
|
||||
|
||||
if serialized_object["type"] == "method":
|
||||
if serialized_object["async"] is True:
|
||||
start_method = copy(serialized_object)
|
||||
start_method["full_access_path"] = add_prefix_to_last_path_element(
|
||||
start_method["full_access_path"], "start"
|
||||
)
|
||||
stop_method = copy(serialized_object)
|
||||
stop_method["full_access_path"] = add_prefix_to_last_path_element(
|
||||
stop_method["full_access_path"], "stop"
|
||||
)
|
||||
self._add_method_proxy(f"start_{attr_name}", start_method)
|
||||
self._add_method_proxy(f"stop_{attr_name}", stop_method)
|
||||
else:
|
||||
self._add_method_proxy(attr_name, serialized_object)
|
||||
|
||||
def _add_method_proxy(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -7,9 +8,11 @@ import aiohttp_middlewares.error
|
||||
from pydase.data_service.state_manager import StateManager
|
||||
from pydase.server.web_server.api.v1.endpoints import (
|
||||
get_value,
|
||||
trigger_async_method,
|
||||
trigger_method,
|
||||
update_value,
|
||||
)
|
||||
from pydase.utils.helpers import get_object_attr_from_path
|
||||
from pydase.utils.serialization.serializer import dump
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -21,12 +24,9 @@ STATUS_OK = 200
|
||||
STATUS_FAILED = 400
|
||||
|
||||
|
||||
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
|
||||
api_application = aiohttp.web.Application(
|
||||
middlewares=(aiohttp_middlewares.error.error_middleware(),)
|
||||
)
|
||||
|
||||
async def _get_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
async def _get_value(
|
||||
state_manager: StateManager, request: aiohttp.web.Request
|
||||
) -> aiohttp.web.Response:
|
||||
logger.info("Handle api request: %s", request)
|
||||
|
||||
access_path = request.rel_url.query["access_path"]
|
||||
@ -40,7 +40,10 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
|
||||
status = STATUS_FAILED
|
||||
return aiohttp.web.json_response(result, status=status)
|
||||
|
||||
async def _update_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
|
||||
async def _update_value(
|
||||
state_manager: StateManager, request: aiohttp.web.Request
|
||||
) -> aiohttp.web.Response:
|
||||
data: UpdateDict = await request.json()
|
||||
|
||||
try:
|
||||
@ -51,18 +54,45 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
|
||||
logger.exception(e)
|
||||
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
||||
|
||||
async def _trigger_method(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
|
||||
async def _trigger_method(
|
||||
state_manager: StateManager, request: aiohttp.web.Request
|
||||
) -> aiohttp.web.Response:
|
||||
data: TriggerMethodDict = await request.json()
|
||||
|
||||
method = get_object_attr_from_path(state_manager.service, data["access_path"])
|
||||
|
||||
try:
|
||||
return aiohttp.web.json_response(trigger_method(state_manager, data))
|
||||
if inspect.iscoroutinefunction(method):
|
||||
method_return = await trigger_async_method(
|
||||
state_manager=state_manager, data=data
|
||||
)
|
||||
else:
|
||||
method_return = trigger_method(state_manager=state_manager, data=data)
|
||||
|
||||
return aiohttp.web.json_response(method_return)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
||||
|
||||
api_application.router.add_get("/get_value", _get_value)
|
||||
api_application.router.add_put("/update_value", _update_value)
|
||||
api_application.router.add_put("/trigger_method", _trigger_method)
|
||||
|
||||
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
|
||||
api_application = aiohttp.web.Application(
|
||||
middlewares=(aiohttp_middlewares.error.error_middleware(),)
|
||||
)
|
||||
|
||||
api_application.router.add_get(
|
||||
"/get_value",
|
||||
lambda request: _get_value(state_manager=state_manager, request=request),
|
||||
)
|
||||
api_application.router.add_put(
|
||||
"/update_value",
|
||||
lambda request: _update_value(state_manager=state_manager, request=request),
|
||||
)
|
||||
api_application.router.add_put(
|
||||
"/trigger_method",
|
||||
lambda request: _trigger_method(state_manager=state_manager, request=request),
|
||||
)
|
||||
|
||||
return api_application
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pydase.utils.serialization.deserializer
|
||||
import pydase.utils.serialization.serializer
|
||||
@ -7,6 +7,9 @@ from pydase.server.web_server.sio_setup import TriggerMethodDict, UpdateDict
|
||||
from pydase.utils.helpers import get_object_attr_from_path
|
||||
from pydase.utils.serialization.types import SerializedObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
loads = pydase.utils.serialization.deserializer.loads
|
||||
Serializer = pydase.utils.serialization.serializer.Serializer
|
||||
|
||||
@ -36,3 +39,19 @@ def trigger_method(state_manager: StateManager, data: TriggerMethodDict) -> Any:
|
||||
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
|
||||
|
||||
return Serializer.serialize_object(method(*args, **kwargs))
|
||||
|
||||
|
||||
async def trigger_async_method(
|
||||
state_manager: StateManager, data: TriggerMethodDict
|
||||
) -> Any:
|
||||
method: Callable[..., Awaitable[Any]] = get_object_attr_from_path(
|
||||
state_manager.service, data["access_path"]
|
||||
)
|
||||
|
||||
serialized_args = data.get("args", None)
|
||||
args = loads(serialized_args) if serialized_args else []
|
||||
|
||||
serialized_kwargs = data.get("kwargs", None)
|
||||
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
|
||||
|
||||
return Serializer.serialize_object(await method(*args, **kwargs))
|
||||
|
@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydase.utils.helpers import get_object_attr_from_path
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired
|
||||
else:
|
||||
@ -11,11 +14,11 @@ else:
|
||||
import click
|
||||
import socketio # type: ignore[import-untyped]
|
||||
|
||||
import pydase.server.web_server.api.v1.endpoints
|
||||
import pydase.utils.serialization.deserializer
|
||||
import pydase.utils.serialization.serializer
|
||||
from pydase.data_service.data_service_observer import DataServiceObserver
|
||||
from pydase.data_service.state_manager import StateManager
|
||||
from pydase.server.web_server.api.v1 import endpoints
|
||||
from pydase.utils.logging import SocketIOHandler
|
||||
from pydase.utils.serialization.serializer import SerializedObject
|
||||
|
||||
@ -155,9 +158,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
||||
@sio.event
|
||||
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None:
|
||||
try:
|
||||
pydase.server.web_server.api.v1.endpoints.update_value(
|
||||
state_manager=state_manager, data=data
|
||||
)
|
||||
endpoints.update_value(state_manager=state_manager, data=data)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return dump(e)
|
||||
@ -166,7 +167,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
||||
@sio.event
|
||||
async def get_value(sid: str, access_path: str) -> SerializedObject:
|
||||
try:
|
||||
return pydase.server.web_server.api.v1.endpoints.get_value(
|
||||
return endpoints.get_value(
|
||||
state_manager=state_manager, access_path=access_path
|
||||
)
|
||||
except Exception as e:
|
||||
@ -175,10 +176,14 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
||||
|
||||
@sio.event
|
||||
async def trigger_method(sid: str, data: TriggerMethodDict) -> Any:
|
||||
method = get_object_attr_from_path(state_manager.service, data["access_path"])
|
||||
|
||||
try:
|
||||
return pydase.server.web_server.api.v1.endpoints.trigger_method(
|
||||
if inspect.iscoroutinefunction(method):
|
||||
return await endpoints.trigger_async_method(
|
||||
state_manager=state_manager, data=data
|
||||
)
|
||||
return endpoints.trigger_method(state_manager=state_manager, data=data)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return dump(e)
|
||||
|
@ -41,6 +41,9 @@ def pydase_client() -> Generator[pydase.Client, None, Any]:
|
||||
def my_method(self, input_str: str) -> str:
|
||||
return input_str
|
||||
|
||||
async def my_async_method(self, input_str: str) -> str:
|
||||
return input_str
|
||||
|
||||
server = pydase.Server(MyService(), web_port=9999)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
thread.start()
|
||||
@ -79,6 +82,14 @@ def test_method_execution(pydase_client: pydase.Client) -> None:
|
||||
pydase_client.proxy.my_method(kwarg="hello")
|
||||
|
||||
|
||||
def test_async_method_execution(pydase_client: pydase.Client) -> None:
|
||||
assert pydase_client.proxy.my_async_method("My return string") == "My return string"
|
||||
assert (
|
||||
pydase_client.proxy.my_async_method(input_str="My return string")
|
||||
== "My return string"
|
||||
)
|
||||
|
||||
|
||||
def test_nested_service(pydase_client: pydase.Client) -> None:
|
||||
assert pydase_client.proxy.sub_service.name == "SubService"
|
||||
pydase_client.proxy.sub_service.name = "New name"
|
||||
|
@ -6,6 +6,7 @@ from typing import Any
|
||||
import aiohttp
|
||||
import pydase
|
||||
import pytest
|
||||
from pydase.utils.serialization.deserializer import Deserializer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -40,7 +41,10 @@ def pydase_server() -> Generator[None, None, None]:
|
||||
return self._readonly_attr
|
||||
|
||||
def my_method(self, input_str: str) -> str:
|
||||
return input_str
|
||||
return f"{input_str}: my_method"
|
||||
|
||||
async def my_async_method(self, input_str: str) -> str:
|
||||
return f"{input_str}: my_async_method"
|
||||
|
||||
server = pydase.Server(MyService(), web_port=9998)
|
||||
thread = threading.Thread(target=server.run, daemon=True)
|
||||
@ -192,3 +196,57 @@ async def test_update_value(
|
||||
resp = await session.get(f"/api/v1/get_value?access_path={access_path}")
|
||||
content = json.loads(await resp.text())
|
||||
assert content == new_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"access_path, expected, ok",
|
||||
[
|
||||
(
|
||||
"my_method",
|
||||
"Hello from function: my_method",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"my_async_method",
|
||||
"Hello from function: my_async_method",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"invalid_method",
|
||||
None,
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_trigger_method(
|
||||
access_path: str,
|
||||
expected: Any,
|
||||
ok: bool,
|
||||
pydase_server: pydase.DataService,
|
||||
) -> None:
|
||||
async with aiohttp.ClientSession("http://localhost:9998") as session:
|
||||
resp = await session.put(
|
||||
"/api/v1/trigger_method",
|
||||
json={
|
||||
"access_path": access_path,
|
||||
"kwargs": {
|
||||
"full_access_path": "",
|
||||
"type": "dict",
|
||||
"value": {
|
||||
"input_str": {
|
||||
"docs": None,
|
||||
"full_access_path": "",
|
||||
"readonly": False,
|
||||
"type": "str",
|
||||
"value": "Hello from function",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert resp.ok == ok
|
||||
|
||||
if resp.ok:
|
||||
content = Deserializer.deserialize(json.loads(await resp.text()))
|
||||
assert content == expected
|
||||
|
Loading…
x
Reference in New Issue
Block a user