From 755a30323913ef40396a7482517849f9a104ef8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Thu, 25 Jul 2024 11:36:05 +0200 Subject: [PATCH] creates api definition, using that in sio_setup --- src/pydase/server/web_server/api.py | 88 +++++++++++++++++++++++ src/pydase/server/web_server/sio_setup.py | 32 +++++---- 2 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 src/pydase/server/web_server/api.py diff --git a/src/pydase/server/web_server/api.py b/src/pydase/server/web_server/api.py new file mode 100644 index 0000000..d0d6805 --- /dev/null +++ b/src/pydase/server/web_server/api.py @@ -0,0 +1,88 @@ +import logging +from typing import Any + +import aiohttp.web +import aiohttp_middlewares.error + +from pydase.data_service.state_manager import StateManager +from pydase.server.web_server.sio_setup import TriggerMethodDict, UpdateDict +from pydase.utils.helpers import get_object_attr_from_path +from pydase.utils.serialization.deserializer import loads +from pydase.utils.serialization.serializer import dump +from pydase.utils.serialization.types import SerializedObject + +logger = logging.getLogger(__name__) + +API_VERSION = "v1" + + +def update_value(state_manager: StateManager, data: UpdateDict) -> None: + path = data["access_path"] + + state_manager.set_service_attribute_value_by_path( + path=path, serialized_value=data["value"] + ) + + +def get_value(state_manager: StateManager, access_path: str) -> SerializedObject: + return state_manager._data_service_cache.get_value_dict_from_cache(access_path) + + +def trigger_method(state_manager: StateManager, data: TriggerMethodDict) -> Any: + method = get_object_attr_from_path(state_manager.service, data["access_path"]) + + serialized_args = data.get("args", None) + args = loads(serialized_args) if serialized_args else [] + + serialized_kwargs = data.get("kwargs", None) + kwargs: dict[str, Any] = loads(serialized_kwargs) if serialized_kwargs else {} + + return dump(method(*args, **kwargs)) + + +def create_api_application(state_manager: StateManager) -> aiohttp.web.Application: + api_application = aiohttp.web.Application( + middlewares=(aiohttp_middlewares.error.error_middleware(),) + ) + + async def _get_value(request: aiohttp.web.Request) -> aiohttp.web.Response: + logger.info("Handle api request: %s", request) + api_version = request.match_info["version"] + logger.info("Version number: %s", api_version) + + access_path = request.rel_url.query["access_path"] + + try: + result = get_value(state_manager, access_path) + except Exception as e: + logger.exception(e) + result = dump(e) + return aiohttp.web.json_response(result) + + async def _update_value(request: aiohttp.web.Request) -> aiohttp.web.Response: + data: UpdateDict = await request.json() + + try: + update_value(state_manager, data) + + return aiohttp.web.Response() + except Exception as e: + logger.exception(e) + return aiohttp.web.json_response(dump(e)) + + async def _trigger_method(request: aiohttp.web.Request) -> aiohttp.web.Response: + data: TriggerMethodDict = await request.json() + + try: + trigger_method(state_manager, data) + + return aiohttp.web.Response() + except Exception as e: + logger.exception(e) + return aiohttp.web.json_response(dump(e)) + + api_application.router.add_get("/{version}/get_value", _get_value) + api_application.router.add_post("/{version}/update_value", _update_value) + api_application.router.add_post("/{version}/trigger_method", _trigger_method) + + return api_application diff --git a/src/pydase/server/web_server/sio_setup.py b/src/pydase/server/web_server/sio_setup.py index aa160f4..e91c7a4 100644 --- a/src/pydase/server/web_server/sio_setup.py +++ b/src/pydase/server/web_server/sio_setup.py @@ -1,15 +1,21 @@ import asyncio import logging +import sys from typing import Any, TypedDict +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired + import click import socketio # type: ignore[import-untyped] +import pydase.server.web_server.api import pydase.utils.serialization.deserializer import pydase.utils.serialization.serializer from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.state_manager import StateManager -from pydase.utils.helpers import get_object_attr_from_path from pydase.utils.logging import SocketIOHandler from pydase.utils.serialization.serializer import SerializedObject @@ -39,8 +45,8 @@ class UpdateDict(TypedDict): class TriggerMethodDict(TypedDict): access_path: str - args: SerializedObject - kwargs: SerializedObject + args: NotRequired[SerializedObject] + kwargs: NotRequired[SerializedObject] class RunMethodDict(TypedDict): @@ -137,21 +143,22 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> return state_manager.cache_manager.cache @sio.event - async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: # type: ignore - path = data["access_path"] - + async def update_value(sid: str, data: UpdateDict) -> SerializedObject | None: try: - state_manager.set_service_attribute_value_by_path( - path=path, serialized_value=data["value"] + pydase.server.web_server.api.update_value( + state_manager=state_manager, data=data ) except Exception as e: logger.exception(e) return dump(e) + return None @sio.event async def get_value(sid: str, access_path: str) -> SerializedObject: try: - return state_manager.cache_manager.get_value_dict_from_cache(access_path) + return pydase.server.web_server.api.get_value( + state_manager=state_manager, access_path=access_path + ) except Exception as e: logger.exception(e) return dump(e) @@ -159,12 +166,9 @@ def setup_sio_events(sio: socketio.AsyncServer, state_manager: StateManager) -> @sio.event async def trigger_method(sid: str, data: TriggerMethodDict) -> Any: try: - method = get_object_attr_from_path( - state_manager.service, data["access_path"] + return pydase.server.web_server.api.trigger_method( + state_manager=state_manager, data=data ) - args = loads(data["args"]) - kwargs: dict[str, Any] = loads(data["kwargs"]) - return dump(method(*args, **kwargs)) except Exception as e: logger.error(e) return dump(e)