diff --git a/backend/bec_atlas/router/base_router.py b/backend/bec_atlas/router/base_router.py index c8f4d8b..1b44150 100644 --- a/backend/bec_atlas/router/base_router.py +++ b/backend/bec_atlas/router/base_router.py @@ -1,10 +1,16 @@ +from __future__ import annotations + from functools import lru_cache +from typing import TYPE_CHECKING from bec_atlas.model.model import User +if TYPE_CHECKING: # pragma: no cover + from bec_atlas.datasources.datasource_manager import DatasourceManager + class BaseRouter: - def __init__(self, prefix: str = "/api/v1", datasources=None) -> None: + def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None: self.datasources = datasources self.prefix = prefix diff --git a/backend/bec_atlas/router/deployment_access_router.py b/backend/bec_atlas/router/deployment_access_router.py index bb7766b..0412e6b 100644 --- a/backend/bec_atlas/router/deployment_access_router.py +++ b/backend/bec_atlas/router/deployment_access_router.py @@ -118,6 +118,12 @@ class DeploymentAccessRouter(BaseRouter): + original.su_read_access + original.su_write_access ) + for profile in new_profiles: + # check if the user exists + user = self._is_valid_user(profile) + if not user: + raise HTTPException(status_code=400, detail=f"User {profile} does not exist") + removed_profiles = old_profiles - new_profiles for profile in removed_profiles: db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id}) @@ -174,6 +180,20 @@ class DeploymentAccessRouter(BaseRouter): redis.connector.set_and_publish(endpoint_info, MsgpackSerialization.dumps(profiles)) + def _is_valid_user(self, user: str) -> bool: + """ + Check if the user exists. + + Args: + user (str): The user's email + + Returns: + bool: True if the user exists, False otherwise + """ + db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + user = db.find_one("users", {"email": user}, User) + return user is not None + def _get_redis_access_profile(self, access_profile: str, username: str, deployment_id: str): """ Get the redis access profile. diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 6d66c7d..1f828e9 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import enum import functools @@ -5,27 +7,32 @@ import inspect import json import traceback import uuid -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import socketio +from bec_lib import messages from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.logger import bec_logger from bec_lib.serialization import MsgpackSerialization, json_ext from bson import ObjectId -from fastapi import APIRouter, Depends, Query, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync -from bec_atlas.model.model import DeploymentAccess, User +from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User from bec_atlas.router.base_router import BaseRouter logger = bec_logger.logger -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from bec_lib.redis_connector import RedisConnector + from bec_atlas.datasources.datasource_manager import DatasourceManager + from bec_atlas.main import AtlasApp + class RemoteAccess(enum.Enum): READ = "read" + WRITE = "write" READ_WRITE = "read_write" NONE = "none" @@ -132,9 +139,10 @@ class RedisRouter(BaseRouter): the API. For pub/sub and stream operations, a websocket connection can be used. """ - def __init__(self, prefix="/api/v1", datasources=None): + def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None): super().__init__(prefix, datasources) self.redis = self.datasources.datasources["redis"].async_connector + self.db = self.datasources.datasources["mongodb"] self.router = APIRouter(prefix=prefix) self.router.add_api_route( @@ -147,6 +155,7 @@ class RedisRouter(BaseRouter): async def redis_get( self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user) ): + self.validate_user_bec_access(current_user, deployment, key, "get", "read") request_id = uuid.uuid4().hex response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id) request_endpoint = RedisAtlasEndpoints.redis_request(deployment) @@ -158,19 +167,206 @@ class RedisRouter(BaseRouter): response = await pubsub.get_message(timeout=10) print(response) response = await pubsub.get_message(timeout=10) + if response is None: + return json_ext.dumps({"error": "Timeout waiting for response"}) out = MsgpackSerialization.loads(response["data"]) - return json_ext.dumps({"data": out.content, "metadata": out.metadata}) @convert_to_user async def redis_post( - self, key: str, value: str, current_user: User = Depends(get_current_user) + self, + deployment: str, + key: str, + value: dict, + redis_op: Literal["send", "set_and_publish", "lpush", "rpush", "set", "xadd"], + msg_type: str, + current_user: User = Depends(get_current_user), ): - return self.redis.set(key, value) + """ + Send a message to the BEC instance of the specified deployment. + + Args: + deployment (str): The deployment id + key (str): The key in Redis + value (dict): The value to send + redis_op (str): The operation to perform + msg_type (str): The message type + current_user (User): The current user + """ + self.validate_user_bec_access(current_user, deployment, key, redis_op, "write") + msg_type = getattr(messages, msg_type, None) + if msg_type is None: + raise HTTPException(status_code=400, detail="Invalid message type") + try: + msg = msg_type(**value) + except TypeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + # msg_dump = MsgpackSerialization.dumps(msg) + data = MsgpackSerialization.dumps( + messages.RawMessage(data={"action": redis_op, "key": key, "value": msg}) + ) + request_endpoint = RedisAtlasEndpoints.redis_request(deployment) + pubsub = self.redis.pubsub() + pubsub.ignore_subscribe_messages = True + await self.redis.publish(request_endpoint, data) + return {"status": "success"} @convert_to_user - async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)): - return self.redis.delete(key) + async def redis_delete( + self, deployment: str, key: str, current_user: User = Depends(get_current_user) + ): + request_endpoint = RedisAtlasEndpoints.redis_request(deployment) + pubsub = self.redis.pubsub() + pubsub.ignore_subscribe_messages = True + data = {"action": "delete", "key": key} + await self.redis.publish(request_endpoint, data) + + def validate_user_bec_access( + self, + user: User, + deployment: str, + key: str, + redis_op: str, + operation_type: Literal["read", "write"], + ): + """ + Validate the user access to the specified key and operation for the specified deployment. + + Args: + user (User): The user object + deployment (str): The deployment name + key (str): The key in Redis + operation_type (str): The operation type (read or write) + + Raises: + HTTPException: If the user does not have access to the key + """ + deployment_access = self.db.find_one( + "deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess + ) + if not deployment_access: + raise ValueError("Deployment not found") + access = self.get_access(user, deployment_access) + + # check if the user has access to the deployment + if access == RemoteAccess.NONE: + raise HTTPException( + status_code=403, detail="User does not have remote access to the deployment" + ) + if operation_type == "read": + if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have read access") + elif operation_type == "write": + if access != RemoteAccess.READ_WRITE: + raise HTTPException(status_code=403, detail="User does not have write access") + else: + raise ValueError("Invalid operation type") + + # check if the user has access to the key + bec_access = self.db.find_one( + "bec_access_profiles", + {"deployment_id": deployment, "username": {"$in": [user.email, user.username]}}, + BECAccessProfile, + user=user, + ) + if not bec_access: + raise HTTPException(status_code=403, detail="User does not have access to the key") + + self.bec_access_profile_allows_op(bec_access, key, redis_op) + + def bec_access_profile_allows_op(self, bec_access: BECAccessProfile, key: str, redis_op: str): + """ + Check if the BEC access profile allows the operation on the key. + + Args: + bec_access (BECAccessProfile): The BEC access profile + key (str): The key in Redis + redis_op (str): The operation to perform + """ + if redis_op in ["lpush", "rpush", "set", "xadd", "delete"]: + access = self.get_key_pattern_access(key, bec_access.keys) + if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have access to the key") + elif redis_op == "send": + access = self.get_channel_pattern_access(key, bec_access.channels) + if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have access to the key") + elif redis_op == "set_and_publish": + access = self.get_key_pattern_access(key, bec_access.keys) + if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have access to the key") + access = self.get_channel_pattern_access(key, bec_access.channels) + if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have access to the key") + elif redis_op == "get": + access = self.get_key_pattern_access(key, bec_access.keys) + if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]: + raise HTTPException(status_code=403, detail="User does not have access to the key") + else: + raise ValueError("Invalid operation") + + @staticmethod + def get_key_pattern_access(key: str, patterns: list[str]) -> bool: + """ + Check if the key matches the pattern. + + Args: + key (str): The key + patterns (list[str]): The patterns + + Returns: + bool: True if the key matches the pattern, False otherwise + """ + if "*" in patterns: + return RemoteAccess.READ_WRITE + for pattern in patterns: + components = pattern.split("~") + rule = components[0] + subpattern = "".join(components[1:]).split("*", maxsplit=1)[0] + if subpattern in key: + if rule == "%R": + return RemoteAccess.READ + if rule == "%W": + return RemoteAccess.WRITE + if rule == "%RW": + return RemoteAccess.READ_WRITE + return RemoteAccess.NONE + + @staticmethod + def get_channel_pattern_access(channel: str, patterns: list[str]) -> bool: + """ + Check if the channel matches the pattern. + + Args: + channel (str): The channel + patterns (list[str]): The patterns + + Returns: + bool: True if the channel matches the pattern, False otherwise + """ + for pattern in patterns: + prefix = pattern.split("*")[0] + if prefix in channel: + return RemoteAccess.READ_WRITE + return RemoteAccess.NONE + + @staticmethod + def get_access(user: User, deployment_access: DeploymentAccess) -> RemoteAccess: + """ + Get the access level of the user to the deployment. + """ + access = RemoteAccess.NONE + groups = set(user.groups) + if user.username is not None: + groups.add(user.username) + if user.email is not None: + groups.add(user.email) + + if groups & set(deployment_access.remote_read_access): + access = RemoteAccess.READ + if groups & set(deployment_access.remote_write_access): + access = RemoteAccess.READ_WRITE + return access def safe_socket(fcn): @@ -277,10 +473,11 @@ class RedisWebsocket: the websocket. """ - def __init__(self, prefix="/api/v1", datasources=None, app=None): + def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None): self.redis: RedisConnector = datasources.datasources["redis"].connector self.prefix = prefix self.fastapi_app = app + self.redis_router = app.redis_router self.active_connections = set() redis_host = datasources.datasources["redis"].config["host"] redis_port = datasources.datasources["redis"].config["port"] @@ -336,34 +533,17 @@ class RedisWebsocket: raise ValueError("Deployment not found in query parameters") deployment_access = self.db.find_one( - "deployments", {"_id": ObjectId(deployment)}, DeploymentAccess + "deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess ) if not deployment_access: raise ValueError("Deployment not found") - access = self.get_access(user, deployment_access) + access = self.redis_router.get_access(user, deployment_access) if access == RemoteAccess.NONE: raise ValueError("User does not have remote access to the deployment") return user, deployment, access - def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess: - """ - Get the access level of the user to the deployment. - """ - access = RemoteAccess.NONE - groups = set(user.groups) - if user.username is not None: - groups.add(user.username) - if user.email is not None: - groups.add(user.email) - - if groups & set(deployment_access.remote_read_access): - access = RemoteAccess.READ - if groups & set(deployment_access.remote_write_access): - access = RemoteAccess.READ_WRITE - return access - @safe_socket async def connect_client(self, sid, environ=None, auth=None, **kwargs): if sid in self.users: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index f24ffdf..44a4d5a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -86,6 +86,11 @@ def backend(redis_server): return fakeredis.FakeStrictRedis(server=redis_server) mongo_client = mongomock.MongoClient("localhost", 27027) + fake_async_redis = fakeredis.FakeAsyncRedis( + server=redis_server, username="ingestor", password="ingestor" + ) + fake_async_redis.connection_pool.connection_kwargs["username"] = "ingestor" + fake_async_redis.connection_pool.connection_kwargs["password"] = "ingestor" config = { "redis": { @@ -94,7 +99,7 @@ def backend(redis_server): "username": "ingestor", "password": "ingestor", "sync_instance": RedisConnector("localhost:1", redis_cls=_fake_redis), - "async_instance": fakeredis.FakeAsyncRedis(server=redis_server), + "async_instance": fake_async_redis, }, "mongodb": {"host": "localhost", "port": 27027, "mongodb_client": mongo_client}, } @@ -105,13 +110,7 @@ def backend(redis_server): class PatchedBECAsyncRedisManager(BECAsyncRedisManager): def _redis_connect(self): - self.redis = fakeredis.FakeAsyncRedis( - server=redis_server, - username=config["redis"]["username"], - password=config["redis"]["password"], - ) - self.redis.connection_pool.connection_kwargs["username"] = config["redis"]["username"] - self.redis.connection_pool.connection_kwargs["password"] = config["redis"]["password"] + self.redis = fake_async_redis self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) with mock.patch( diff --git a/backend/tests/test_deployment_access_router.py b/backend/tests/test_deployment_access_router.py index 3158d0b..98fe1fd 100644 --- a/backend/tests/test_deployment_access_router.py +++ b/backend/tests/test_deployment_access_router.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest from bec_atlas.model.model import DeploymentAccess @@ -42,10 +44,11 @@ def test_deployment_access_router(logged_in_client): out = DeploymentAccess(**out) -def test_patch_deployment_access(logged_in_client): +def test_patch_deployment_access(logged_in_client, backend): """ Test that the deployment access endpoint returns a 200 when the deployment id is valid. """ + _, app = backend deployments = logged_in_client.get( "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} ).json() @@ -58,18 +61,19 @@ def test_patch_deployment_access(logged_in_client): out = response.json() out = DeploymentAccess(**out) - response = logged_in_client.patch( - "/api/v1/deployment_access", - params={"deployment_id": deployment_id}, - json={ - "user_read_access": ["test1"], - "user_write_access": ["test2"], - "su_read_access": ["test3"], - "su_write_access": ["test4"], - "remote_read_access": ["test5"], - "remote_write_access": ["test6"], - }, - ) + with mock.patch.object(app.deployment_access_router, "_is_valid_user", return_value=True): + response = logged_in_client.patch( + "/api/v1/deployment_access", + params={"deployment_id": deployment_id}, + json={ + "user_read_access": ["test1"], + "user_write_access": ["test2"], + "su_read_access": ["test3"], + "su_write_access": ["test4"], + "remote_read_access": ["test5"], + "remote_write_access": ["test6"], + }, + ) assert response.status_code == 200 out = response.json() out = DeploymentAccess(**out) diff --git a/backend/tests/test_redis_router.py b/backend/tests/test_redis_router.py new file mode 100644 index 0000000..ed67d9e --- /dev/null +++ b/backend/tests/test_redis_router.py @@ -0,0 +1,370 @@ +import asyncio +from unittest import mock + +import pytest +from bec_lib import messages +from bec_lib.serialization import MsgpackSerialization + +from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User +from bec_atlas.router.redis_router import RemoteAccess + + +@pytest.fixture +def logged_in_client(backend): + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + return client + + +@pytest.fixture +def deployment(logged_in_client): + client = logged_in_client + response = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}) + assert response.status_code == 200 + return response.json()[0] + + +@pytest.mark.parametrize( + "key, patterns, access", + [ + ( + "public/some/key", + [ + "%R~public/*", # Read-only access + "%R~info/*", # Read-only access + "%RW~personal/test_username/*", # Read/Write access + "%RW~user/*", # Read/Write access + ], + RemoteAccess.READ, + ), + ( + "info/some/key", + [ + "%R~public/*", # Read-only access + "%R~info/*", # Read-only access + "%RW~personal/test_username/*", # Read/Write access + "%RW~user/*", # Read/Write access + ], + RemoteAccess.READ, + ), + ( + "personal/test_username/some/key", + [ + "%R~public/*", # Read-only access + "%R~info/*", # Read-only access + "%RW~personal/test_username/*", # Read/Write access + "%RW~user/*", # Read/Write access + ], + RemoteAccess.READ_WRITE, + ), + ( + "user/some/key", + [ + "%R~public/*", # Read-only access + "%R~info/*", # Read-only access + "%RW~personal/test_username/*", # Read/Write access + "%RW~user/*", # Read/Write access + ], + RemoteAccess.READ_WRITE, + ), + ("user/some/key", ["*"], RemoteAccess.READ_WRITE), + ("public/some/key", ["%W~public/*"], RemoteAccess.WRITE), + ("some/key", ["%W~public/*"], RemoteAccess.NONE), + ], +) +def test_get_key_pattern_access(backend, key, patterns, access): + _, app = backend + assert app.redis_router.get_key_pattern_access(key, patterns) == access + + +@pytest.mark.parametrize( + "channel, patterns, access", + [ + ( + "public/some/channel", + ["public/*", "info/*", "personal/test_username/*", "user/*"], + RemoteAccess.READ_WRITE, + ), + ("some/channel", ["public/*"], RemoteAccess.NONE), + ], +) +def test_get_channel_pattern_access(backend, channel, patterns, access): + _, app = backend + assert app.redis_router.get_channel_pattern_access(channel, patterns) == access + + +@pytest.mark.parametrize( + "user, deployment_access, expected_access", + [ + ( + User( + owner_groups=["admin"], + access_groups=["admin"], + email="admin@bec_atlas.ch", + groups=["admin"], + first_name="admin", + last_name="admin", + ), + DeploymentAccess( + owner_groups=["admin"], access_groups=["admin"], user_read_access=["admin"] + ), + RemoteAccess.NONE, + ), + ( + User( + owner_groups=["admin"], + access_groups=["admin"], + email="admin@bec_atlas.ch", + groups=["admin"], + first_name="admin", + last_name="admin", + ), + DeploymentAccess( + owner_groups=["admin"], access_groups=["admin"], remote_read_access=["admin"] + ), + RemoteAccess.READ, + ), + ( + User( + owner_groups=["admin"], + access_groups=["admin"], + email="admin@bec_atlas.ch", + groups=["admin"], + first_name="admin", + last_name="admin", + ), + DeploymentAccess( + owner_groups=["admin"], access_groups=["admin"], remote_write_access=["admin"] + ), + RemoteAccess.READ_WRITE, + ), + ], +) +def test_get_access(backend, user, deployment_access, expected_access): + _, app = backend + assert app.redis_router.get_access(user, deployment_access) == expected_access + + +@pytest.mark.parametrize( + "bec_access, key, redis_op, raise_exception", + [ + # Full access profile - should allow all operations + ( + BECAccessProfile( + deployment_id="test_id", + username="admin", + owner_groups=["admin"], + keys=["*"], + channels=["*"], + commands=["*"], + ), + "some/key", + "get", + False, + ), + # Read-only access to keys + ( + BECAccessProfile( + deployment_id="test_id", + username="reader", + owner_groups=["readers"], + keys=["%R~data/*"], + channels=["*"], + commands=["*"], + ), + "data/sensor1", + "get", + False, + ), + # Write operation with read-only access should fail + ( + BECAccessProfile( + deployment_id="test_id", + username="reader", + owner_groups=["readers"], + keys=["%R~data/*"], + channels=["*"], + commands=["*"], + ), + "data/sensor1", + "set", + True, + ), + # Send operation to allowed channel + ( + BECAccessProfile( + deployment_id="test_id", + username="writer", + owner_groups=["writers"], + keys=["*"], + channels=["commands/*"], + commands=["*"], + ), + "commands/motor1", + "send", + False, + ), + # Testing set_and_publish with mixed permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["%RW~status/*"], + channels=["status/*"], + commands=["*"], + ), + "status/device1", + "set_and_publish", + False, + ), + # Testing set_and_publish with insufficient key permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["%R~status/*"], + channels=["status/*"], + commands=["*"], + ), + "status/device1", + "set_and_publish", + True, + ), + # Testing invalid operation + ( + BECAccessProfile( + deployment_id="test_id", + username="admin", + owner_groups=["admin"], + keys=["*"], + channels=["*"], + commands=["*"], + ), + "some/key", + "invalid_op", + True, + ), + # Test send operation with insufficient channel permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["*"], + channels=["internal/*"], + commands=["*"], + ), + "status/device1", + "send", + True, + ), + # Test set_and_publish with insufficient write permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["%R~status/*"], + channels=["status/*"], + commands=["*"], + ), + "status/device1", + "set_and_publish", + True, + ), + # Test set_and_publish with insufficient channel permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["%RW~status*"], + channels=["internal/*"], + commands=["*"], + ), + "status/device1", + "set_and_publish", + True, + ), + # Test get operation with insufficient read permissions + ( + BECAccessProfile( + deployment_id="test_id", + username="user", + owner_groups=["users"], + keys=["%W~status/*"], + channels=["status/*"], + commands=["*"], + ), + "status/device1", + "get", + True, + ), + ], + ids=[ + "Full access profile - should allow all operations", + "Read-only access to keys", + "Write operation with read-only access should fail", + "Send operation to allowed channel", + "Testing set_and_publish with mixed permissions", + "Testing set_and_publish with insufficient key permissions", + "Testing invalid operation", + "Test send operation with insufficient channel permissions", + "Test set_and_publish with insufficient write permissions", + "Test set_and_publish with insufficient channel permissions", + "Test get operation with insufficient read permissions", + ], +) +def test_bec_access_profile_allows_op(backend, bec_access, key, redis_op, raise_exception): + _, app = backend + if raise_exception: + with pytest.raises(Exception): + app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op) + else: + app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op) + + +# @pytest.mark.asyncio +def test_redis_get(logged_in_client, deployment, backend): + client = logged_in_client + _, app = backend + response = client.patch( + "/api/v1/deployment_access", + params={"deployment_id": deployment["_id"]}, + json={ + "user_read_access": ["admin@bec_atlas.ch"], + "remote_read_access": ["admin@bec_atlas.ch"], + }, + ) + assert response.status_code == 200 + + with mock.patch.object(app.redis_router.redis, "pubsub") as pubsub_mock: + msg = MsgpackSerialization.dumps( + messages.RawMessage(data={"test_key": "test"}, metadata={"message": "test"}) + ) + response = { + "type": "message", + "pattern": None, + "channel": "internal/deployment", + "data": msg, + } + + pubsub_mock().subscribe = mock.AsyncMock() + ret_msg = pubsub_mock().get_message = mock.AsyncMock() + ret_msg.side_effect = [None, response] + response = client.get( + "/api/v1/redis", params={"deployment": deployment["_id"], "key": "test_key"} + ) + assert response.status_code == 200 + assert response.json() == { + "data": {"data": {"test_key": "test"}}, + "metadata": {"message": "test"}, + } diff --git a/backend/tests/test_redis_websocket.py b/backend/tests/test_redis_websocket.py index 6e2d8ee..7bfef4c 100644 --- a/backend/tests/test_redis_websocket.py +++ b/backend/tests/test_redis_websocket.py @@ -28,7 +28,7 @@ def backend_client(backend): async def connected_ws(backend_client): client, app = backend_client deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json() - with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ): + with mock.patch.object(app.redis_router, "get_access", return_value=RemoteAccess.READ): await app.redis_websocket.socket.handlers["/"]["connect"]( "sid", {