diff --git a/src/pyDataInterface/server/__init__.py b/src/pyDataInterface/server/__init__.py index 09ed39c..5ccd59f 100644 --- a/src/pyDataInterface/server/__init__.py +++ b/src/pyDataInterface/server/__init__.py @@ -1,3 +1,4 @@ from .server import Server +from .web_server import WebAPI __all__ = ["Server"] diff --git a/src/pyDataInterface/server/server.py b/src/pyDataInterface/server/server.py index b55f8dc..2de51d4 100644 --- a/src/pyDataInterface/server/server.py +++ b/src/pyDataInterface/server/server.py @@ -18,6 +18,8 @@ from uvicorn.server import HANDLED_SIGNALS from pyDataInterface import DataService from pyDataInterface.version import __version__ +from .web_server import WebAPI + try: import tiqi_rpc except ImportError: @@ -123,31 +125,47 @@ class Server: host=self._host, port=self._rpc_port, ) - tiqi_rpc_server.install_signal_handlers = lambda: None + tiqi_rpc_server.install_signal_handlers = lambda: None # type: ignore future_or_task = self._loop.create_task(tiqi_rpc_server.serve()) self.servers["tiqi-rpc"] = future_or_task if self._enable_web: - # async def print_client_color() -> None: - # while True: - # print(self._service.name) - # await asyncio.sleep(1) + self._wapi: WebAPI = WebAPI( + *self._args, + service=self._service, + info=self._info, + **self._kwargs, + ) + web_server = uvicorn.Server( + uvicorn.Config( + self._wapi.fastapi_app, host=self._host, port=self._web_port + ) + ) - # future_or_task = self._loop.create_task(print_client_color()) - # self._wapi: FastAPI = web_api( - # data_model=self._data_model, - # info=self._info, - # *self._args, - # **self._kwargs, - # ) - # web_server = uvicorn.Server( - # uvicorn.Config(self._wapi, host=self._host, port=self._web_port) - # ) - # # overwrite uvicorn's signal handlers, otherwise it will bogart SIGINT and - # # SIGTERM, which makes it impossible to escape out of - # web_server.install_signal_handlers = lambda: None - # future_or_task = self._loop.create_task(web_server.serve()) - # self.servers["web"] = future_or_task - pass + def sio_callback(parent_path: str, name: str, value: Any) -> None: + async def notify() -> None: + try: + await self._wapi.sio.emit( + "notify", + { + "data": { + "parent_path": parent_path, + "name": name, + "value": value, + } + }, + ) + except Exception as e: + logger.warning(f"Failed to send notification: {e}") + + self._loop.create_task(notify()) + + self._service.add_notification_callback(sio_callback) + + # overwrite uvicorn's signal handlers, otherwise it will bogart SIGINT and + # SIGTERM, which makes it impossible to escape out of + web_server.install_signal_handlers = lambda: None # type: ignore + future_or_task = self._loop.create_task(web_server.serve()) + self.servers["web"] = future_or_task async def main_loop(self) -> None: while not self.should_exit: @@ -156,14 +174,14 @@ class Server: async def shutdown(self) -> None: logger.info("Shutting down") - await self._cancel_servers() - await self._cancel_tasks() + await self.__cancel_servers() + await self.__cancel_tasks() if self._enable_rpc: logger.debug("Closing rpyc server.") self._rpc_server.close() - async def _cancel_servers(self) -> None: + async def __cancel_servers(self) -> None: for server_name, task in self.servers.items(): task.cancel() try: @@ -173,7 +191,7 @@ class Server: except Exception as e: logger.warning(f"Unexpected exception: {e}.") - async def _cancel_tasks(self) -> None: + async def __cancel_tasks(self) -> None: for task in asyncio.all_tasks(self._loop): task.cancel() try: @@ -217,7 +235,7 @@ class Server: if self._enable_web: async def emit_exception() -> None: - await self._wapi._sio.emit( + await self._wapi.sio.emit( "notify", { "data": { diff --git a/src/pyDataInterface/server/web_server.py b/src/pyDataInterface/server/web_server.py new file mode 100644 index 0000000..bbcb515 --- /dev/null +++ b/src/pyDataInterface/server/web_server.py @@ -0,0 +1,106 @@ +from pathlib import Path +from typing import Any + +import socketio +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +from pyDataInterface import DataService +from pyDataInterface.config import OperationMode +from pyDataInterface.version import __version__ + + +class WebAPI: + __sio_app: socketio.ASGIApp + __fastapi_app: FastAPI + + def __init__( + self, + service: DataService, + frontend: str | Path | None = None, + css: str | Path | None = None, + enable_CORS: bool = True, + info: dict[str, Any] = {}, + *args: Any, + **kwargs: Any, + ): + self.service = service + self.frontend = frontend + self.css = css + self.enable_CORS = enable_CORS + self.info = info + self.args = args + self.kwargs = kwargs + + self.setup_socketio() + self.setup_fastapi_app() + + def setup_socketio(self) -> None: + # the socketio ASGI app, to notify clients when params update + if self.enable_CORS: + self.__sio = socketio.AsyncServer( + async_mode="asgi", cors_allowed_origins="*" + ) + else: + self.__sio = socketio.AsyncServer(async_mode="asgi") + self.__sio_app = socketio.ASGIApp(self.__sio) + + def setup_fastapi_app(self) -> None: # noqa: CFQ004 + app = FastAPI() + + if self.enable_CORS: + app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + app.mount("/ws", self.__sio_app) + + # @app.get("/version", include_in_schema=False) + @app.get("/version") + def version() -> str: + return __version__ + + @app.get("/name") + def name() -> str: + return self.service.get_service_name() + + @app.get("/info") + def info() -> dict[str, Any]: + return self.info + + @app.get("/service-properties") + def service_properties() -> dict[str, Any]: + return self.service.serialize() + + if OperationMode().environment == "production": + app.mount( + "/", + StaticFiles( + directory=Path(__file__).parent.parent.parent.parent + / "frontend" + / "build", + html=True, + ), + ) + + self.__fastapi_app = app + + def add_endpoint(self, name: str) -> None: + # your endpoint creation code + pass + + def get_custom_openapi(self) -> None: + # your custom openapi generation code + pass + + @property + def sio(self) -> socketio.AsyncServer: + return self.__sio + + @property + def fastapi_app(self) -> FastAPI: + return self.__fastapi_app