diff --git a/src/pydase/client/proxy_loader.py b/src/pydase/client/proxy_loader.py index af4e7d2..5909296 100644 --- a/src/pydase/client/proxy_loader.py +++ b/src/pydase/client/proxy_loader.py @@ -1,7 +1,6 @@ import asyncio import logging from collections.abc import Iterable -from copy import copy from typing import TYPE_CHECKING, Any, cast import socketio # type: ignore @@ -202,25 +201,8 @@ class ProxyClassMixin: def _handle_serialized_method( self, attr_name: str, serialized_object: SerializedObject ) -> None: - def add_prefix_to_last_path_element(s: str, prefix: str) -> str: - parts = s.split(".") - parts[-1] = f"{prefix}_{parts[-1]}" - return ".".join(parts) - if serialized_object["type"] == "method": - if serialized_object["async"] is True: - start_method = copy(serialized_object) - start_method["full_access_path"] = add_prefix_to_last_path_element( - start_method["full_access_path"], "start" - ) - stop_method = copy(serialized_object) - stop_method["full_access_path"] = add_prefix_to_last_path_element( - stop_method["full_access_path"], "stop" - ) - self._add_method_proxy(f"start_{attr_name}", start_method) - self._add_method_proxy(f"stop_{attr_name}", stop_method) - else: - self._add_method_proxy(attr_name, serialized_object) + self._add_method_proxy(attr_name, serialized_object) def _add_method_proxy( self, attr_name: str, serialized_object: SerializedObject diff --git a/src/pydase/server/web_server/api/v1/application.py b/src/pydase/server/web_server/api/v1/application.py index 37c79cf..ff3cc7a 100644 --- a/src/pydase/server/web_server/api/v1/application.py +++ b/src/pydase/server/web_server/api/v1/application.py @@ -1,3 +1,4 @@ +import inspect import logging from typing import TYPE_CHECKING @@ -7,9 +8,11 @@ import aiohttp_middlewares.error from pydase.data_service.state_manager import StateManager from pydase.server.web_server.api.v1.endpoints import ( get_value, + trigger_async_method, trigger_method, update_value, ) +from pydase.utils.helpers import get_object_attr_from_path from pydase.utils.serialization.serializer import dump if TYPE_CHECKING: @@ -21,48 +24,75 @@ STATUS_OK = 200 STATUS_FAILED = 400 +async def _get_value( + state_manager: StateManager, request: aiohttp.web.Request +) -> aiohttp.web.Response: + logger.info("Handle api request: %s", request) + + access_path = request.rel_url.query["access_path"] + + status = STATUS_OK + try: + result = get_value(state_manager, access_path) + except Exception as e: + logger.exception(e) + result = dump(e) + status = STATUS_FAILED + return aiohttp.web.json_response(result, status=status) + + +async def _update_value( + state_manager: StateManager, request: aiohttp.web.Request +) -> aiohttp.web.Response: + data: UpdateDict = await request.json() + + try: + update_value(state_manager, data) + + return aiohttp.web.json_response() + except Exception as e: + logger.exception(e) + return aiohttp.web.json_response(dump(e), status=STATUS_FAILED) + + +async def _trigger_method( + state_manager: StateManager, request: aiohttp.web.Request +) -> aiohttp.web.Response: + data: TriggerMethodDict = await request.json() + + method = get_object_attr_from_path(state_manager.service, data["access_path"]) + + try: + if inspect.iscoroutinefunction(method): + method_return = await trigger_async_method( + state_manager=state_manager, data=data + ) + else: + method_return = trigger_method(state_manager=state_manager, data=data) + + return aiohttp.web.json_response(method_return) + + except Exception as e: + logger.exception(e) + return aiohttp.web.json_response(dump(e), status=STATUS_FAILED) + + def create_api_application(state_manager: StateManager) -> aiohttp.web.Application: api_application = aiohttp.web.Application( middlewares=(aiohttp_middlewares.error.error_middleware(),) ) - async def _get_value(request: aiohttp.web.Request) -> aiohttp.web.Response: - logger.info("Handle api request: %s", request) - - access_path = request.rel_url.query["access_path"] - - status = STATUS_OK - try: - result = get_value(state_manager, access_path) - except Exception as e: - logger.exception(e) - result = dump(e) - status = STATUS_FAILED - return aiohttp.web.json_response(result, status=status) - - async def _update_value(request: aiohttp.web.Request) -> aiohttp.web.Response: - data: UpdateDict = await request.json() - - try: - update_value(state_manager, data) - - return aiohttp.web.json_response() - except Exception as e: - logger.exception(e) - return aiohttp.web.json_response(dump(e), status=STATUS_FAILED) - - async def _trigger_method(request: aiohttp.web.Request) -> aiohttp.web.Response: - data: TriggerMethodDict = await request.json() - - try: - return aiohttp.web.json_response(trigger_method(state_manager, data)) - - except Exception as e: - logger.exception(e) - return aiohttp.web.json_response(dump(e), status=STATUS_FAILED) - - api_application.router.add_get("/get_value", _get_value) - api_application.router.add_put("/update_value", _update_value) - api_application.router.add_put("/trigger_method", _trigger_method) + api_application.router.add_get( + "/get_value", + lambda request: _get_value(state_manager=state_manager, request=request), + ) + api_application.router.add_put( + "/update_value", + lambda request: _update_value(state_manager=state_manager, request=request), + ) + api_application.router.add_put( + "/trigger_method", + lambda request: _trigger_method(state_manager=state_manager, request=request), + ) return api_application diff --git a/src/pydase/server/web_server/api/v1/endpoints.py b/src/pydase/server/web_server/api/v1/endpoints.py index f09a8cf..dca28ab 100644 --- a/src/pydase/server/web_server/api/v1/endpoints.py +++ b/src/pydase/server/web_server/api/v1/endpoints.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import TYPE_CHECKING, Any import pydase.utils.serialization.deserializer import pydase.utils.serialization.serializer @@ -7,6 +7,9 @@ from pydase.server.web_server.sio_setup import TriggerMethodDict, UpdateDict from pydase.utils.helpers import get_object_attr_from_path from pydase.utils.serialization.types import SerializedObject +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + loads = pydase.utils.serialization.deserializer.loads Serializer = pydase.utils.serialization.serializer.Serializer @@ -36,3 +39,19 @@ def trigger_method(state_manager: StateManager, data: TriggerMethodDict) -> Any: kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {} return Serializer.serialize_object(method(*args, **kwargs)) + + +async def trigger_async_method( + state_manager: StateManager, data: TriggerMethodDict +) -> Any: + method: Callable[..., Awaitable[Any]] = get_object_attr_from_path( + state_manager.service, data["access_path"] + ) + + serialized_args = data.get("args", None) + args = loads(serialized_args) if serialized_args else [] + + serialized_kwargs = data.get("kwargs", None) + kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {} + + return Serializer.serialize_object(await method(*args, **kwargs)) diff --git a/src/pydase/server/web_server/sio_setup.py b/src/pydase/server/web_server/sio_setup.py index bac9f3b..f97a2fe 100644 --- a/src/pydase/server/web_server/sio_setup.py +++ b/src/pydase/server/web_server/sio_setup.py @@ -1,8 +1,11 @@ import asyncio +import inspect import logging import sys from typing import Any, TypedDict +from pydase.utils.helpers import get_object_attr_from_path + if sys.version_info < (3, 11): from typing_extensions import NotRequired else: @@ -11,11 +14,11 @@ else: import click import socketio # type: ignore[import-untyped] -import pydase.server.web_server.api.v1.endpoints import pydase.utils.serialization.deserializer import pydase.utils.serialization.serializer from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.state_manager import StateManager +from pydase.server.web_server.api.v1 import endpoints from pydase.utils.logging import SocketIOHandler from pydase.utils.serialization.serializer import SerializedObject @@ -155,9 +158,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> @sio.event async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: try: - pydase.server.web_server.api.v1.endpoints.update_value( - state_manager=state_manager, data=data - ) + endpoints.update_value(state_manager=state_manager, data=data) except Exception as e: logger.exception(e) return dump(e) @@ -166,7 +167,7 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> @sio.event async def get_value(sid: str, access_path: str) -> SerializedObject: try: - return pydase.server.web_server.api.v1.endpoints.get_value( + return endpoints.get_value( state_manager=state_manager, access_path=access_path ) except Exception as e: @@ -175,10 +176,14 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> @sio.event async def trigger_method(sid: str, data: TriggerMethodDict) -> Any: + method = get_object_attr_from_path(state_manager.service, data["access_path"]) + try: - return pydase.server.web_server.api.v1.endpoints.trigger_method( - state_manager=state_manager, data=data - ) + if inspect.iscoroutinefunction(method): + return await endpoints.trigger_async_method( + state_manager=state_manager, data=data + ) + return endpoints.trigger_method(state_manager=state_manager, data=data) except Exception as e: logger.error(e) return dump(e) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 27c2cd0..ae29443 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -41,6 +41,9 @@ def pydase_client() -> Generator[pydase.Client, None, Any]: def my_method(self, input_str: str) -> str: return input_str + async def my_async_method(self, input_str: str) -> str: + return input_str + server = pydase.Server(MyService(), web_port=9999) thread = threading.Thread(target=server.run, daemon=True) thread.start() @@ -79,6 +82,14 @@ def test_method_execution(pydase_client: pydase.Client) -> None: pydase_client.proxy.my_method(kwarg="hello") +def test_async_method_execution(pydase_client: pydase.Client) -> None: + assert pydase_client.proxy.my_async_method("My return string") == "My return string" + assert ( + pydase_client.proxy.my_async_method(input_str="My return string") + == "My return string" + ) + + def test_nested_service(pydase_client: pydase.Client) -> None: assert pydase_client.proxy.sub_service.name == "SubService" pydase_client.proxy.sub_service.name = "New name" diff --git a/tests/server/web_server/api/v1/test_endpoints.py b/tests/server/web_server/api/v1/test_endpoints.py index 3fd416c..d737c8d 100644 --- a/tests/server/web_server/api/v1/test_endpoints.py +++ b/tests/server/web_server/api/v1/test_endpoints.py @@ -6,6 +6,7 @@ from typing import Any import aiohttp import pydase import pytest +from pydase.utils.serialization.deserializer import Deserializer @pytest.fixture() @@ -40,7 +41,10 @@ def pydase_server() -> Generator[None, None, None]: return self._readonly_attr def my_method(self, input_str: str) -> str: - return input_str + return f"{input_str}: my_method" + + async def my_async_method(self, input_str: str) -> str: + return f"{input_str}: my_async_method" server = pydase.Server(MyService(), web_port=9998) thread = threading.Thread(target=server.run, daemon=True) @@ -192,3 +196,57 @@ async def test_update_value( resp = await session.get(f"/api/v1/get_value?access_path={access_path}") content = json.loads(await resp.text()) assert content == new_value + + +@pytest.mark.parametrize( + "access_path, expected, ok", + [ + ( + "my_method", + "Hello from function: my_method", + True, + ), + ( + "my_async_method", + "Hello from function: my_async_method", + True, + ), + ( + "invalid_method", + None, + False, + ), + ], +) +@pytest.mark.asyncio() +async def test_trigger_method( + access_path: str, + expected: Any, + ok: bool, + pydase_server: pydase.DataService, +) -> None: + async with aiohttp.ClientSession("http://localhost:9998") as session: + resp = await session.put( + "/api/v1/trigger_method", + json={ + "access_path": access_path, + "kwargs": { + "full_access_path": "", + "type": "dict", + "value": { + "input_str": { + "docs": None, + "full_access_path": "", + "readonly": False, + "type": "str", + "value": "Hello from function", + }, + }, + }, + }, + ) + assert resp.ok == ok + + if resp.ok: + content = Deserializer.deserialize(json.loads(await resp.text())) + assert content == expected