From c8a41565a2379611c2aa0037b3d5f23dd8b08b65 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Fri, 22 Nov 2024 23:07:34 +0100 Subject: [PATCH] fix: fixed websocket connection and state transfer between redis and fastapi --- .gitlab-ci.yml | 3 + backend/bec_atlas/main.py | 10 +- backend/bec_atlas/router/redis_router.py | 173 +++++++++++++++++++++-- backend/tests/conftest.py | 47 ++++++ backend/tests/docker-compose.yml | 6 +- backend/tests/test_login.py | 43 ++---- backend/tests/test_redis_websocket.py | 99 +++++++++++++ 7 files changed, 333 insertions(+), 48 deletions(-) create mode 100644 backend/tests/test_redis_websocket.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0ad091e..127686a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -28,6 +28,9 @@ include: services: - name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/scylladb/scylla:latest alias: scylla + - name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/redis:latest + alias: redis + command: ["redis-server", "--port", "6380"] formatter: stage: Formatter diff --git a/backend/bec_atlas/main.py b/backend/bec_atlas/main.py index 60c1b44..ecc328a 100644 --- a/backend/bec_atlas/main.py +++ b/backend/bec_atlas/main.py @@ -16,6 +16,7 @@ class AtlasApp: def __init__(self, config=None): self.config = config or CONFIG self.app = FastAPI() + self.server = None self.prefix = f"/api/{self.API_VERSION}" self.datasources = DatasourceManager(config=self.config) self.register_event_handler() @@ -41,11 +42,16 @@ class AtlasApp: self.app.include_router(self.user_router.router) if "redis" in self.datasources.datasources: - self.redis_websocket = RedisWebsocket(prefix=self.prefix, datasources=self.datasources) + self.redis_websocket = RedisWebsocket( + prefix=self.prefix, datasources=self.datasources, app=self + ) self.app.mount("/", self.redis_websocket.app) def run(self, port=8000): - uvicorn.run(self.app, host="localhost", port=port) + config = uvicorn.Config(self.app, host="localhost", port=port) + self.server = uvicorn.Server(config=config) + self.server.run() + # uvicorn.run(self.app, host="localhost", port=port) def main(): diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 46d0215..56dec3c 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -1,14 +1,19 @@ import asyncio +import functools import inspect import json +import traceback from typing import TYPE_CHECKING import socketio from bec_lib.endpoints import MessageEndpoints -from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from bec_lib.logger import bec_logger +from fastapi import APIRouter from bec_atlas.router.base_router import BaseRouter +logger = bec_logger.logger + if TYPE_CHECKING: from bec_lib.redis_connector import RedisConnector @@ -37,52 +42,196 @@ class RedisRouter(BaseRouter): return self.redis.delete(key) +def safe_socket(fcn): + @functools.wraps(fcn) + async def wrapper(self, sid, *args, **kwargs): + try: + out = await fcn(self, sid, *args, **kwargs) + except Exception as exc: + content = traceback.format_exc() + logger.error(content) + await self.socket.emit("error", {"error": str(exc)}, room=sid) + return + return out + + return wrapper + + +class BECAsyncRedisManager(socketio.AsyncRedisManager): + def __init__( + self, + parent, + url="redis://localhost:6379/0", + channel="socketio", + write_only=False, + logger=None, + redis_options=None, + ): + self.parent = parent + super().__init__(url, channel, write_only, logger, redis_options) + self.requested_channels = [] + self.started_update_loop = False + + # task = asyncio.create_task(self._required_channel_heartbeat()) + # loop.run_until_complete(task) + + def start_update_loop(self): + self.started_update_loop = True + loop = asyncio.get_event_loop() + task = loop.create_task(self._backend_heartbeat()) + return task + + async def disconnect(self, sid, namespace, **kwargs): + if kwargs.get("ignore_queue"): + await super().disconnect(sid, namespace, **kwargs) + await self.update_state_info() + return + message = { + "method": "disconnect", + "sid": sid, + "namespace": namespace or "/", + "host_id": self.host_id, + } + await self._handle_disconnect(message) # handle in this host + await self._publish(message) # notify other hosts + + async def enter_room(self, sid, namespace, room, eio_sid=None): + await super().enter_room(sid, namespace, room, eio_sid=eio_sid) + await self.update_state_info() + + async def leave_room(self, sid, namespace, room): + await super().leave_room(sid, namespace, room) + await self.update_state_info() + + async def _backend_heartbeat(self): + while not self.parent.fastapi_app.server.should_exit: + await asyncio.sleep(1) + await self.redis.publish(f"deployments/{self.host_id}/heartbeat/", "ping") + data = json.dumps(self.parent.users) + print(f"Sending heartbeat: {data}") + await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30) + + async def update_state_info(self): + data = json.dumps(self.parent.users) + await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30) + await self.redis.publish(f"deployments/{self.host_id}/state/", data) + + async def update_websocket_states(self): + loop = asyncio.get_event_loop() + if not self.started_update_loop and loop.is_running(): + self.start_update_loop() + await self.update_state_info() + + async def remove_user(self, sid): + if sid in self.parent.users: + del self.parent.users[sid] + print(f"Removed user {sid}") + await self.update_state_info() + + class RedisWebsocket: """ This class is a websocket handler for the Redis API. It exposes the redis client through the websocket. """ - def __init__(self, prefix="/api/v1", datasources=None): + def __init__(self, prefix="/api/v1", datasources=None, app=None): self.redis: RedisConnector = datasources.datasources["redis"].connector self.prefix = prefix + self.fastapi_app = app self.active_connections = set() - self.socket = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi") + self.socket = socketio.AsyncServer( + cors_allowed_origins="*", + async_mode="asgi", + client_manager=BECAsyncRedisManager( + self, url=f"redis://{self.redis.host}:{self.redis.port}/0" + ), + ) self.app = socketio.ASGIApp(self.socket) self.loop = asyncio.get_event_loop() + self.users = {} self.socket.on("connect", self.connect_client) self.socket.on("register", self.redis_register) + self.socket.on("unregister", self.redis.unregister) self.socket.on("disconnect", self.disconnect_client) - def connect_client(self, sid, environ=None): + @safe_socket + async def connect_client(self, sid, environ=None): print("Client connected") - self.active_connections.add(sid) + http_query = environ.get("HTTP_QUERY") + if not http_query: + raise ValueError("Query parameters not found") + query = json.loads(http_query) - def disconnect_client(self, sid, _environ=None): + if "user" not in query: + raise ValueError("User not found in query parameters") + user = query["user"] + + if sid not in self.users: + # check if the user was already registered in redis + deployment_keys = await self.socket.manager.redis.keys("deployments/*/state/") + if not deployment_keys: + state_data = [] + else: + state_data = await self.socket.manager.redis.mget(*deployment_keys) + info = {} + for data in state_data: + if not data: + continue + obj = json.loads(data) + for value in obj.values(): + info[value["user"]] = value["subscriptions"] + + if user in info: + self.users[sid] = {"user": user, "subscriptions": info[user]} + for endpoint in set(self.users[sid]["subscriptions"]): + await self.socket.enter_room(sid, f"ENDPOINT::{endpoint}") + await self.socket.manager.update_websocket_states() + else: + self.users[sid] = {"user": query["user"], "subscriptions": []} + + await self.socket.manager.update_websocket_states() + + async def disconnect_client(self, sid, _environ=None): print("Client disconnected") - self.active_connections.remove(sid) + is_exit = self.fastapi_app.server.should_exit + if is_exit: + return + await self.socket.manager.remove_user(sid) + @safe_socket async def redis_register(self, sid: str, msg: str): + """ + Register a client to a redis channel. + + Args: + sid (str): The socket id of the client + msg (str): The message sent by the client + """ if sid not in self.active_connections: self.active_connections.add(sid) try: print(msg) data = json.loads(msg) except json.JSONDecodeError: - return + raise ValueError("Invalid JSON message") - endpoint = getattr(MessageEndpoints, data.get("endpoint")) + endpoint = getattr(MessageEndpoints, data.get("endpoint"), None) + if endpoint is None: + raise ValueError(f"Endpoint {data.get('endpoint')} not found") # check if the endpoint receives arguments - if len(inspect.signature(endpoint).parameters) > 1: + if len(inspect.signature(endpoint).parameters) > 0: endpoint = endpoint(data.get("args")) else: endpoint = endpoint() self.redis.register(endpoint, cb=self.on_redis_message, parent=self) - await self.socket.enter_room(sid, endpoint.endpoint) - await self.socket.emit("registered", data={"endpoint": endpoint.endpoint}, room=sid) + if data.get("endpoint") not in self.users[sid]["subscriptions"]: + await self.socket.enter_room(sid, f"ENDPOINT::{data.get('endpoint')}") + self.users[sid]["subscriptions"].append(data.get("endpoint")) + await self.socket.manager.update_websocket_states() @staticmethod def on_redis_message(message, parent): diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 303f79b..324870a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -3,6 +3,9 @@ import os from typing import Iterator import pytest +from bec_atlas.main import AtlasApp +from bec_atlas.utils.setup_database import setup_database +from fastapi.testclient import TestClient from pytest_docker.plugin import DockerComposeExecutor, Services @@ -96,3 +99,47 @@ def docker_services( docker_cleanup, ) as docker_service: yield docker_service + + +@pytest.fixture(scope="session") +def scylla_container(docker_ip, docker_services): + host = docker_ip + if os.path.exists("/.dockerenv"): + # if we are running in the CI, scylla was started as 'scylla' service + host = "scylla" + if docker_services is None: + port = 9042 + else: + port = docker_services.port_for("scylla", 9042) + + setup_database(host=host, port=port) + return host, port + + +@pytest.fixture(scope="session") +def redis_container(docker_ip, docker_services): + host = docker_ip + if os.path.exists("/.dockerenv"): + # if we are running in the CI, scylla was started as 'scylla' service + host = "redis" + if docker_services is None: + port = 6380 + else: + port = docker_services.port_for("redis", 6379) + + return host, port + + +@pytest.fixture(scope="session") +def backend(scylla_container, redis_container): + scylla_host, scylla_port = scylla_container + redis_host, redis_port = redis_container + config = { + "scylla": {"hosts": [(scylla_host, scylla_port)]}, + "redis": {"host": redis_host, "port": redis_port}, + } + + app = AtlasApp(config) + + with TestClient(app.app) as _client: + yield _client, app diff --git a/backend/tests/docker-compose.yml b/backend/tests/docker-compose.yml index 2e373f3..bd5a8f1 100644 --- a/backend/tests/docker-compose.yml +++ b/backend/tests/docker-compose.yml @@ -3,4 +3,8 @@ services: scylla: image: scylladb/scylla:latest ports: - - "9070:9042" \ No newline at end of file + - "9070:9042" + redis: + image: redis:latest + ports: + - "6380:6379" \ No newline at end of file diff --git a/backend/tests/test_login.py b/backend/tests/test_login.py index 82b8629..872e70c 100644 --- a/backend/tests/test_login.py +++ b/backend/tests/test_login.py @@ -1,41 +1,18 @@ -import os - import pytest -from bec_atlas.main import AtlasApp -from bec_atlas.utils.setup_database import setup_database -from fastapi.testclient import TestClient -@pytest.fixture(scope="session") -def scylla_container(docker_ip, docker_services): - host = docker_ip - if os.path.exists("/.dockerenv"): - # if we are running in the CI, scylla was started as 'scylla' service - host = "scylla" - if docker_services is None: - port = 9042 - else: - port = docker_services.port_for("scylla", 9042) - - setup_database(host=host, port=port) - return host, port - - -@pytest.fixture(scope="session") -def client(scylla_container): - host, port = scylla_container - config = {"scylla": {"hosts": [(host, port)]}} - - with TestClient(AtlasApp(config).app) as _client: - yield _client +@pytest.fixture +def backend_client(backend): + client, _ = backend + return client @pytest.mark.timeout(60) -def test_login(client): +def test_login(backend_client): """ Test that the login endpoint returns a token. """ - response = client.post( + response = backend_client.post( "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} ) assert response.status_code == 200 @@ -45,11 +22,11 @@ def test_login(client): @pytest.mark.timeout(60) -def test_login_wrong_password(client): +def test_login_wrong_password(backend_client): """ Test that the login returns a 401 when the password is wrong. """ - response = client.post( + response = backend_client.post( "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "wrong_password"} ) assert response.status_code == 401 @@ -57,11 +34,11 @@ def test_login_wrong_password(client): @pytest.mark.timeout(60) -def test_login_unknown_user(client): +def test_login_unknown_user(backend_client): """ Test that the login returns a 404 when the user is unknown. """ - response = client.post( + response = backend_client.post( "/api/v1/user/login", json={"username": "no_user@bec_atlas.ch", "password": "wrong_password"}, ) diff --git a/backend/tests/test_redis_websocket.py b/backend/tests/test_redis_websocket.py new file mode 100644 index 0000000..5d6f13a --- /dev/null +++ b/backend/tests/test_redis_websocket.py @@ -0,0 +1,99 @@ +import json +from unittest import mock + +import pytest + + +@pytest.fixture +def backend_client(backend): + client, app = backend + app.server = mock.Mock() + app.server.should_exit = False + app.redis_websocket.users = {} + yield client, app + app.redis_websocket.redis._redis_conn.flushall() + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_connect(backend_client): + client, app = backend_client + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", {"HTTP_QUERY": '{"user": "test"}'} + ) + assert "sid" in app.redis_websocket.users + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_disconnect(backend_client): + client, app = backend_client + app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []} + await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") + assert "sid" not in app.redis_websocket.users + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_multiple_connect(backend_client): + client, app = backend_client + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid1", {"HTTP_QUERY": '{"user": "test1"}'} + ) + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid2", {"HTTP_QUERY": '{"user": "test2"}'} + ) + assert "sid1" in app.redis_websocket.users + assert "sid2" in app.redis_websocket.users + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_multiple_connect_same_sid(backend_client): + client, app = backend_client + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", {"HTTP_QUERY": '{"user": "test"}'} + ) + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", {"HTTP_QUERY": '{"user": "test"}'} + ) + + assert "sid" in app.redis_websocket.users + assert len(app.redis_websocket.users) == 1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_multiple_disconnect_same_sid(backend_client): + client, app = backend_client + app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []} + await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") + await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") + assert "sid" not in app.redis_websocket.users + assert len(app.redis_websocket.users) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_register_wrong_endpoint_raises(backend_client): + client, app = backend_client + with mock.patch.object(app.redis_websocket.socket, "emit") as emit: + app.redis_websocket.socket.handlers["/"]["connect"]("sid") + await app.redis_websocket.socket.handlers["/"]["register"]( + "sid", json.dumps({"endpoint": "wrong_endpoint"}) + ) + assert mock.call("error", mock.ANY, room="sid") in emit.mock_calls + + +@pytest.mark.asyncio(loop_scope="session") +async def test_redis_websocket_register(backend_client): + client, app = backend_client + with mock.patch.object(app.redis_websocket.socket, "emit") as emit: + with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room: + with mock.patch.object(app.redis_websocket.socket.manager, "rooms") as rooms: + rooms.__getitem__.return_value = {"ENDPOINT::scan_status": "sid"} + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", {"HTTP_QUERY": '{"user": "test"}'} + ) + + await app.redis_websocket.socket.handlers["/"]["register"]( + "sid", json.dumps({"endpoint": "scan_status"}) + ) + assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls + enter_room.assert_called_with("sid", "ENDPOINT::scan_status") + + assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls