fix: fixed websocket connection and state transfer between redis and fastapi

This commit is contained in:
2024-11-22 23:07:34 +01:00
parent fc7d4b8fd5
commit c8a41565a2
7 changed files with 333 additions and 48 deletions

View File

@ -3,6 +3,9 @@ import os
from typing import Iterator
import pytest
from bec_atlas.main import AtlasApp
from bec_atlas.utils.setup_database import setup_database
from fastapi.testclient import TestClient
from pytest_docker.plugin import DockerComposeExecutor, Services
@ -96,3 +99,47 @@ def docker_services(
docker_cleanup,
) as docker_service:
yield docker_service
@pytest.fixture(scope="session")
def scylla_container(docker_ip, docker_services):
host = docker_ip
if os.path.exists("/.dockerenv"):
# if we are running in the CI, scylla was started as 'scylla' service
host = "scylla"
if docker_services is None:
port = 9042
else:
port = docker_services.port_for("scylla", 9042)
setup_database(host=host, port=port)
return host, port
@pytest.fixture(scope="session")
def redis_container(docker_ip, docker_services):
host = docker_ip
if os.path.exists("/.dockerenv"):
# if we are running in the CI, scylla was started as 'scylla' service
host = "redis"
if docker_services is None:
port = 6380
else:
port = docker_services.port_for("redis", 6379)
return host, port
@pytest.fixture(scope="session")
def backend(scylla_container, redis_container):
scylla_host, scylla_port = scylla_container
redis_host, redis_port = redis_container
config = {
"scylla": {"hosts": [(scylla_host, scylla_port)]},
"redis": {"host": redis_host, "port": redis_port},
}
app = AtlasApp(config)
with TestClient(app.app) as _client:
yield _client, app

View File

@ -3,4 +3,8 @@ services:
scylla:
image: scylladb/scylla:latest
ports:
- "9070:9042"
- "9070:9042"
redis:
image: redis:latest
ports:
- "6380:6379"

View File

@ -1,41 +1,18 @@
import os
import pytest
from bec_atlas.main import AtlasApp
from bec_atlas.utils.setup_database import setup_database
from fastapi.testclient import TestClient
@pytest.fixture(scope="session")
def scylla_container(docker_ip, docker_services):
host = docker_ip
if os.path.exists("/.dockerenv"):
# if we are running in the CI, scylla was started as 'scylla' service
host = "scylla"
if docker_services is None:
port = 9042
else:
port = docker_services.port_for("scylla", 9042)
setup_database(host=host, port=port)
return host, port
@pytest.fixture(scope="session")
def client(scylla_container):
host, port = scylla_container
config = {"scylla": {"hosts": [(host, port)]}}
with TestClient(AtlasApp(config).app) as _client:
yield _client
@pytest.fixture
def backend_client(backend):
client, _ = backend
return client
@pytest.mark.timeout(60)
def test_login(client):
def test_login(backend_client):
"""
Test that the login endpoint returns a token.
"""
response = client.post(
response = backend_client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
assert response.status_code == 200
@ -45,11 +22,11 @@ def test_login(client):
@pytest.mark.timeout(60)
def test_login_wrong_password(client):
def test_login_wrong_password(backend_client):
"""
Test that the login returns a 401 when the password is wrong.
"""
response = client.post(
response = backend_client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "wrong_password"}
)
assert response.status_code == 401
@ -57,11 +34,11 @@ def test_login_wrong_password(client):
@pytest.mark.timeout(60)
def test_login_unknown_user(client):
def test_login_unknown_user(backend_client):
"""
Test that the login returns a 404 when the user is unknown.
"""
response = client.post(
response = backend_client.post(
"/api/v1/user/login",
json={"username": "no_user@bec_atlas.ch", "password": "wrong_password"},
)

View File

@ -0,0 +1,99 @@
import json
from unittest import mock
import pytest
@pytest.fixture
def backend_client(backend):
client, app = backend
app.server = mock.Mock()
app.server.should_exit = False
app.redis_websocket.users = {}
yield client, app
app.redis_websocket.redis._redis_conn.flushall()
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_connect(backend_client):
client, app = backend_client
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test"}'}
)
assert "sid" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_disconnect(backend_client):
client, app = backend_client
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect(backend_client):
client, app = backend_client
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid1", {"HTTP_QUERY": '{"user": "test1"}'}
)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid2", {"HTTP_QUERY": '{"user": "test2"}'}
)
assert "sid1" in app.redis_websocket.users
assert "sid2" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect_same_sid(backend_client):
client, app = backend_client
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test"}'}
)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test"}'}
)
assert "sid" in app.redis_websocket.users
assert len(app.redis_websocket.users) == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
client, app = backend_client
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users
assert len(app.redis_websocket.users) == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
client, app = backend_client
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
app.redis_websocket.socket.handlers["/"]["connect"]("sid")
await app.redis_websocket.socket.handlers["/"]["register"](
"sid", json.dumps({"endpoint": "wrong_endpoint"})
)
assert mock.call("error", mock.ANY, room="sid") in emit.mock_calls
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_register(backend_client):
client, app = backend_client
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
with mock.patch.object(app.redis_websocket.socket.manager, "rooms") as rooms:
rooms.__getitem__.return_value = {"ENDPOINT::scan_status": "sid"}
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test"}'}
)
await app.redis_websocket.socket.handlers["/"]["register"](
"sid", json.dumps({"endpoint": "scan_status"})
)
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls
enter_room.assert_called_with("sid", "ENDPOINT::scan_status")
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls