mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
import json
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
from bec_lib.endpoints import MessageEndpoints
|
|
|
|
from bec_atlas.router.redis_router import RedisAtlasEndpoints, RemoteAccess
|
|
import pytest_asyncio
|
|
|
|
|
|
@pytest.fixture
|
|
def backend_client(backend):
|
|
client, app = backend
|
|
app.server = mock.Mock()
|
|
app.server.should_exit = False
|
|
app.redis_websocket.users = {}
|
|
response = client.post(
|
|
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
|
)
|
|
assert response.status_code == 200
|
|
token = response.json()
|
|
assert isinstance(token, str)
|
|
assert len(token) > 20
|
|
return client, app
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def connected_ws(backend_client):
|
|
client, app = backend_client
|
|
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
|
|
with mock.patch.object(app.redis_router, "get_access", return_value=RemoteAccess.READ):
|
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
|
"sid",
|
|
{
|
|
"HTTP_QUERY": json.dumps({"deployment": deployment[0]["_id"]}),
|
|
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
|
},
|
|
)
|
|
yield backend_client
|
|
|
|
|
|
async def test_redis_websocket_connect(connected_ws):
|
|
_, app = connected_ws
|
|
assert "sid" in app.redis_websocket.users
|
|
|
|
|
|
async def test_redis_websocket_disconnect(connected_ws):
|
|
_, app = connected_ws
|
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
|
assert "sid" not in app.redis_websocket.users
|
|
|
|
|
|
async def test_redis_websocket_multiple_connect(connected_ws):
|
|
client, app = connected_ws
|
|
|
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
|
"sid2",
|
|
{
|
|
"HTTP_QUERY": json.dumps(
|
|
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
|
|
),
|
|
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
|
},
|
|
)
|
|
|
|
assert "sid" in app.redis_websocket.users
|
|
assert "sid2" in app.redis_websocket.users
|
|
|
|
|
|
async def test_redis_websocket_multiple_connect_same_sid(connected_ws):
|
|
client, app = connected_ws
|
|
|
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
|
"sid",
|
|
{
|
|
"HTTP_QUERY": json.dumps(
|
|
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
|
|
),
|
|
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
|
},
|
|
)
|
|
|
|
assert "sid" in app.redis_websocket.users
|
|
assert len(app.redis_websocket.users) == 1
|
|
|
|
|
|
async def test_redis_websocket_multiple_disconnect_same_sid(connected_ws):
|
|
_, app = connected_ws
|
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
|
assert "sid" not in app.redis_websocket.users
|
|
assert len(app.redis_websocket.users) == 0
|
|
|
|
|
|
async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
|
|
_, app = backend_client
|
|
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
|
await app.redis_websocket.socket.handlers["/"]["connect"]("sid")
|
|
await app.redis_websocket.socket.handlers["/"]["register"](
|
|
"sid", json.dumps({"endpoint": "wrong_endpoint"})
|
|
)
|
|
assert mock.call("error", mock.ANY, room="sid") in emit.mock_calls
|
|
|
|
|
|
async def test_redis_websocket_register(connected_ws):
|
|
_, app = connected_ws
|
|
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
|
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
|
|
await app.redis_websocket.socket.handlers["/"]["register"](
|
|
"sid", json.dumps({"endpoint": "scan_status"})
|
|
)
|
|
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls
|
|
enter_room.assert_called_with(
|
|
"sid",
|
|
RedisAtlasEndpoints.socketio_endpoint_room(
|
|
app.redis_websocket.users["sid"]["deployment"],
|
|
MessageEndpoints.scan_status().endpoint,
|
|
),
|
|
)
|
|
|
|
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls
|