mirror of
https://github.com/tiqi-group/pydase.git
synced 2025-04-21 00:40:01 +02:00
Merge pull request #190 from tiqi-group/fix/async_functions
Fix: triggering async functions
This commit is contained in:
commit
7b786be892
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from copy import copy
|
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import socketio # type: ignore
|
import socketio # type: ignore
|
||||||
@ -202,24 +201,7 @@ class ProxyClassMixin:
|
|||||||
def _handle_serialized_method(
|
def _handle_serialized_method(
|
||||||
self, attr_name: str, serialized_object: SerializedObject
|
self, attr_name: str, serialized_object: SerializedObject
|
||||||
) -> None:
|
) -> None:
|
||||||
def add_prefix_to_last_path_element(s: str, prefix: str) -> str:
|
|
||||||
parts = s.split(".")
|
|
||||||
parts[-1] = f"{prefix}_{parts[-1]}"
|
|
||||||
return ".".join(parts)
|
|
||||||
|
|
||||||
if serialized_object["type"] == "method":
|
if serialized_object["type"] == "method":
|
||||||
if serialized_object["async"] is True:
|
|
||||||
start_method = copy(serialized_object)
|
|
||||||
start_method["full_access_path"] = add_prefix_to_last_path_element(
|
|
||||||
start_method["full_access_path"], "start"
|
|
||||||
)
|
|
||||||
stop_method = copy(serialized_object)
|
|
||||||
stop_method["full_access_path"] = add_prefix_to_last_path_element(
|
|
||||||
stop_method["full_access_path"], "stop"
|
|
||||||
)
|
|
||||||
self._add_method_proxy(f"start_{attr_name}", start_method)
|
|
||||||
self._add_method_proxy(f"stop_{attr_name}", stop_method)
|
|
||||||
else:
|
|
||||||
self._add_method_proxy(attr_name, serialized_object)
|
self._add_method_proxy(attr_name, serialized_object)
|
||||||
|
|
||||||
def _add_method_proxy(
|
def _add_method_proxy(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -7,9 +8,11 @@ import aiohttp_middlewares.error
|
|||||||
from pydase.data_service.state_manager import StateManager
|
from pydase.data_service.state_manager import StateManager
|
||||||
from pydase.server.web_server.api.v1.endpoints import (
|
from pydase.server.web_server.api.v1.endpoints import (
|
||||||
get_value,
|
get_value,
|
||||||
|
trigger_async_method,
|
||||||
trigger_method,
|
trigger_method,
|
||||||
update_value,
|
update_value,
|
||||||
)
|
)
|
||||||
|
from pydase.utils.helpers import get_object_attr_from_path
|
||||||
from pydase.utils.serialization.serializer import dump
|
from pydase.utils.serialization.serializer import dump
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -21,12 +24,9 @@ STATUS_OK = 200
|
|||||||
STATUS_FAILED = 400
|
STATUS_FAILED = 400
|
||||||
|
|
||||||
|
|
||||||
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
|
async def _get_value(
|
||||||
api_application = aiohttp.web.Application(
|
state_manager: StateManager, request: aiohttp.web.Request
|
||||||
middlewares=(aiohttp_middlewares.error.error_middleware(),)
|
) -> aiohttp.web.Response:
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
||||||
logger.info("Handle api request: %s", request)
|
logger.info("Handle api request: %s", request)
|
||||||
|
|
||||||
access_path = request.rel_url.query["access_path"]
|
access_path = request.rel_url.query["access_path"]
|
||||||
@ -40,7 +40,10 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
|
|||||||
status = STATUS_FAILED
|
status = STATUS_FAILED
|
||||||
return aiohttp.web.json_response(result, status=status)
|
return aiohttp.web.json_response(result, status=status)
|
||||||
|
|
||||||
async def _update_value(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
||||||
|
async def _update_value(
|
||||||
|
state_manager: StateManager, request: aiohttp.web.Request
|
||||||
|
) -> aiohttp.web.Response:
|
||||||
data: UpdateDict = await request.json()
|
data: UpdateDict = await request.json()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -51,18 +54,45 @@ def create_api_application(state_manager: StateManager) -> aiohttp.web.Applicati
|
|||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
||||||
|
|
||||||
async def _trigger_method(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
|
||||||
|
async def _trigger_method(
|
||||||
|
state_manager: StateManager, request: aiohttp.web.Request
|
||||||
|
) -> aiohttp.web.Response:
|
||||||
data: TriggerMethodDict = await request.json()
|
data: TriggerMethodDict = await request.json()
|
||||||
|
|
||||||
|
method = get_object_attr_from_path(state_manager.service, data["access_path"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return aiohttp.web.json_response(trigger_method(state_manager, data))
|
if inspect.iscoroutinefunction(method):
|
||||||
|
method_return = await trigger_async_method(
|
||||||
|
state_manager=state_manager, data=data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
method_return = trigger_method(state_manager=state_manager, data=data)
|
||||||
|
|
||||||
|
return aiohttp.web.json_response(method_return)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
return aiohttp.web.json_response(dump(e), status=STATUS_FAILED)
|
||||||
|
|
||||||
api_application.router.add_get("/get_value", _get_value)
|
|
||||||
api_application.router.add_put("/update_value", _update_value)
|
def create_api_application(state_manager: StateManager) -> aiohttp.web.Application:
|
||||||
api_application.router.add_put("/trigger_method", _trigger_method)
|
api_application = aiohttp.web.Application(
|
||||||
|
middlewares=(aiohttp_middlewares.error.error_middleware(),)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_application.router.add_get(
|
||||||
|
"/get_value",
|
||||||
|
lambda request: _get_value(state_manager=state_manager, request=request),
|
||||||
|
)
|
||||||
|
api_application.router.add_put(
|
||||||
|
"/update_value",
|
||||||
|
lambda request: _update_value(state_manager=state_manager, request=request),
|
||||||
|
)
|
||||||
|
api_application.router.add_put(
|
||||||
|
"/trigger_method",
|
||||||
|
lambda request: _trigger_method(state_manager=state_manager, request=request),
|
||||||
|
)
|
||||||
|
|
||||||
return api_application
|
return api_application
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import pydase.utils.serialization.deserializer
|
import pydase.utils.serialization.deserializer
|
||||||
import pydase.utils.serialization.serializer
|
import pydase.utils.serialization.serializer
|
||||||
@ -7,6 +7,9 @@ from pydase.server.web_server.sio_setup import TriggerMethodDict, UpdateDict
|
|||||||
from pydase.utils.helpers import get_object_attr_from_path
|
from pydase.utils.helpers import get_object_attr_from_path
|
||||||
from pydase.utils.serialization.types import SerializedObject
|
from pydase.utils.serialization.types import SerializedObject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
loads = pydase.utils.serialization.deserializer.loads
|
loads = pydase.utils.serialization.deserializer.loads
|
||||||
Serializer = pydase.utils.serialization.serializer.Serializer
|
Serializer = pydase.utils.serialization.serializer.Serializer
|
||||||
|
|
||||||
@ -36,3 +39,19 @@ def trigger_method(state_manager: StateManager, data: TriggerMethodDict) -> Any:
|
|||||||
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
|
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
|
||||||
|
|
||||||
return Serializer.serialize_object(method(*args, **kwargs))
|
return Serializer.serialize_object(method(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
async def trigger_async_method(
|
||||||
|
state_manager: StateManager, data: TriggerMethodDict
|
||||||
|
) -> Any:
|
||||||
|
method: Callable[..., Awaitable[Any]] = get_object_attr_from_path(
|
||||||
|
state_manager.service, data["access_path"]
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_args = data.get("args", None)
|
||||||
|
args = loads(serialized_args) if serialized_args else []
|
||||||
|
|
||||||
|
serialized_kwargs = data.get("kwargs", None)
|
||||||
|
kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {}
|
||||||
|
|
||||||
|
return Serializer.serialize_object(await method(*args, **kwargs))
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
from pydase.utils.helpers import get_object_attr_from_path
|
||||||
|
|
||||||
if sys.version_info < (3, 11):
|
if sys.version_info < (3, 11):
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
else:
|
else:
|
||||||
@ -11,11 +14,11 @@ else:
|
|||||||
import click
|
import click
|
||||||
import socketio # type: ignore[import-untyped]
|
import socketio # type: ignore[import-untyped]
|
||||||
|
|
||||||
import pydase.server.web_server.api.v1.endpoints
|
|
||||||
import pydase.utils.serialization.deserializer
|
import pydase.utils.serialization.deserializer
|
||||||
import pydase.utils.serialization.serializer
|
import pydase.utils.serialization.serializer
|
||||||
from pydase.data_service.data_service_observer import DataServiceObserver
|
from pydase.data_service.data_service_observer import DataServiceObserver
|
||||||
from pydase.data_service.state_manager import StateManager
|
from pydase.data_service.state_manager import StateManager
|
||||||
|
from pydase.server.web_server.api.v1 import endpoints
|
||||||
from pydase.utils.logging import SocketIOHandler
|
from pydase.utils.logging import SocketIOHandler
|
||||||
from pydase.utils.serialization.serializer import SerializedObject
|
from pydase.utils.serialization.serializer import SerializedObject
|
||||||
|
|
||||||
@ -155,9 +158,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
|||||||
@sio.event
|
@sio.event
|
||||||
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None:
|
async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None:
|
||||||
try:
|
try:
|
||||||
pydase.server.web_server.api.v1.endpoints.update_value(
|
endpoints.update_value(state_manager=state_manager, data=data)
|
||||||
state_manager=state_manager, data=data
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return dump(e)
|
return dump(e)
|
||||||
@ -166,7 +167,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
|||||||
@sio.event
|
@sio.event
|
||||||
async def get_value(sid: str, access_path: str) -> SerializedObject:
|
async def get_value(sid: str, access_path: str) -> SerializedObject:
|
||||||
try:
|
try:
|
||||||
return pydase.server.web_server.api.v1.endpoints.get_value(
|
return endpoints.get_value(
|
||||||
state_manager=state_manager, access_path=access_path
|
state_manager=state_manager, access_path=access_path
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -175,10 +176,14 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) ->
|
|||||||
|
|
||||||
@sio.event
|
@sio.event
|
||||||
async def trigger_method(sid: str, data: TriggerMethodDict) -> Any:
|
async def trigger_method(sid: str, data: TriggerMethodDict) -> Any:
|
||||||
|
method = get_object_attr_from_path(state_manager.service, data["access_path"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return pydase.server.web_server.api.v1.endpoints.trigger_method(
|
if inspect.iscoroutinefunction(method):
|
||||||
|
return await endpoints.trigger_async_method(
|
||||||
state_manager=state_manager, data=data
|
state_manager=state_manager, data=data
|
||||||
)
|
)
|
||||||
|
return endpoints.trigger_method(state_manager=state_manager, data=data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return dump(e)
|
return dump(e)
|
||||||
|
@ -41,6 +41,9 @@ def pydase_client() -> Generator[pydase.Client, None, Any]:
|
|||||||
def my_method(self, input_str: str) -> str:
|
def my_method(self, input_str: str) -> str:
|
||||||
return input_str
|
return input_str
|
||||||
|
|
||||||
|
async def my_async_method(self, input_str: str) -> str:
|
||||||
|
return input_str
|
||||||
|
|
||||||
server = pydase.Server(MyService(), web_port=9999)
|
server = pydase.Server(MyService(), web_port=9999)
|
||||||
thread = threading.Thread(target=server.run, daemon=True)
|
thread = threading.Thread(target=server.run, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
@ -79,6 +82,14 @@ def test_method_execution(pydase_client: pydase.Client) -> None:
|
|||||||
pydase_client.proxy.my_method(kwarg="hello")
|
pydase_client.proxy.my_method(kwarg="hello")
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_method_execution(pydase_client: pydase.Client) -> None:
|
||||||
|
assert pydase_client.proxy.my_async_method("My return string") == "My return string"
|
||||||
|
assert (
|
||||||
|
pydase_client.proxy.my_async_method(input_str="My return string")
|
||||||
|
== "My return string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_nested_service(pydase_client: pydase.Client) -> None:
|
def test_nested_service(pydase_client: pydase.Client) -> None:
|
||||||
assert pydase_client.proxy.sub_service.name == "SubService"
|
assert pydase_client.proxy.sub_service.name == "SubService"
|
||||||
pydase_client.proxy.sub_service.name = "New name"
|
pydase_client.proxy.sub_service.name = "New name"
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import pydase
|
import pydase
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydase.utils.serialization.deserializer import Deserializer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -40,7 +41,10 @@ def pydase_server() -> Generator[None, None, None]:
|
|||||||
return self._readonly_attr
|
return self._readonly_attr
|
||||||
|
|
||||||
def my_method(self, input_str: str) -> str:
|
def my_method(self, input_str: str) -> str:
|
||||||
return input_str
|
return f"{input_str}: my_method"
|
||||||
|
|
||||||
|
async def my_async_method(self, input_str: str) -> str:
|
||||||
|
return f"{input_str}: my_async_method"
|
||||||
|
|
||||||
server = pydase.Server(MyService(), web_port=9998)
|
server = pydase.Server(MyService(), web_port=9998)
|
||||||
thread = threading.Thread(target=server.run, daemon=True)
|
thread = threading.Thread(target=server.run, daemon=True)
|
||||||
@ -192,3 +196,57 @@ async def test_update_value(
|
|||||||
resp = await session.get(f"/api/v1/get_value?access_path={access_path}")
|
resp = await session.get(f"/api/v1/get_value?access_path={access_path}")
|
||||||
content = json.loads(await resp.text())
|
content = json.loads(await resp.text())
|
||||||
assert content == new_value
|
assert content == new_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"access_path, expected, ok",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"my_method",
|
||||||
|
"Hello from function: my_method",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"my_async_method",
|
||||||
|
"Hello from function: my_async_method",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"invalid_method",
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_trigger_method(
|
||||||
|
access_path: str,
|
||||||
|
expected: Any,
|
||||||
|
ok: bool,
|
||||||
|
pydase_server: pydase.DataService,
|
||||||
|
) -> None:
|
||||||
|
async with aiohttp.ClientSession("http://localhost:9998") as session:
|
||||||
|
resp = await session.put(
|
||||||
|
"/api/v1/trigger_method",
|
||||||
|
json={
|
||||||
|
"access_path": access_path,
|
||||||
|
"kwargs": {
|
||||||
|
"full_access_path": "",
|
||||||
|
"type": "dict",
|
||||||
|
"value": {
|
||||||
|
"input_str": {
|
||||||
|
"docs": None,
|
||||||
|
"full_access_path": "",
|
||||||
|
"readonly": False,
|
||||||
|
"type": "str",
|
||||||
|
"value": "Hello from function",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.ok == ok
|
||||||
|
|
||||||
|
if resp.ok:
|
||||||
|
content = Deserializer.deserialize(json.loads(await resp.text()))
|
||||||
|
assert content == expected
|
||||||
|
Loading…
x
Reference in New Issue
Block a user