From f06f4962e8982e767507fe9a8caa06a5cdd0002a Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Mon, 16 Dec 2024 19:38:20 +0100 Subject: [PATCH] wip --- backend/bec_atlas/router/redis_router.py | 9 +++--- backend/tests/test_redis_websocket.py | 35 ++++++++++++++---------- backend/tests/test_scan_ingestor.py | 7 ++--- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 8257070..e590e8b 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any import socketio from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.logger import bec_logger +from bec_lib.serialization import json_ext from fastapi import APIRouter from bec_atlas.router.base_router import BaseRouter @@ -54,7 +55,7 @@ class RedisAtlasEndpoints: return f"internal/deployment/{deployment}/data/{endpoint}" @staticmethod - def socketio_endpoint_room(endpoint: str): + def socketio_endpoint_room(deployment: str, endpoint: str): """ Endpoint for the socketio room for an endpoint. @@ -64,7 +65,7 @@ class RedisAtlasEndpoints: Returns: str: The endpoint for the socketio room """ - return f"ENDPOINT::{endpoint}" + return f"socketio/rooms/{deployment}/{endpoint}" class RedisRouter(BaseRouter): @@ -333,7 +334,7 @@ class RedisWebsocket: RedisAtlasEndpoints.redis_data(deployment, endpoint), Any, MessageOp.STREAM ) - room = RedisAtlasEndpoints.socketio_endpoint_room(endpoint) + room = RedisAtlasEndpoints.socketio_endpoint_room(deployment, endpoint) self.redis.register(endpoint_info, cb=self.on_redis_message, parent=self, room=room) if endpoint not in self.users[sid]["subscriptions"]: await self.socket.enter_room(sid, room) @@ -348,7 +349,7 @@ class RedisWebsocket: else: msg = message["data"] outgoing = {"data": msg.content, "metadata": msg.metadata} - outgoing = json.dumps(outgoing) + outgoing = json_ext.dumps(outgoing) await parent.socket.emit("message", data=outgoing, room=room) # Run the coroutine in this loop diff --git a/backend/tests/test_redis_websocket.py b/backend/tests/test_redis_websocket.py index d051951..45bb852 100644 --- a/backend/tests/test_redis_websocket.py +++ b/backend/tests/test_redis_websocket.py @@ -2,6 +2,8 @@ import json from unittest import mock import pytest +from bec_atlas.router.redis_router import RedisAtlasEndpoints +from bec_lib.endpoints import MessageEndpoints @pytest.fixture @@ -18,7 +20,7 @@ def backend_client(backend): async def test_redis_websocket_connect(backend_client): client, app = backend_client await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test"}'} + "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} ) assert "sid" in app.redis_websocket.users @@ -35,10 +37,10 @@ async def test_redis_websocket_disconnect(backend_client): async def test_redis_websocket_multiple_connect(backend_client): client, app = backend_client await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid1", {"HTTP_QUERY": '{"user": "test1"}'} + "sid1", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} ) await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid2", {"HTTP_QUERY": '{"user": "test2"}'} + "sid2", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} ) assert "sid1" in app.redis_websocket.users assert "sid2" in app.redis_websocket.users @@ -48,10 +50,10 @@ async def test_redis_websocket_multiple_connect(backend_client): async def test_redis_websocket_multiple_connect_same_sid(backend_client): client, app = backend_client await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test"}'} + "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} ) await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test"}'} + "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} ) assert "sid" in app.redis_websocket.users @@ -84,16 +86,19 @@ async def test_redis_websocket_register(backend_client): client, app = backend_client with mock.patch.object(app.redis_websocket.socket, "emit") as emit: with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room: - with mock.patch.object(app.redis_websocket.socket.manager, "rooms") as rooms: - rooms.__getitem__.return_value = {"ENDPOINT::scan_status": "sid"} - await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test"}'} - ) + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} + ) - await app.redis_websocket.socket.handlers["/"]["register"]( - "sid", json.dumps({"endpoint": "scan_status"}) - ) - assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls - enter_room.assert_called_with("sid", "ENDPOINT::scan_status") + await app.redis_websocket.socket.handlers["/"]["register"]( + "sid", json.dumps({"endpoint": "scan_status"}) + ) + assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls + enter_room.assert_called_with( + "sid", + RedisAtlasEndpoints.socketio_endpoint_room( + "test", MessageEndpoints.scan_status().endpoint + ), + ) assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls diff --git a/backend/tests/test_scan_ingestor.py b/backend/tests/test_scan_ingestor.py index 3ecd830..9fcb1e1 100644 --- a/backend/tests/test_scan_ingestor.py +++ b/backend/tests/test_scan_ingestor.py @@ -1,7 +1,6 @@ import pytest -from bec_lib import messages - from bec_atlas.ingestor.data_ingestor import DataIngestor +from bec_lib import messages @pytest.fixture @@ -77,7 +76,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend): }, timestamp=1732610545.15924, ) - scan_ingestor.update_scan_status(msg) + scan_ingestor.update_scan_status(msg, deployment_id="5cc67967-744d-4115-a46b-13246580cb3f") response = client.post( "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} @@ -94,7 +93,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend): assert out["status"] == "open" msg.status = "closed" - scan_ingestor.update_scan_status(msg) + scan_ingestor.update_scan_status(msg, deployment_id="5cc67967-744d-4115-a46b-13246580cb3f") response = client.get(f"/api/v1/scans/id/{scan_id}") assert response.status_code == 200 out = response.json()