diff --git a/src/pyDataInterface/__init__.py b/src/pyDataInterface/__init__.py index d8f4aa8..87ef78f 100644 --- a/src/pyDataInterface/__init__.py +++ b/src/pyDataInterface/__init__.py @@ -1,4 +1,5 @@ from .data_service import DataService +from .server import Server from .version import __major__, __minor__, __patch__, __version__ __all__ = [ @@ -7,5 +8,5 @@ __all__ = [ "__minor__", "__patch__", "DataService", - # "Server", + "Server", ] diff --git a/src/pyDataInterface/server/__init__.py b/src/pyDataInterface/server/__init__.py index e69de29..09ed39c 100644 --- a/src/pyDataInterface/server/__init__.py +++ b/src/pyDataInterface/server/__init__.py @@ -0,0 +1,3 @@ +from .server import Server + +__all__ = ["Server"] diff --git a/src/pyDataInterface/server/server.py b/src/pyDataInterface/server/server.py index e69de29..b55f8dc 100644 --- a/src/pyDataInterface/server/server.py +++ b/src/pyDataInterface/server/server.py @@ -0,0 +1,232 @@ +import asyncio +import os +import signal +import threading +from concurrent.futures import ThreadPoolExecutor +from types import FrameType +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI +from loguru import logger +from rpyc import ( + ForkingServer, # can be used for multiprocessing, E.g. a database interface server +) +from rpyc import ThreadedServer +from uvicorn.server import HANDLED_SIGNALS + +from pyDataInterface import DataService +from pyDataInterface.version import __version__ + +try: + import tiqi_rpc +except ImportError: + logger.debug("tiqi_rpc is not installed. tiqi_rpc.Server will not be exposed.") + tiqi_rpc = None # type: ignore + + +class Server: + def __init__( + self, + service: DataService, + host: str = "0.0.0.0", + rpc_port: int = 18871, + tiqi_rpc_port: int = 6007, + web_port: int = 8001, + enable_rpc: bool = True, + enable_tiqi_rpc: bool = True, + enable_web: bool = True, + use_forking_server: bool = False, + web_settings: dict[str, Any] = {}, + *args: Any, + **kwargs: Any, + ) -> None: + self._service = service + self._host = host + self._rpc_port = rpc_port + self._tiqi_rpc_port = tiqi_rpc_port + self._web_port = web_port + self._enable_rpc = enable_rpc + self._enable_tiqi_rpc = enable_tiqi_rpc + self._enable_web = enable_web + self._web_settings = web_settings + self._args = args + self._kwargs = kwargs + self._loop: asyncio.AbstractEventLoop + self._rpc_server_type = ForkingServer if use_forking_server else ThreadedServer + self.should_exit = False + # self.servers: list[asyncio.Future[Any]] = [] + self.servers: dict[str, asyncio.Future[Any]] = {} + self.executor: ThreadPoolExecutor | None = None + self._info: dict[str, Any] = { + "name": self._service.get_service_name(), + "version": __version__, + "rpc_port": self._rpc_port, + "tiqi_rpc_port": self._tiqi_rpc_port, + "web_port": self._web_port, + "enable_rpc": self._enable_rpc, + "enable_tiqi_rpc": self._enable_tiqi_rpc, + "enable_web": self._enable_web, + "web_settings": self._web_settings, + **kwargs, + } + + def run(self) -> None: + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self.serve()) + + async def serve(self) -> None: + process_id = os.getpid() + + logger.info(f"Started server process [{process_id}]") + + await self.startup() + if self.should_exit: + return + await self.main_loop() + await self.shutdown() + + logger.info(f"Finished server process [{process_id}]") + + def _start_autostart_tasks(self) -> None: + self._service._start_autostart_tasks() + + async def startup(self) -> None: + self._loop = asyncio.get_running_loop() + self._loop.set_exception_handler(self.custom_exception_handler) + self.install_signal_handlers() + self._start_autostart_tasks() + + if self._enable_rpc: + self.executor = ThreadPoolExecutor() + self._rpc_server = self._rpc_server_type( + self._service, + port=self._rpc_port, + protocol_config={ + "allow_all_attrs": True, + "allow_setattr": True, + }, + ) + future_or_task = self._loop.run_in_executor( + executor=self.executor, func=self._rpc_server.start + ) + self.servers["rpyc"] = future_or_task + if self._enable_tiqi_rpc and tiqi_rpc is not None: + tiqi_rpc_server = tiqi_rpc.Server( + RPCInterface( + self._data_model, *self._args, info=self._info, **self._kwargs + ), + host=self._host, + port=self._rpc_port, + ) + tiqi_rpc_server.install_signal_handlers = lambda: None + future_or_task = self._loop.create_task(tiqi_rpc_server.serve()) + self.servers["tiqi-rpc"] = future_or_task + if self._enable_web: + # async def print_client_color() -> None: + # while True: + # print(self._service.name) + # await asyncio.sleep(1) + + # future_or_task = self._loop.create_task(print_client_color()) + # self._wapi: FastAPI = web_api( + # data_model=self._data_model, + # info=self._info, + # *self._args, + # **self._kwargs, + # ) + # web_server = uvicorn.Server( + # uvicorn.Config(self._wapi, host=self._host, port=self._web_port) + # ) + # # overwrite uvicorn's signal handlers, otherwise it will bogart SIGINT and + # # SIGTERM, which makes it impossible to escape out of + # web_server.install_signal_handlers = lambda: None + # future_or_task = self._loop.create_task(web_server.serve()) + # self.servers["web"] = future_or_task + pass + + async def main_loop(self) -> None: + while not self.should_exit: + await asyncio.sleep(0.1) + + async def shutdown(self) -> None: + logger.info("Shutting down") + + await self._cancel_servers() + await self._cancel_tasks() + + if self._enable_rpc: + logger.debug("Closing rpyc server.") + self._rpc_server.close() + + async def _cancel_servers(self) -> None: + for server_name, task in self.servers.items(): + task.cancel() + try: + await task + except asyncio.CancelledError: + logger.debug(f"Cancelled {server_name} server.") + except Exception as e: + logger.warning(f"Unexpected exception: {e}.") + + async def _cancel_tasks(self) -> None: + for task in asyncio.all_tasks(self._loop): + task.cancel() + try: + await task + except asyncio.CancelledError: + logger.debug(f"Cancelled task {task.get_coro()}.") + except Exception as e: + logger.warning(f"Unexpected exception: {e}.") + + def install_signal_handlers(self) -> None: + if threading.current_thread() is not threading.main_thread(): + # Signals can only be listened to from the main thread. + return + + try: + for sig in HANDLED_SIGNALS: + self._loop.add_signal_handler(sig, self.handle_exit, sig, None) + except NotImplementedError: # pragma: no cover + # Windows + for sig in HANDLED_SIGNALS: + signal.signal(sig, self.handle_exit) + + def handle_exit(self, sig: int = 0, frame: Optional[FrameType] = None) -> None: + logger.info("Handling exit") + if self.should_exit and sig == signal.SIGINT: + self.force_exit = True + else: + self.should_exit = True + + def custom_exception_handler( + self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] + ) -> None: + # if any background task creates an unhandled exception, shut down the entire + # loop. It's possible we don't want to do this, maybe make this optional in the + # future + loop.default_exception_handler(context) + + # here we exclude most kinds of exceptions from triggering this kind of shutdown + exc = context.get("exception") + if type(exc) not in [RuntimeError, KeyboardInterrupt, asyncio.CancelledError]: + if self._enable_web: + + async def emit_exception() -> None: + await self._wapi._sio.emit( + "notify", + { + "data": { + "exception": str(exc), + "type": exc.__class__.__name__, + } + }, + ) + + loop.create_task(emit_exception()) + else: + self.handle_exit()