fix: fixed router access to redis; added tests

This commit is contained in:
2025-02-18 20:16:36 +01:00
parent 3ee7c0f652
commit 7569bc920a
7 changed files with 632 additions and 53 deletions

View File

@ -1,10 +1,16 @@
from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING
from bec_atlas.model.model import User
if TYPE_CHECKING: # pragma: no cover
from bec_atlas.datasources.datasource_manager import DatasourceManager
class BaseRouter:
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None:
self.datasources = datasources
self.prefix = prefix

View File

@ -118,6 +118,12 @@ class DeploymentAccessRouter(BaseRouter):
+ original.su_read_access
+ original.su_write_access
)
for profile in new_profiles:
# check if the user exists
user = self._is_valid_user(profile)
if not user:
raise HTTPException(status_code=400, detail=f"User {profile} does not exist")
removed_profiles = old_profiles - new_profiles
for profile in removed_profiles:
db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id})
@ -174,6 +180,20 @@ class DeploymentAccessRouter(BaseRouter):
redis.connector.set_and_publish(endpoint_info, MsgpackSerialization.dumps(profiles))
def _is_valid_user(self, user: str) -> bool:
"""
Check if the user exists.
Args:
user (str): The user's email
Returns:
bool: True if the user exists, False otherwise
"""
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
user = db.find_one("users", {"email": user}, User)
return user is not None
def _get_redis_access_profile(self, access_profile: str, username: str, deployment_id: str):
"""
Get the redis access profile.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import enum
import functools
@ -5,27 +7,32 @@ import inspect
import json
import traceback
import uuid
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
import socketio
from bec_lib import messages
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
from bec_lib.logger import bec_logger
from bec_lib.serialization import MsgpackSerialization, json_ext
from bson import ObjectId
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync
from bec_atlas.model.model import DeploymentAccess, User
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
from bec_atlas.router.base_router import BaseRouter
logger = bec_logger.logger
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from bec_lib.redis_connector import RedisConnector
from bec_atlas.datasources.datasource_manager import DatasourceManager
from bec_atlas.main import AtlasApp
class RemoteAccess(enum.Enum):
READ = "read"
WRITE = "write"
READ_WRITE = "read_write"
NONE = "none"
@ -132,9 +139,10 @@ class RedisRouter(BaseRouter):
the API. For pub/sub and stream operations, a websocket connection can be used.
"""
def __init__(self, prefix="/api/v1", datasources=None):
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
super().__init__(prefix, datasources)
self.redis = self.datasources.datasources["redis"].async_connector
self.db = self.datasources.datasources["mongodb"]
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
@ -147,6 +155,7 @@ class RedisRouter(BaseRouter):
async def redis_get(
self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user)
):
self.validate_user_bec_access(current_user, deployment, key, "get", "read")
request_id = uuid.uuid4().hex
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
@ -158,19 +167,206 @@ class RedisRouter(BaseRouter):
response = await pubsub.get_message(timeout=10)
print(response)
response = await pubsub.get_message(timeout=10)
if response is None:
return json_ext.dumps({"error": "Timeout waiting for response"})
out = MsgpackSerialization.loads(response["data"])
return json_ext.dumps({"data": out.content, "metadata": out.metadata})
@convert_to_user
async def redis_post(
self, key: str, value: str, current_user: User = Depends(get_current_user)
self,
deployment: str,
key: str,
value: dict,
redis_op: Literal["send", "set_and_publish", "lpush", "rpush", "set", "xadd"],
msg_type: str,
current_user: User = Depends(get_current_user),
):
return self.redis.set(key, value)
"""
Send a message to the BEC instance of the specified deployment.
Args:
deployment (str): The deployment id
key (str): The key in Redis
value (dict): The value to send
redis_op (str): The operation to perform
msg_type (str): The message type
current_user (User): The current user
"""
self.validate_user_bec_access(current_user, deployment, key, redis_op, "write")
msg_type = getattr(messages, msg_type, None)
if msg_type is None:
raise HTTPException(status_code=400, detail="Invalid message type")
try:
msg = msg_type(**value)
except TypeError as exc:
raise HTTPException(status_code=400, detail=str(exc))
# msg_dump = MsgpackSerialization.dumps(msg)
data = MsgpackSerialization.dumps(
messages.RawMessage(data={"action": redis_op, "key": key, "value": msg})
)
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
pubsub = self.redis.pubsub()
pubsub.ignore_subscribe_messages = True
await self.redis.publish(request_endpoint, data)
return {"status": "success"}
@convert_to_user
async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)):
return self.redis.delete(key)
async def redis_delete(
self, deployment: str, key: str, current_user: User = Depends(get_current_user)
):
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
pubsub = self.redis.pubsub()
pubsub.ignore_subscribe_messages = True
data = {"action": "delete", "key": key}
await self.redis.publish(request_endpoint, data)
def validate_user_bec_access(
self,
user: User,
deployment: str,
key: str,
redis_op: str,
operation_type: Literal["read", "write"],
):
"""
Validate the user access to the specified key and operation for the specified deployment.
Args:
user (User): The user object
deployment (str): The deployment name
key (str): The key in Redis
operation_type (str): The operation type (read or write)
Raises:
HTTPException: If the user does not have access to the key
"""
deployment_access = self.db.find_one(
"deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess
)
if not deployment_access:
raise ValueError("Deployment not found")
access = self.get_access(user, deployment_access)
# check if the user has access to the deployment
if access == RemoteAccess.NONE:
raise HTTPException(
status_code=403, detail="User does not have remote access to the deployment"
)
if operation_type == "read":
if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have read access")
elif operation_type == "write":
if access != RemoteAccess.READ_WRITE:
raise HTTPException(status_code=403, detail="User does not have write access")
else:
raise ValueError("Invalid operation type")
# check if the user has access to the key
bec_access = self.db.find_one(
"bec_access_profiles",
{"deployment_id": deployment, "username": {"$in": [user.email, user.username]}},
BECAccessProfile,
user=user,
)
if not bec_access:
raise HTTPException(status_code=403, detail="User does not have access to the key")
self.bec_access_profile_allows_op(bec_access, key, redis_op)
def bec_access_profile_allows_op(self, bec_access: BECAccessProfile, key: str, redis_op: str):
"""
Check if the BEC access profile allows the operation on the key.
Args:
bec_access (BECAccessProfile): The BEC access profile
key (str): The key in Redis
redis_op (str): The operation to perform
"""
if redis_op in ["lpush", "rpush", "set", "xadd", "delete"]:
access = self.get_key_pattern_access(key, bec_access.keys)
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have access to the key")
elif redis_op == "send":
access = self.get_channel_pattern_access(key, bec_access.channels)
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have access to the key")
elif redis_op == "set_and_publish":
access = self.get_key_pattern_access(key, bec_access.keys)
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have access to the key")
access = self.get_channel_pattern_access(key, bec_access.channels)
if access not in [RemoteAccess.WRITE, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have access to the key")
elif redis_op == "get":
access = self.get_key_pattern_access(key, bec_access.keys)
if access not in [RemoteAccess.READ, RemoteAccess.READ_WRITE]:
raise HTTPException(status_code=403, detail="User does not have access to the key")
else:
raise ValueError("Invalid operation")
@staticmethod
def get_key_pattern_access(key: str, patterns: list[str]) -> bool:
"""
Check if the key matches the pattern.
Args:
key (str): The key
patterns (list[str]): The patterns
Returns:
bool: True if the key matches the pattern, False otherwise
"""
if "*" in patterns:
return RemoteAccess.READ_WRITE
for pattern in patterns:
components = pattern.split("~")
rule = components[0]
subpattern = "".join(components[1:]).split("*", maxsplit=1)[0]
if subpattern in key:
if rule == "%R":
return RemoteAccess.READ
if rule == "%W":
return RemoteAccess.WRITE
if rule == "%RW":
return RemoteAccess.READ_WRITE
return RemoteAccess.NONE
@staticmethod
def get_channel_pattern_access(channel: str, patterns: list[str]) -> bool:
"""
Check if the channel matches the pattern.
Args:
channel (str): The channel
patterns (list[str]): The patterns
Returns:
bool: True if the channel matches the pattern, False otherwise
"""
for pattern in patterns:
prefix = pattern.split("*")[0]
if prefix in channel:
return RemoteAccess.READ_WRITE
return RemoteAccess.NONE
@staticmethod
def get_access(user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
"""
Get the access level of the user to the deployment.
"""
access = RemoteAccess.NONE
groups = set(user.groups)
if user.username is not None:
groups.add(user.username)
if user.email is not None:
groups.add(user.email)
if groups & set(deployment_access.remote_read_access):
access = RemoteAccess.READ
if groups & set(deployment_access.remote_write_access):
access = RemoteAccess.READ_WRITE
return access
def safe_socket(fcn):
@ -277,10 +473,11 @@ class RedisWebsocket:
the websocket.
"""
def __init__(self, prefix="/api/v1", datasources=None, app=None):
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
self.redis: RedisConnector = datasources.datasources["redis"].connector
self.prefix = prefix
self.fastapi_app = app
self.redis_router = app.redis_router
self.active_connections = set()
redis_host = datasources.datasources["redis"].config["host"]
redis_port = datasources.datasources["redis"].config["port"]
@ -336,34 +533,17 @@ class RedisWebsocket:
raise ValueError("Deployment not found in query parameters")
deployment_access = self.db.find_one(
"deployments", {"_id": ObjectId(deployment)}, DeploymentAccess
"deployment_access", {"_id": ObjectId(deployment)}, DeploymentAccess
)
if not deployment_access:
raise ValueError("Deployment not found")
access = self.get_access(user, deployment_access)
access = self.redis_router.get_access(user, deployment_access)
if access == RemoteAccess.NONE:
raise ValueError("User does not have remote access to the deployment")
return user, deployment, access
def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
"""
Get the access level of the user to the deployment.
"""
access = RemoteAccess.NONE
groups = set(user.groups)
if user.username is not None:
groups.add(user.username)
if user.email is not None:
groups.add(user.email)
if groups & set(deployment_access.remote_read_access):
access = RemoteAccess.READ
if groups & set(deployment_access.remote_write_access):
access = RemoteAccess.READ_WRITE
return access
@safe_socket
async def connect_client(self, sid, environ=None, auth=None, **kwargs):
if sid in self.users:

View File

@ -86,6 +86,11 @@ def backend(redis_server):
return fakeredis.FakeStrictRedis(server=redis_server)
mongo_client = mongomock.MongoClient("localhost", 27027)
fake_async_redis = fakeredis.FakeAsyncRedis(
server=redis_server, username="ingestor", password="ingestor"
)
fake_async_redis.connection_pool.connection_kwargs["username"] = "ingestor"
fake_async_redis.connection_pool.connection_kwargs["password"] = "ingestor"
config = {
"redis": {
@ -94,7 +99,7 @@ def backend(redis_server):
"username": "ingestor",
"password": "ingestor",
"sync_instance": RedisConnector("localhost:1", redis_cls=_fake_redis),
"async_instance": fakeredis.FakeAsyncRedis(server=redis_server),
"async_instance": fake_async_redis,
},
"mongodb": {"host": "localhost", "port": 27027, "mongodb_client": mongo_client},
}
@ -105,13 +110,7 @@ def backend(redis_server):
class PatchedBECAsyncRedisManager(BECAsyncRedisManager):
def _redis_connect(self):
self.redis = fakeredis.FakeAsyncRedis(
server=redis_server,
username=config["redis"]["username"],
password=config["redis"]["password"],
)
self.redis.connection_pool.connection_kwargs["username"] = config["redis"]["username"]
self.redis.connection_pool.connection_kwargs["password"] = config["redis"]["password"]
self.redis = fake_async_redis
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
with mock.patch(

View File

@ -1,3 +1,5 @@
from unittest import mock
import pytest
from bec_atlas.model.model import DeploymentAccess
@ -42,10 +44,11 @@ def test_deployment_access_router(logged_in_client):
out = DeploymentAccess(**out)
def test_patch_deployment_access(logged_in_client):
def test_patch_deployment_access(logged_in_client, backend):
"""
Test that the deployment access endpoint returns a 200 when the deployment id is valid.
"""
_, app = backend
deployments = logged_in_client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
).json()
@ -58,18 +61,19 @@ def test_patch_deployment_access(logged_in_client):
out = response.json()
out = DeploymentAccess(**out)
response = logged_in_client.patch(
"/api/v1/deployment_access",
params={"deployment_id": deployment_id},
json={
"user_read_access": ["test1"],
"user_write_access": ["test2"],
"su_read_access": ["test3"],
"su_write_access": ["test4"],
"remote_read_access": ["test5"],
"remote_write_access": ["test6"],
},
)
with mock.patch.object(app.deployment_access_router, "_is_valid_user", return_value=True):
response = logged_in_client.patch(
"/api/v1/deployment_access",
params={"deployment_id": deployment_id},
json={
"user_read_access": ["test1"],
"user_write_access": ["test2"],
"su_read_access": ["test3"],
"su_write_access": ["test4"],
"remote_read_access": ["test5"],
"remote_write_access": ["test6"],
},
)
assert response.status_code == 200
out = response.json()
out = DeploymentAccess(**out)

View File

@ -0,0 +1,370 @@
import asyncio
from unittest import mock
import pytest
from bec_lib import messages
from bec_lib.serialization import MsgpackSerialization
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
from bec_atlas.router.redis_router import RemoteAccess
@pytest.fixture
def logged_in_client(backend):
client, _ = backend
response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
assert response.status_code == 200
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
return client
@pytest.fixture
def deployment(logged_in_client):
client = logged_in_client
response = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"})
assert response.status_code == 200
return response.json()[0]
@pytest.mark.parametrize(
"key, patterns, access",
[
(
"public/some/key",
[
"%R~public/*", # Read-only access
"%R~info/*", # Read-only access
"%RW~personal/test_username/*", # Read/Write access
"%RW~user/*", # Read/Write access
],
RemoteAccess.READ,
),
(
"info/some/key",
[
"%R~public/*", # Read-only access
"%R~info/*", # Read-only access
"%RW~personal/test_username/*", # Read/Write access
"%RW~user/*", # Read/Write access
],
RemoteAccess.READ,
),
(
"personal/test_username/some/key",
[
"%R~public/*", # Read-only access
"%R~info/*", # Read-only access
"%RW~personal/test_username/*", # Read/Write access
"%RW~user/*", # Read/Write access
],
RemoteAccess.READ_WRITE,
),
(
"user/some/key",
[
"%R~public/*", # Read-only access
"%R~info/*", # Read-only access
"%RW~personal/test_username/*", # Read/Write access
"%RW~user/*", # Read/Write access
],
RemoteAccess.READ_WRITE,
),
("user/some/key", ["*"], RemoteAccess.READ_WRITE),
("public/some/key", ["%W~public/*"], RemoteAccess.WRITE),
("some/key", ["%W~public/*"], RemoteAccess.NONE),
],
)
def test_get_key_pattern_access(backend, key, patterns, access):
_, app = backend
assert app.redis_router.get_key_pattern_access(key, patterns) == access
@pytest.mark.parametrize(
"channel, patterns, access",
[
(
"public/some/channel",
["public/*", "info/*", "personal/test_username/*", "user/*"],
RemoteAccess.READ_WRITE,
),
("some/channel", ["public/*"], RemoteAccess.NONE),
],
)
def test_get_channel_pattern_access(backend, channel, patterns, access):
_, app = backend
assert app.redis_router.get_channel_pattern_access(channel, patterns) == access
@pytest.mark.parametrize(
"user, deployment_access, expected_access",
[
(
User(
owner_groups=["admin"],
access_groups=["admin"],
email="admin@bec_atlas.ch",
groups=["admin"],
first_name="admin",
last_name="admin",
),
DeploymentAccess(
owner_groups=["admin"], access_groups=["admin"], user_read_access=["admin"]
),
RemoteAccess.NONE,
),
(
User(
owner_groups=["admin"],
access_groups=["admin"],
email="admin@bec_atlas.ch",
groups=["admin"],
first_name="admin",
last_name="admin",
),
DeploymentAccess(
owner_groups=["admin"], access_groups=["admin"], remote_read_access=["admin"]
),
RemoteAccess.READ,
),
(
User(
owner_groups=["admin"],
access_groups=["admin"],
email="admin@bec_atlas.ch",
groups=["admin"],
first_name="admin",
last_name="admin",
),
DeploymentAccess(
owner_groups=["admin"], access_groups=["admin"], remote_write_access=["admin"]
),
RemoteAccess.READ_WRITE,
),
],
)
def test_get_access(backend, user, deployment_access, expected_access):
_, app = backend
assert app.redis_router.get_access(user, deployment_access) == expected_access
@pytest.mark.parametrize(
"bec_access, key, redis_op, raise_exception",
[
# Full access profile - should allow all operations
(
BECAccessProfile(
deployment_id="test_id",
username="admin",
owner_groups=["admin"],
keys=["*"],
channels=["*"],
commands=["*"],
),
"some/key",
"get",
False,
),
# Read-only access to keys
(
BECAccessProfile(
deployment_id="test_id",
username="reader",
owner_groups=["readers"],
keys=["%R~data/*"],
channels=["*"],
commands=["*"],
),
"data/sensor1",
"get",
False,
),
# Write operation with read-only access should fail
(
BECAccessProfile(
deployment_id="test_id",
username="reader",
owner_groups=["readers"],
keys=["%R~data/*"],
channels=["*"],
commands=["*"],
),
"data/sensor1",
"set",
True,
),
# Send operation to allowed channel
(
BECAccessProfile(
deployment_id="test_id",
username="writer",
owner_groups=["writers"],
keys=["*"],
channels=["commands/*"],
commands=["*"],
),
"commands/motor1",
"send",
False,
),
# Testing set_and_publish with mixed permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["%RW~status/*"],
channels=["status/*"],
commands=["*"],
),
"status/device1",
"set_and_publish",
False,
),
# Testing set_and_publish with insufficient key permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["%R~status/*"],
channels=["status/*"],
commands=["*"],
),
"status/device1",
"set_and_publish",
True,
),
# Testing invalid operation
(
BECAccessProfile(
deployment_id="test_id",
username="admin",
owner_groups=["admin"],
keys=["*"],
channels=["*"],
commands=["*"],
),
"some/key",
"invalid_op",
True,
),
# Test send operation with insufficient channel permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["*"],
channels=["internal/*"],
commands=["*"],
),
"status/device1",
"send",
True,
),
# Test set_and_publish with insufficient write permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["%R~status/*"],
channels=["status/*"],
commands=["*"],
),
"status/device1",
"set_and_publish",
True,
),
# Test set_and_publish with insufficient channel permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["%RW~status*"],
channels=["internal/*"],
commands=["*"],
),
"status/device1",
"set_and_publish",
True,
),
# Test get operation with insufficient read permissions
(
BECAccessProfile(
deployment_id="test_id",
username="user",
owner_groups=["users"],
keys=["%W~status/*"],
channels=["status/*"],
commands=["*"],
),
"status/device1",
"get",
True,
),
],
ids=[
"Full access profile - should allow all operations",
"Read-only access to keys",
"Write operation with read-only access should fail",
"Send operation to allowed channel",
"Testing set_and_publish with mixed permissions",
"Testing set_and_publish with insufficient key permissions",
"Testing invalid operation",
"Test send operation with insufficient channel permissions",
"Test set_and_publish with insufficient write permissions",
"Test set_and_publish with insufficient channel permissions",
"Test get operation with insufficient read permissions",
],
)
def test_bec_access_profile_allows_op(backend, bec_access, key, redis_op, raise_exception):
_, app = backend
if raise_exception:
with pytest.raises(Exception):
app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op)
else:
app.redis_router.bec_access_profile_allows_op(bec_access, key, redis_op)
# @pytest.mark.asyncio
def test_redis_get(logged_in_client, deployment, backend):
client = logged_in_client
_, app = backend
response = client.patch(
"/api/v1/deployment_access",
params={"deployment_id": deployment["_id"]},
json={
"user_read_access": ["admin@bec_atlas.ch"],
"remote_read_access": ["admin@bec_atlas.ch"],
},
)
assert response.status_code == 200
with mock.patch.object(app.redis_router.redis, "pubsub") as pubsub_mock:
msg = MsgpackSerialization.dumps(
messages.RawMessage(data={"test_key": "test"}, metadata={"message": "test"})
)
response = {
"type": "message",
"pattern": None,
"channel": "internal/deployment",
"data": msg,
}
pubsub_mock().subscribe = mock.AsyncMock()
ret_msg = pubsub_mock().get_message = mock.AsyncMock()
ret_msg.side_effect = [None, response]
response = client.get(
"/api/v1/redis", params={"deployment": deployment["_id"], "key": "test_key"}
)
assert response.status_code == 200
assert response.json() == {
"data": {"data": {"test_key": "test"}},
"metadata": {"message": "test"},
}

View File

@ -28,7 +28,7 @@ def backend_client(backend):
async def connected_ws(backend_client):
client, app = backend_client
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ):
with mock.patch.object(app.redis_router, "get_access", return_value=RemoteAccess.READ):
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid",
{