mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-13 22:51:49 +02:00
fix: fixed router access to redis; added tests
This commit is contained in:
@ -1,10 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bec_atlas.model.model import User
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
|
||||
|
||||
class BaseRouter:
|
||||
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
|
||||
def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None:
|
||||
self.datasources = datasources
|
||||
self.prefix = prefix
|
||||
|
||||
|
@ -118,6 +118,12 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
+ original.su_read_access
|
||||
+ original.su_write_access
|
||||
)
|
||||
for profile in new_profiles:
|
||||
# check if the user exists
|
||||
user = self._is_valid_user(profile)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail=f"User {profile} does not exist")
|
||||
|
||||
removed_profiles = old_profiles - new_profiles
|
||||
for profile in removed_profiles:
|
||||
db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id})
|
||||
@ -174,6 +180,20 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
|
||||
redis.connector.set_and_publish(endpoint_info, MsgpackSerialization.dumps(profiles))
|
||||
|
||||
def _is_valid_user(self, user: str) -> bool:
|
||||
"""
|
||||
Check if the user exists.
|
||||
|
||||
Args:
|
||||
user (str): The user's email
|
||||
|
||||
Returns:
|
||||
bool: True if the user exists, False otherwise
|
||||
"""
|
||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
user = db.find_one("users", {"email": user}, User)
|
||||
return user is not None
|
||||
|
||||
def _get_redis_access_profile(self, access_profile: str, username: str, deployment_id: str):
|
||||
"""
|
||||
Get the redis access profile.
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
@ -5,27 +7,32 @@ import inspect
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import socketio
|
||||
from bec_lib import messages
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.serialization import MsgpackSerialization, json_ext
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync
|
||||
from bec_atlas.model.model import DeploymentAccess, User
|
||||
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.main import AtlasApp
|
||||
|
||||
|
||||
class RemoteAccess(enum.Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
READ_WRITE = "read_write"
|
||||
NONE = "none"
|
||||
|
||||
@ -132,9 +139,10 @@ class RedisRouter(BaseRouter):
|
||||
the API. For pub/sub and stream operations, a websocket connection can be used.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.redis = self.datasources.datasources["redis"].async_connector
|
||||
self.db = self.datasources.datasources["mongodb"]
|
||||
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
@ -147,6 +155,7 @@ class RedisRouter(BaseRouter):
|
||||
async def redis_get(
|
||||
self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
self.validate_user_bec_access(current_user, deployment, key, "get", "read")
|
||||
request_id = uuid.uuid4().hex
|
||||
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
|
||||
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
|
||||
@ -158,19 +167,206 @@ class RedisRouter(BaseRouter):
|
||||
response = await pubsub.get_message(timeout=10)
|
||||
print(response)
|
||||
response = await pubsub.get_message(timeout=10)
|
||||
if response is None:
|
||||
return json_ext.dumps({"error": "Timeout waiting for response"})
|
||||
out = MsgpackSerialization.loads(response["data"])
|
||||
|
||||
return json_ext.dumps({"data": out.content, "metadata": out.metadata})
|
||||
|
||||
@convert_to_user
|
||||
async def redis_post(
|
||||
self, key: str, value: str, current_user: User = Depends(get_current_user)
|
||||
self,
|
||||
deployment: str,
|
||||
key: str,
|
||||
value: dict,
|
||||
redis_op: Literal["send", "set_and_publish", "lpush", "rpush", "set", "xadd"],
|
||||
msg_type: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
return self.redis.set(key, value)
|
||||
"""
|
||||
Send a message to the BEC instance of the specified deployment.
|
||||
|
||||
Args:
|
||||
deployment (str): The deployment id
|
||||
key (str): The key in Redis
|
||||
value (dict): The value to send
|
||||
redis_op (str): The operation to perform
|
||||
msg_type (str): The message type
|
||||
current_user (User): The current user
|
||||
"""
|
||||
self.validate_user_bec_access(current_user, deployment, key, redis_op, "write")
|
||||
msg_type = getattr(messages, msg_type, None)
|
||||
if msg_type is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid message type")
|
||||
try:
|
||||
msg = msg_type(**value)
|
||||
except TypeError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
# msg_dump = MsgpackSerialization.dumps(msg)
|
||||
data = MsgpackSerialization.dumps(
|
||||
messages.RawMessage(data={"action": redis_op, "key": key, "value": msg})
|
||||
)
|
||||
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
|
||||
pubsub = self.redis.pubsub()
|
||||
pubsub.ignore_subscribe_messages = True
|
||||
await self.redis.publish(request_endpoint, data)
|
||||
return {"status": "success"}
|
||||
|
||||
@convert_to_user
|
||||
async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)):
|
||||
return self.redis.delete(key)
|
||||
async def redis_delete(
|
||||
self, deployment: str, key: str, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
|
||||
pubsub = self.redis.pubsub()
|
||||
pubsub.ignore_subscribe_messages = True
|
||||
data = {"action": "delete", "key": key}
|
||||
await self.redis.publish(request_endpoint, data)
|
||||
|
||||
def validate_user_bec_access(
|
||||
self,
|
||||
user: User,
|
||||
deployment: str,
|
||||
key: str,
|
||||
redis_op: str,
|
||||
operation_type: Literal["read", "write"],
|
||||
):
|
||||
"""
|
||||
Validate the user access to the specified key and operation for the specified deployment.
|
||||
|
||||
Args:
|
||||
user (User): The user object
|
||||
deployment (str): The deployment name
|
||||
key (str): The key in Redis
|
||||
operation_type (str): The operation type (read or write)
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user does not have access to the key
|
||||
"""
|
||||
deployment_access = self.db.find_one(
|
||||
"deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess
|
||||
)
|
||||
if not deployment_access:
|
||||
raise ValueError("Deployment not found")
|
||||
access = self.get_access(user, deployment_access)
|
||||
|
||||
# check if the user has access to the deployment
|
||||
if access == RemoteAccess.NONE:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="User does not have remote access to the deployment"
|
||||
)
|
||||
if operation_type == "read":
|
||||
if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have read access")
|
||||
elif operation_type == "write":
|
||||
if access != RemoteAccess.READ_WRITE:
|
||||
raise HTTPException(status_code=403, detail="User does not have write access")
|
||||
else:
|
||||
raise ValueError("Invalid operation type")
|
||||
|
||||
# check if the user has access to the key
|
||||
bec_access = self.db.find_one(
|
||||
"bec_access_profiles",
|
||||
{"deployment_id": deployment, "username": {"$in": [user.email, user.username]}},
|
||||
BECAccessProfile,
|
||||
user=user,
|
||||
)
|
||||
if not bec_access:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
|
||||
self.bec_access_profile_allows_op(bec_access, key, redis_op)
|
||||
|
||||
def bec_access_profile_allows_op(self, bec_access: BECAccessProfile, key: str, redis_op: str):
|
||||
"""
|
||||
Check if the BEC access profile allows the operation on the key.
|
||||
|
||||
Args:
|
||||
bec_access (BECAccessProfile): The BEC access profile
|
||||
key (str): The key in Redis
|
||||
redis_op (str): The operation to perform
|
||||
"""
|
||||
if redis_op in ["lpush", "rpush", "set", "xadd", "delete"]:
|
||||
access = self.get_key_pattern_access(key, bec_access.keys)
|
||||
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
elif redis_op == "send":
|
||||
access = self.get_channel_pattern_access(key, bec_access.channels)
|
||||
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
elif redis_op == "set_and_publish":
|
||||
access = self.get_key_pattern_access(key, bec_access.keys)
|
||||
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
access = self.get_channel_pattern_access(key, bec_access.channels)
|
||||
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
elif redis_op == "get":
|
||||
access = self.get_key_pattern_access(key, bec_access.keys)
|
||||
if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]:
|
||||
raise HTTPException(status_code=403, detail="User does not have access to the key")
|
||||
else:
|
||||
raise ValueError("Invalid operation")
|
||||
|
||||
@staticmethod
|
||||
def get_key_pattern_access(key: str, patterns: list[str]) -> bool:
|
||||
"""
|
||||
Check if the key matches the pattern.
|
||||
|
||||
Args:
|
||||
key (str): The key
|
||||
patterns (list[str]): The patterns
|
||||
|
||||
Returns:
|
||||
bool: True if the key matches the pattern, False otherwise
|
||||
"""
|
||||
if "*" in patterns:
|
||||
return RemoteAccess.READ_WRITE
|
||||
for pattern in patterns:
|
||||
components = pattern.split("~")
|
||||
rule = components[0]
|
||||
subpattern = "".join(components[1:]).split("*", maxsplit=1)[0]
|
||||
if subpattern in key:
|
||||
if rule == "%R":
|
||||
return RemoteAccess.READ
|
||||
if rule == "%W":
|
||||
return RemoteAccess.WRITE
|
||||
if rule == "%RW":
|
||||
return RemoteAccess.READ_WRITE
|
||||
return RemoteAccess.NONE
|
||||
|
||||
@staticmethod
|
||||
def get_channel_pattern_access(channel: str, patterns: list[str]) -> bool:
|
||||
"""
|
||||
Check if the channel matches the pattern.
|
||||
|
||||
Args:
|
||||
channel (str): The channel
|
||||
patterns (list[str]): The patterns
|
||||
|
||||
Returns:
|
||||
bool: True if the channel matches the pattern, False otherwise
|
||||
"""
|
||||
for pattern in patterns:
|
||||
prefix = pattern.split("*")[0]
|
||||
if prefix in channel:
|
||||
return RemoteAccess.READ_WRITE
|
||||
return RemoteAccess.NONE
|
||||
|
||||
@staticmethod
|
||||
def get_access(user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
|
||||
"""
|
||||
Get the access level of the user to the deployment.
|
||||
"""
|
||||
access = RemoteAccess.NONE
|
||||
groups = set(user.groups)
|
||||
if user.username is not None:
|
||||
groups.add(user.username)
|
||||
if user.email is not None:
|
||||
groups.add(user.email)
|
||||
|
||||
if groups & set(deployment_access.remote_read_access):
|
||||
access = RemoteAccess.READ
|
||||
if groups & set(deployment_access.remote_write_access):
|
||||
access = RemoteAccess.READ_WRITE
|
||||
return access
|
||||
|
||||
|
||||
def safe_socket(fcn):
|
||||
@ -277,10 +473,11 @@ class RedisWebsocket:
|
||||
the websocket.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None, app=None):
|
||||
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
|
||||
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
||||
self.prefix = prefix
|
||||
self.fastapi_app = app
|
||||
self.redis_router = app.redis_router
|
||||
self.active_connections = set()
|
||||
redis_host = datasources.datasources["redis"].config["host"]
|
||||
redis_port = datasources.datasources["redis"].config["port"]
|
||||
@ -336,34 +533,17 @@ class RedisWebsocket:
|
||||
raise ValueError("Deployment not found in query parameters")
|
||||
|
||||
deployment_access = self.db.find_one(
|
||||
"deployments", {"_id": ObjectId(deployment)}, DeploymentAccess
|
||||
"deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess
|
||||
)
|
||||
if not deployment_access:
|
||||
raise ValueError("Deployment not found")
|
||||
|
||||
access = self.get_access(user, deployment_access)
|
||||
access = self.redis_router.get_access(user, deployment_access)
|
||||
if access == RemoteAccess.NONE:
|
||||
raise ValueError("User does not have remote access to the deployment")
|
||||
|
||||
return user, deployment, access
|
||||
|
||||
def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
|
||||
"""
|
||||
Get the access level of the user to the deployment.
|
||||
"""
|
||||
access = RemoteAccess.NONE
|
||||
groups = set(user.groups)
|
||||
if user.username is not None:
|
||||
groups.add(user.username)
|
||||
if user.email is not None:
|
||||
groups.add(user.email)
|
||||
|
||||
if groups & set(deployment_access.remote_read_access):
|
||||
access = RemoteAccess.READ
|
||||
if groups & set(deployment_access.remote_write_access):
|
||||
access = RemoteAccess.READ_WRITE
|
||||
return access
|
||||
|
||||
@safe_socket
|
||||
async def connect_client(self, sid, environ=None, auth=None, **kwargs):
|
||||
if sid in self.users:
|
||||
|
@ -86,6 +86,11 @@ def backend(redis_server):
|
||||
return fakeredis.FakeStrictRedis(server=redis_server)
|
||||
|
||||
mongo_client = mongomock.MongoClient("localhost", 27027)
|
||||
fake_async_redis = fakeredis.FakeAsyncRedis(
|
||||
server=redis_server, username="ingestor", password="ingestor"
|
||||
)
|
||||
fake_async_redis.connection_pool.connection_kwargs["username"] = "ingestor"
|
||||
fake_async_redis.connection_pool.connection_kwargs["password"] = "ingestor"
|
||||
|
||||
config = {
|
||||
"redis": {
|
||||
@ -94,7 +99,7 @@ def backend(redis_server):
|
||||
"username": "ingestor",
|
||||
"password": "ingestor",
|
||||
"sync_instance": RedisConnector("localhost:1", redis_cls=_fake_redis),
|
||||
"async_instance": fakeredis.FakeAsyncRedis(server=redis_server),
|
||||
"async_instance": fake_async_redis,
|
||||
},
|
||||
"mongodb": {"host": "localhost", "port": 27027, "mongodb_client": mongo_client},
|
||||
}
|
||||
@ -105,13 +110,7 @@ def backend(redis_server):
|
||||
|
||||
class PatchedBECAsyncRedisManager(BECAsyncRedisManager):
|
||||
def _redis_connect(self):
|
||||
self.redis = fakeredis.FakeAsyncRedis(
|
||||
server=redis_server,
|
||||
username=config["redis"]["username"],
|
||||
password=config["redis"]["password"],
|
||||
)
|
||||
self.redis.connection_pool.connection_kwargs["username"] = config["redis"]["username"]
|
||||
self.redis.connection_pool.connection_kwargs["password"] = config["redis"]["password"]
|
||||
self.redis = fake_async_redis
|
||||
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
|
||||
|
||||
with mock.patch(
|
||||
|
@ -1,3 +1,5 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from bec_atlas.model.model import DeploymentAccess
|
||||
@ -42,10 +44,11 @@ def test_deployment_access_router(logged_in_client):
|
||||
out = DeploymentAccess(**out)
|
||||
|
||||
|
||||
def test_patch_deployment_access(logged_in_client):
|
||||
def test_patch_deployment_access(logged_in_client, backend):
|
||||
"""
|
||||
Test that the deployment access endpoint returns a 200 when the deployment id is valid.
|
||||
"""
|
||||
_, app = backend
|
||||
deployments = logged_in_client.get(
|
||||
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||
).json()
|
||||
@ -58,18 +61,19 @@ def test_patch_deployment_access(logged_in_client):
|
||||
out = response.json()
|
||||
out = DeploymentAccess(**out)
|
||||
|
||||
response = logged_in_client.patch(
|
||||
"/api/v1/deployment_access",
|
||||
params={"deployment_id": deployment_id},
|
||||
json={
|
||||
"user_read_access": ["test1"],
|
||||
"user_write_access": ["test2"],
|
||||
"su_read_access": ["test3"],
|
||||
"su_write_access": ["test4"],
|
||||
"remote_read_access": ["test5"],
|
||||
"remote_write_access": ["test6"],
|
||||
},
|
||||
)
|
||||
with mock.patch.object(app.deployment_access_router, "_is_valid_user", return_value=True):
|
||||
response = logged_in_client.patch(
|
||||
"/api/v1/deployment_access",
|
||||
params={"deployment_id": deployment_id},
|
||||
json={
|
||||
"user_read_access": ["test1"],
|
||||
"user_write_access": ["test2"],
|
||||
"su_read_access": ["test3"],
|
||||
"su_write_access": ["test4"],
|
||||
"remote_read_access": ["test5"],
|
||||
"remote_write_access": ["test6"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
out = response.json()
|
||||
out = DeploymentAccess(**out)
|
||||
|
370
backend/tests/test_redis_router.py
Normal file
370
backend/tests/test_redis_router.py
Normal file
@ -0,0 +1,370 @@
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from bec_lib import messages
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
|
||||
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
|
||||
from bec_atlas.router.redis_router import RemoteAccess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_client(backend):
|
||||
client, _ = backend
|
||||
response = client.post(
|
||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deployment(logged_in_client):
|
||||
client = logged_in_client
|
||||
response = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"})
|
||||
assert response.status_code == 200
|
||||
return response.json()[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key, patterns, access",
|
||||
[
|
||||
(
|
||||
"public/some/key",
|
||||
[
|
||||
"%R~public/*", # Read-only access
|
||||
"%R~info/*", # Read-only access
|
||||
"%RW~personal/test_username/*", # Read/Write access
|
||||
"%RW~user/*", # Read/Write access
|
||||
],
|
||||
RemoteAccess.READ,
|
||||
),
|
||||
(
|
||||
"info/some/key",
|
||||
[
|
||||
"%R~public/*", # Read-only access
|
||||
"%R~info/*", # Read-only access
|
||||
"%RW~personal/test_username/*", # Read/Write access
|
||||
"%RW~user/*", # Read/Write access
|
||||
],
|
||||
RemoteAccess.READ,
|
||||
),
|
||||
(
|
||||
"personal/test_username/some/key",
|
||||
[
|
||||
"%R~public/*", # Read-only access
|
||||
"%R~info/*", # Read-only access
|
||||
"%RW~personal/test_username/*", # Read/Write access
|
||||
"%RW~user/*", # Read/Write access
|
||||
],
|
||||
RemoteAccess.READ_WRITE,
|
||||
),
|
||||
(
|
||||
"user/some/key",
|
||||
[
|
||||
"%R~public/*", # Read-only access
|
||||
"%R~info/*", # Read-only access
|
||||
"%RW~personal/test_username/*", # Read/Write access
|
||||
"%RW~user/*", # Read/Write access
|
||||
],
|
||||
RemoteAccess.READ_WRITE,
|
||||
),
|
||||
("user/some/key", ["*"], RemoteAccess.READ_WRITE),
|
||||
("public/some/key", ["%W~public/*"], RemoteAccess.WRITE),
|
||||
("some/key", ["%W~public/*"], RemoteAccess.NONE),
|
||||
],
|
||||
)
|
||||
def test_get_key_pattern_access(backend, key, patterns, access):
|
||||
_, app = backend
|
||||
assert app.redis_router.get_key_pattern_access(key, patterns) == access
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"channel, patterns, access",
|
||||
[
|
||||
(
|
||||
"public/some/channel",
|
||||
["public/*", "info/*", "personal/test_username/*", "user/*"],
|
||||
RemoteAccess.READ_WRITE,
|
||||
),
|
||||
("some/channel", ["public/*"], RemoteAccess.NONE),
|
||||
],
|
||||
)
|
||||
def test_get_channel_pattern_access(backend, channel, patterns, access):
|
||||
_, app = backend
|
||||
assert app.redis_router.get_channel_pattern_access(channel, patterns) == access
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user, deployment_access, expected_access",
|
||||
[
|
||||
(
|
||||
User(
|
||||
owner_groups=["admin"],
|
||||
access_groups=["admin"],
|
||||
email="admin@bec_atlas.ch",
|
||||
groups=["admin"],
|
||||
first_name="admin",
|
||||
last_name="admin",
|
||||
),
|
||||
DeploymentAccess(
|
||||
owner_groups=["admin"], access_groups=["admin"], user_read_access=["admin"]
|
||||
),
|
||||
RemoteAccess.NONE,
|
||||
),
|
||||
(
|
||||
User(
|
||||
owner_groups=["admin"],
|
||||
access_groups=["admin"],
|
||||
email="admin@bec_atlas.ch",
|
||||
groups=["admin"],
|
||||
first_name="admin",
|
||||
last_name="admin",
|
||||
),
|
||||
DeploymentAccess(
|
||||
owner_groups=["admin"], access_groups=["admin"], remote_read_access=["admin"]
|
||||
),
|
||||
RemoteAccess.READ,
|
||||
),
|
||||
(
|
||||
User(
|
||||
owner_groups=["admin"],
|
||||
access_groups=["admin"],
|
||||
email="admin@bec_atlas.ch",
|
||||
groups=["admin"],
|
||||
first_name="admin",
|
||||
last_name="admin",
|
||||
),
|
||||
DeploymentAccess(
|
||||
owner_groups=["admin"], access_groups=["admin"], remote_write_access=["admin"]
|
||||
),
|
||||
RemoteAccess.READ_WRITE,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_access(backend, user, deployment_access, expected_access):
|
||||
_, app = backend
|
||||
assert app.redis_router.get_access(user, deployment_access) == expected_access
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bec_access, key, redis_op, raise_exception",
|
||||
[
|
||||
# Full access profile - should allow all operations
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="admin",
|
||||
owner_groups=["admin"],
|
||||
keys=["*"],
|
||||
channels=["*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"some/key",
|
||||
"get",
|
||||
False,
|
||||
),
|
||||
# Read-only access to keys
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="reader",
|
||||
owner_groups=["readers"],
|
||||
keys=["%R~data/*"],
|
||||
channels=["*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"data/sensor1",
|
||||
"get",
|
||||
False,
|
||||
),
|
||||
# Write operation with read-only access should fail
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="reader",
|
||||
owner_groups=["readers"],
|
||||
keys=["%R~data/*"],
|
||||
channels=["*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"data/sensor1",
|
||||
"set",
|
||||
True,
|
||||
),
|
||||
# Send operation to allowed channel
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="writer",
|
||||
owner_groups=["writers"],
|
||||
keys=["*"],
|
||||
channels=["commands/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"commands/motor1",
|
||||
"send",
|
||||
False,
|
||||
),
|
||||
# Testing set_and_publish with mixed permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["%RW~status/*"],
|
||||
channels=["status/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"set_and_publish",
|
||||
False,
|
||||
),
|
||||
# Testing set_and_publish with insufficient key permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["%R~status/*"],
|
||||
channels=["status/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"set_and_publish",
|
||||
True,
|
||||
),
|
||||
# Testing invalid operation
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="admin",
|
||||
owner_groups=["admin"],
|
||||
keys=["*"],
|
||||
channels=["*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"some/key",
|
||||
"invalid_op",
|
||||
True,
|
||||
),
|
||||
# Test send operation with insufficient channel permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["*"],
|
||||
channels=["internal/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"send",
|
||||
True,
|
||||
),
|
||||
# Test set_and_publish with insufficient write permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["%R~status/*"],
|
||||
channels=["status/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"set_and_publish",
|
||||
True,
|
||||
),
|
||||
# Test set_and_publish with insufficient channel permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["%RW~status*"],
|
||||
channels=["internal/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"set_and_publish",
|
||||
True,
|
||||
),
|
||||
# Test get operation with insufficient read permissions
|
||||
(
|
||||
BECAccessProfile(
|
||||
deployment_id="test_id",
|
||||
username="user",
|
||||
owner_groups=["users"],
|
||||
keys=["%W~status/*"],
|
||||
channels=["status/*"],
|
||||
commands=["*"],
|
||||
),
|
||||
"status/device1",
|
||||
"get",
|
||||
True,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"Full access profile - should allow all operations",
|
||||
"Read-only access to keys",
|
||||
"Write operation with read-only access should fail",
|
||||
"Send operation to allowed channel",
|
||||
"Testing set_and_publish with mixed permissions",
|
||||
"Testing set_and_publish with insufficient key permissions",
|
||||
"Testing invalid operation",
|
||||
"Test send operation with insufficient channel permissions",
|
||||
"Test set_and_publish with insufficient write permissions",
|
||||
"Test set_and_publish with insufficient channel permissions",
|
||||
"Test get operation with insufficient read permissions",
|
||||
],
|
||||
)
|
||||
def test_bec_access_profile_allows_op(backend, bec_access, key, redis_op, raise_exception):
|
||||
_, app = backend
|
||||
if raise_exception:
|
||||
with pytest.raises(Exception):
|
||||
app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op)
|
||||
else:
|
||||
app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
def test_redis_get(logged_in_client, deployment, backend):
|
||||
client = logged_in_client
|
||||
_, app = backend
|
||||
response = client.patch(
|
||||
"/api/v1/deployment_access",
|
||||
params={"deployment_id": deployment["_id"]},
|
||||
json={
|
||||
"user_read_access": ["admin@bec_atlas.ch"],
|
||||
"remote_read_access": ["admin@bec_atlas.ch"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
with mock.patch.object(app.redis_router.redis, "pubsub") as pubsub_mock:
|
||||
msg = MsgpackSerialization.dumps(
|
||||
messages.RawMessage(data={"test_key": "test"}, metadata={"message": "test"})
|
||||
)
|
||||
response = {
|
||||
"type": "message",
|
||||
"pattern": None,
|
||||
"channel": "internal/deployment",
|
||||
"data": msg,
|
||||
}
|
||||
|
||||
pubsub_mock().subscribe = mock.AsyncMock()
|
||||
ret_msg = pubsub_mock().get_message = mock.AsyncMock()
|
||||
ret_msg.side_effect = [None, response]
|
||||
response = client.get(
|
||||
"/api/v1/redis", params={"deployment": deployment["_id"], "key": "test_key"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"data": {"data": {"test_key": "test"}},
|
||||
"metadata": {"message": "test"},
|
||||
}
|
@ -28,7 +28,7 @@ def backend_client(backend):
|
||||
async def connected_ws(backend_client):
|
||||
client, app = backend_client
|
||||
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
|
||||
with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ):
|
||||
with mock.patch.object(app.redis_router, "get_access", return_value=RemoteAccess.READ):
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid",
|
||||
{
|
||||
|
Reference in New Issue
Block a user