mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
fix: fixed websocket connection and state transfer between redis and fastapi
This commit is contained in:
@ -28,6 +28,9 @@ include:
|
|||||||
services:
|
services:
|
||||||
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/scylladb/scylla:latest
|
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/scylladb/scylla:latest
|
||||||
alias: scylla
|
alias: scylla
|
||||||
|
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/redis:latest
|
||||||
|
alias: redis
|
||||||
|
command: ["redis-server", "--port", "6380"]
|
||||||
|
|
||||||
formatter:
|
formatter:
|
||||||
stage: Formatter
|
stage: Formatter
|
||||||
|
@ -16,6 +16,7 @@ class AtlasApp:
|
|||||||
def __init__(self, config=None):
|
def __init__(self, config=None):
|
||||||
self.config = config or CONFIG
|
self.config = config or CONFIG
|
||||||
self.app = FastAPI()
|
self.app = FastAPI()
|
||||||
|
self.server = None
|
||||||
self.prefix = f"/api/{self.API_VERSION}"
|
self.prefix = f"/api/{self.API_VERSION}"
|
||||||
self.datasources = DatasourceManager(config=self.config)
|
self.datasources = DatasourceManager(config=self.config)
|
||||||
self.register_event_handler()
|
self.register_event_handler()
|
||||||
@ -41,11 +42,16 @@ class AtlasApp:
|
|||||||
self.app.include_router(self.user_router.router)
|
self.app.include_router(self.user_router.router)
|
||||||
|
|
||||||
if "redis" in self.datasources.datasources:
|
if "redis" in self.datasources.datasources:
|
||||||
self.redis_websocket = RedisWebsocket(prefix=self.prefix, datasources=self.datasources)
|
self.redis_websocket = RedisWebsocket(
|
||||||
|
prefix=self.prefix, datasources=self.datasources, app=self
|
||||||
|
)
|
||||||
self.app.mount("/", self.redis_websocket.app)
|
self.app.mount("/", self.redis_websocket.app)
|
||||||
|
|
||||||
def run(self, port=8000):
|
def run(self, port=8000):
|
||||||
uvicorn.run(self.app, host="localhost", port=port)
|
config = uvicorn.Config(self.app, host="localhost", port=port)
|
||||||
|
self.server = uvicorn.Server(config=config)
|
||||||
|
self.server.run()
|
||||||
|
# uvicorn.run(self.app, host="localhost", port=port)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import socketio
|
import socketio
|
||||||
from bec_lib.endpoints import MessageEndpoints
|
from bec_lib.endpoints import MessageEndpoints
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from bec_lib.logger import bec_logger
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from bec_atlas.router.base_router import BaseRouter
|
from bec_atlas.router.base_router import BaseRouter
|
||||||
|
|
||||||
|
logger = bec_logger.logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from bec_lib.redis_connector import RedisConnector
|
from bec_lib.redis_connector import RedisConnector
|
||||||
|
|
||||||
@ -37,52 +42,196 @@ class RedisRouter(BaseRouter):
|
|||||||
return self.redis.delete(key)
|
return self.redis.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_socket(fcn):
|
||||||
|
@functools.wraps(fcn)
|
||||||
|
async def wrapper(self, sid, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
out = await fcn(self, sid, *args, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
content = traceback.format_exc()
|
||||||
|
logger.error(content)
|
||||||
|
await self.socket.emit("error", {"error": str(exc)}, room=sid)
|
||||||
|
return
|
||||||
|
return out
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class BECAsyncRedisManager(socketio.AsyncRedisManager):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
url="redis://localhost:6379/0",
|
||||||
|
channel="socketio",
|
||||||
|
write_only=False,
|
||||||
|
logger=None,
|
||||||
|
redis_options=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
super().__init__(url, channel, write_only, logger, redis_options)
|
||||||
|
self.requested_channels = []
|
||||||
|
self.started_update_loop = False
|
||||||
|
|
||||||
|
# task = asyncio.create_task(self._required_channel_heartbeat())
|
||||||
|
# loop.run_until_complete(task)
|
||||||
|
|
||||||
|
def start_update_loop(self):
|
||||||
|
self.started_update_loop = True
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
task = loop.create_task(self._backend_heartbeat())
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def disconnect(self, sid, namespace, **kwargs):
|
||||||
|
if kwargs.get("ignore_queue"):
|
||||||
|
await super().disconnect(sid, namespace, **kwargs)
|
||||||
|
await self.update_state_info()
|
||||||
|
return
|
||||||
|
message = {
|
||||||
|
"method": "disconnect",
|
||||||
|
"sid": sid,
|
||||||
|
"namespace": namespace or "/",
|
||||||
|
"host_id": self.host_id,
|
||||||
|
}
|
||||||
|
await self._handle_disconnect(message) # handle in this host
|
||||||
|
await self._publish(message) # notify other hosts
|
||||||
|
|
||||||
|
async def enter_room(self, sid, namespace, room, eio_sid=None):
|
||||||
|
await super().enter_room(sid, namespace, room, eio_sid=eio_sid)
|
||||||
|
await self.update_state_info()
|
||||||
|
|
||||||
|
async def leave_room(self, sid, namespace, room):
|
||||||
|
await super().leave_room(sid, namespace, room)
|
||||||
|
await self.update_state_info()
|
||||||
|
|
||||||
|
async def _backend_heartbeat(self):
|
||||||
|
while not self.parent.fastapi_app.server.should_exit:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await self.redis.publish(f"deployments/{self.host_id}/heartbeat/", "ping")
|
||||||
|
data = json.dumps(self.parent.users)
|
||||||
|
print(f"Sending heartbeat: {data}")
|
||||||
|
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
|
||||||
|
|
||||||
|
async def update_state_info(self):
|
||||||
|
data = json.dumps(self.parent.users)
|
||||||
|
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
|
||||||
|
await self.redis.publish(f"deployments/{self.host_id}/state/", data)
|
||||||
|
|
||||||
|
async def update_websocket_states(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if not self.started_update_loop and loop.is_running():
|
||||||
|
self.start_update_loop()
|
||||||
|
await self.update_state_info()
|
||||||
|
|
||||||
|
async def remove_user(self, sid):
|
||||||
|
if sid in self.parent.users:
|
||||||
|
del self.parent.users[sid]
|
||||||
|
print(f"Removed user {sid}")
|
||||||
|
await self.update_state_info()
|
||||||
|
|
||||||
|
|
||||||
class RedisWebsocket:
|
class RedisWebsocket:
|
||||||
"""
|
"""
|
||||||
This class is a websocket handler for the Redis API. It exposes the redis client through
|
This class is a websocket handler for the Redis API. It exposes the redis client through
|
||||||
the websocket.
|
the websocket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None, app=None):
|
||||||
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
self.fastapi_app = app
|
||||||
self.active_connections = set()
|
self.active_connections = set()
|
||||||
self.socket = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi")
|
self.socket = socketio.AsyncServer(
|
||||||
|
cors_allowed_origins="*",
|
||||||
|
async_mode="asgi",
|
||||||
|
client_manager=BECAsyncRedisManager(
|
||||||
|
self, url=f"redis://{self.redis.host}:{self.redis.port}/0"
|
||||||
|
),
|
||||||
|
)
|
||||||
self.app = socketio.ASGIApp(self.socket)
|
self.app = socketio.ASGIApp(self.socket)
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.users = {}
|
||||||
|
|
||||||
self.socket.on("connect", self.connect_client)
|
self.socket.on("connect", self.connect_client)
|
||||||
self.socket.on("register", self.redis_register)
|
self.socket.on("register", self.redis_register)
|
||||||
|
self.socket.on("unregister", self.redis.unregister)
|
||||||
self.socket.on("disconnect", self.disconnect_client)
|
self.socket.on("disconnect", self.disconnect_client)
|
||||||
|
|
||||||
def connect_client(self, sid, environ=None):
|
@safe_socket
|
||||||
|
async def connect_client(self, sid, environ=None):
|
||||||
print("Client connected")
|
print("Client connected")
|
||||||
self.active_connections.add(sid)
|
http_query = environ.get("HTTP_QUERY")
|
||||||
|
if not http_query:
|
||||||
|
raise ValueError("Query parameters not found")
|
||||||
|
query = json.loads(http_query)
|
||||||
|
|
||||||
def disconnect_client(self, sid, _environ=None):
|
if "user" not in query:
|
||||||
|
raise ValueError("User not found in query parameters")
|
||||||
|
user = query["user"]
|
||||||
|
|
||||||
|
if sid not in self.users:
|
||||||
|
# check if the user was already registered in redis
|
||||||
|
deployment_keys = await self.socket.manager.redis.keys("deployments/*/state/")
|
||||||
|
if not deployment_keys:
|
||||||
|
state_data = []
|
||||||
|
else:
|
||||||
|
state_data = await self.socket.manager.redis.mget(*deployment_keys)
|
||||||
|
info = {}
|
||||||
|
for data in state_data:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
obj = json.loads(data)
|
||||||
|
for value in obj.values():
|
||||||
|
info[value["user"]] = value["subscriptions"]
|
||||||
|
|
||||||
|
if user in info:
|
||||||
|
self.users[sid] = {"user": user, "subscriptions": info[user]}
|
||||||
|
for endpoint in set(self.users[sid]["subscriptions"]):
|
||||||
|
await self.socket.enter_room(sid, f"ENDPOINT::{endpoint}")
|
||||||
|
await self.socket.manager.update_websocket_states()
|
||||||
|
else:
|
||||||
|
self.users[sid] = {"user": query["user"], "subscriptions": []}
|
||||||
|
|
||||||
|
await self.socket.manager.update_websocket_states()
|
||||||
|
|
||||||
|
async def disconnect_client(self, sid, _environ=None):
|
||||||
print("Client disconnected")
|
print("Client disconnected")
|
||||||
self.active_connections.remove(sid)
|
is_exit = self.fastapi_app.server.should_exit
|
||||||
|
if is_exit:
|
||||||
|
return
|
||||||
|
await self.socket.manager.remove_user(sid)
|
||||||
|
|
||||||
|
@safe_socket
|
||||||
async def redis_register(self, sid: str, msg: str):
|
async def redis_register(self, sid: str, msg: str):
|
||||||
|
"""
|
||||||
|
Register a client to a redis channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sid (str): The socket id of the client
|
||||||
|
msg (str): The message sent by the client
|
||||||
|
"""
|
||||||
if sid not in self.active_connections:
|
if sid not in self.active_connections:
|
||||||
self.active_connections.add(sid)
|
self.active_connections.add(sid)
|
||||||
try:
|
try:
|
||||||
print(msg)
|
print(msg)
|
||||||
data = json.loads(msg)
|
data = json.loads(msg)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return
|
raise ValueError("Invalid JSON message")
|
||||||
|
|
||||||
endpoint = getattr(MessageEndpoints, data.get("endpoint"))
|
endpoint = getattr(MessageEndpoints, data.get("endpoint"), None)
|
||||||
|
if endpoint is None:
|
||||||
|
raise ValueError(f"Endpoint {data.get('endpoint')} not found")
|
||||||
|
|
||||||
# check if the endpoint receives arguments
|
# check if the endpoint receives arguments
|
||||||
if len(inspect.signature(endpoint).parameters) > 1:
|
if len(inspect.signature(endpoint).parameters) > 0:
|
||||||
endpoint = endpoint(data.get("args"))
|
endpoint = endpoint(data.get("args"))
|
||||||
else:
|
else:
|
||||||
endpoint = endpoint()
|
endpoint = endpoint()
|
||||||
|
|
||||||
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
|
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
|
||||||
await self.socket.enter_room(sid, endpoint.endpoint)
|
if data.get("endpoint") not in self.users[sid]["subscriptions"]:
|
||||||
await self.socket.emit("registered", data={"endpoint": endpoint.endpoint}, room=sid)
|
await self.socket.enter_room(sid, f"ENDPOINT::{data.get('endpoint')}")
|
||||||
|
self.users[sid]["subscriptions"].append(data.get("endpoint"))
|
||||||
|
await self.socket.manager.update_websocket_states()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def on_redis_message(message, parent):
|
def on_redis_message(message, parent):
|
||||||
|
@ -3,6 +3,9 @@ import os
|
|||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from bec_atlas.main import AtlasApp
|
||||||
|
from bec_atlas.utils.setup_database import setup_database
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
from pytest_docker.plugin import DockerComposeExecutor, Services
|
from pytest_docker.plugin import DockerComposeExecutor, Services
|
||||||
|
|
||||||
|
|
||||||
@ -96,3 +99,47 @@ def docker_services(
|
|||||||
docker_cleanup,
|
docker_cleanup,
|
||||||
) as docker_service:
|
) as docker_service:
|
||||||
yield docker_service
|
yield docker_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def scylla_container(docker_ip, docker_services):
|
||||||
|
host = docker_ip
|
||||||
|
if os.path.exists("/.dockerenv"):
|
||||||
|
# if we are running in the CI, scylla was started as 'scylla' service
|
||||||
|
host = "scylla"
|
||||||
|
if docker_services is None:
|
||||||
|
port = 9042
|
||||||
|
else:
|
||||||
|
port = docker_services.port_for("scylla", 9042)
|
||||||
|
|
||||||
|
setup_database(host=host, port=port)
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def redis_container(docker_ip, docker_services):
|
||||||
|
host = docker_ip
|
||||||
|
if os.path.exists("/.dockerenv"):
|
||||||
|
# if we are running in the CI, scylla was started as 'scylla' service
|
||||||
|
host = "redis"
|
||||||
|
if docker_services is None:
|
||||||
|
port = 6380
|
||||||
|
else:
|
||||||
|
port = docker_services.port_for("redis", 6379)
|
||||||
|
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def backend(scylla_container, redis_container):
|
||||||
|
scylla_host, scylla_port = scylla_container
|
||||||
|
redis_host, redis_port = redis_container
|
||||||
|
config = {
|
||||||
|
"scylla": {"hosts": [(scylla_host, scylla_port)]},
|
||||||
|
"redis": {"host": redis_host, "port": redis_port},
|
||||||
|
}
|
||||||
|
|
||||||
|
app = AtlasApp(config)
|
||||||
|
|
||||||
|
with TestClient(app.app) as _client:
|
||||||
|
yield _client, app
|
||||||
|
@ -3,4 +3,8 @@ services:
|
|||||||
scylla:
|
scylla:
|
||||||
image: scylladb/scylla:latest
|
image: scylladb/scylla:latest
|
||||||
ports:
|
ports:
|
||||||
- "9070:9042"
|
- "9070:9042"
|
||||||
|
redis:
|
||||||
|
image: redis:latest
|
||||||
|
ports:
|
||||||
|
- "6380:6379"
|
@ -1,41 +1,18 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from bec_atlas.main import AtlasApp
|
|
||||||
from bec_atlas.utils.setup_database import setup_database
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture
|
||||||
def scylla_container(docker_ip, docker_services):
|
def backend_client(backend):
|
||||||
host = docker_ip
|
client, _ = backend
|
||||||
if os.path.exists("/.dockerenv"):
|
return client
|
||||||
# if we are running in the CI, scylla was started as 'scylla' service
|
|
||||||
host = "scylla"
|
|
||||||
if docker_services is None:
|
|
||||||
port = 9042
|
|
||||||
else:
|
|
||||||
port = docker_services.port_for("scylla", 9042)
|
|
||||||
|
|
||||||
setup_database(host=host, port=port)
|
|
||||||
return host, port
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def client(scylla_container):
|
|
||||||
host, port = scylla_container
|
|
||||||
config = {"scylla": {"hosts": [(host, port)]}}
|
|
||||||
|
|
||||||
with TestClient(AtlasApp(config).app) as _client:
|
|
||||||
yield _client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_login(client):
|
def test_login(backend_client):
|
||||||
"""
|
"""
|
||||||
Test that the login endpoint returns a token.
|
Test that the login endpoint returns a token.
|
||||||
"""
|
"""
|
||||||
response = client.post(
|
response = backend_client.post(
|
||||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@ -45,11 +22,11 @@ def test_login(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_login_wrong_password(client):
|
def test_login_wrong_password(backend_client):
|
||||||
"""
|
"""
|
||||||
Test that the login returns a 401 when the password is wrong.
|
Test that the login returns a 401 when the password is wrong.
|
||||||
"""
|
"""
|
||||||
response = client.post(
|
response = backend_client.post(
|
||||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "wrong_password"}
|
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "wrong_password"}
|
||||||
)
|
)
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@ -57,11 +34,11 @@ def test_login_wrong_password(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_login_unknown_user(client):
|
def test_login_unknown_user(backend_client):
|
||||||
"""
|
"""
|
||||||
Test that the login returns a 404 when the user is unknown.
|
Test that the login returns a 404 when the user is unknown.
|
||||||
"""
|
"""
|
||||||
response = client.post(
|
response = backend_client.post(
|
||||||
"/api/v1/user/login",
|
"/api/v1/user/login",
|
||||||
json={"username": "no_user@bec_atlas.ch", "password": "wrong_password"},
|
json={"username": "no_user@bec_atlas.ch", "password": "wrong_password"},
|
||||||
)
|
)
|
||||||
|
99
backend/tests/test_redis_websocket.py
Normal file
99
backend/tests/test_redis_websocket.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import json
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backend_client(backend):
|
||||||
|
client, app = backend
|
||||||
|
app.server = mock.Mock()
|
||||||
|
app.server.should_exit = False
|
||||||
|
app.redis_websocket.users = {}
|
||||||
|
yield client, app
|
||||||
|
app.redis_websocket.redis._redis_conn.flushall()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_connect(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid", {"HTTP_QUERY": '{"user": "test"}'}
|
||||||
|
)
|
||||||
|
assert "sid" in app.redis_websocket.users
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_disconnect(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||||
|
assert "sid" not in app.redis_websocket.users
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_multiple_connect(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid1", {"HTTP_QUERY": '{"user": "test1"}'}
|
||||||
|
)
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid2", {"HTTP_QUERY": '{"user": "test2"}'}
|
||||||
|
)
|
||||||
|
assert "sid1" in app.redis_websocket.users
|
||||||
|
assert "sid2" in app.redis_websocket.users
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_multiple_connect_same_sid(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid", {"HTTP_QUERY": '{"user": "test"}'}
|
||||||
|
)
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid", {"HTTP_QUERY": '{"user": "test"}'}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "sid" in app.redis_websocket.users
|
||||||
|
assert len(app.redis_websocket.users) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||||
|
assert "sid" not in app.redis_websocket.users
|
||||||
|
assert len(app.redis_websocket.users) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
||||||
|
app.redis_websocket.socket.handlers["/"]["connect"]("sid")
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["register"](
|
||||||
|
"sid", json.dumps({"endpoint": "wrong_endpoint"})
|
||||||
|
)
|
||||||
|
assert mock.call("error", mock.ANY, room="sid") in emit.mock_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_redis_websocket_register(backend_client):
|
||||||
|
client, app = backend_client
|
||||||
|
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
||||||
|
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
|
||||||
|
with mock.patch.object(app.redis_websocket.socket.manager, "rooms") as rooms:
|
||||||
|
rooms.__getitem__.return_value = {"ENDPOINT::scan_status": "sid"}
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||||
|
"sid", {"HTTP_QUERY": '{"user": "test"}'}
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.redis_websocket.socket.handlers["/"]["register"](
|
||||||
|
"sid", json.dumps({"endpoint": "scan_status"})
|
||||||
|
)
|
||||||
|
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls
|
||||||
|
enter_room.assert_called_with("sid", "ENDPOINT::scan_status")
|
||||||
|
|
||||||
|
assert mock.call("error", mock.ANY, room="sid") not in emit.mock_calls
|
Reference in New Issue
Block a user