mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
tests: added more backend tests
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
from bson import ObjectId
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
|
||||||
from bec_atlas.authentication import get_current_user
|
from bec_atlas.authentication import get_current_user
|
||||||
|
@ -5,15 +5,15 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from bec_lib.endpoints import EndpointInfo, MessageOp
|
from bec_lib.endpoints import EndpointInfo, MessageOp
|
||||||
from bec_lib.serialization import MsgpackSerialization
|
from bec_lib.serialization import MsgpackSerialization
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
from bec_atlas.authentication import get_current_user
|
from bec_atlas.authentication import get_current_user
|
||||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||||
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, Deployments, UserInfo
|
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo
|
||||||
from bec_atlas.router.base_router import BaseRouter
|
from bec_atlas.router.base_router import BaseRouter
|
||||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +50,8 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
Returns:
|
Returns:
|
||||||
DeploymentAccess: The access lists for the deployment
|
DeploymentAccess: The access lists for the deployment
|
||||||
"""
|
"""
|
||||||
|
if not ObjectId.is_valid(deployment_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid deployment ID")
|
||||||
return self.db.find_one(
|
return self.db.find_one(
|
||||||
"deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user
|
"deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user
|
||||||
)
|
)
|
||||||
@ -119,17 +121,17 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id})
|
db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id})
|
||||||
for profile in new_profiles:
|
for profile in new_profiles:
|
||||||
if profile in updated.su_write_access:
|
if profile in updated.su_write_access:
|
||||||
access = self._get_redis_access_profile("su_write", profile, updated.id)
|
access = self._get_redis_access_profile("su_write", profile, str(updated.id))
|
||||||
elif profile in updated.su_read_access:
|
elif profile in updated.su_read_access:
|
||||||
access = self._get_redis_access_profile("su_read", profile, updated.id)
|
access = self._get_redis_access_profile("su_read", profile, str(updated.id))
|
||||||
elif profile in updated.user_write_access:
|
elif profile in updated.user_write_access:
|
||||||
access = self._get_redis_access_profile("user_write", profile, updated.id)
|
access = self._get_redis_access_profile("user_write", profile, str(updated.id))
|
||||||
else:
|
else:
|
||||||
access = self._get_redis_access_profile("user_read", profile, updated.id)
|
access = self._get_redis_access_profile("user_read", profile, str(updated.id))
|
||||||
|
|
||||||
existing_profile = db.find_one(
|
existing_profile = db.find_one(
|
||||||
"bec_access_profiles",
|
"bec_access_profiles",
|
||||||
{"username": profile, "deployment_id": updated.id},
|
{"username": profile, "deployment_id": str(updated.id)},
|
||||||
BECAccessProfile,
|
BECAccessProfile,
|
||||||
)
|
)
|
||||||
if existing_profile:
|
if existing_profile:
|
||||||
|
@ -9,7 +9,7 @@ from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
|||||||
from bec_atlas.model.model import DeploymentCredential, UserInfo
|
from bec_atlas.model.model import DeploymentCredential, UserInfo
|
||||||
from bec_atlas.router.base_router import BaseRouter
|
from bec_atlas.router.base_router import BaseRouter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class DeploymentCredentialsRouter(BaseRouter):
|
|||||||
self.deployment_credential,
|
self.deployment_credential,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
description="Retrieve the deployment key for a specific deployment.",
|
description="Retrieve the deployment key for a specific deployment.",
|
||||||
response_model=DeploymentCredential,
|
response_model=DeploymentCredential | None,
|
||||||
)
|
)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/deploymentCredentials/refresh",
|
"/deploymentCredentials/refresh",
|
||||||
@ -42,6 +42,8 @@ class DeploymentCredentialsRouter(BaseRouter):
|
|||||||
Args:
|
Args:
|
||||||
deployment_id (str): The deployment id
|
deployment_id (str): The deployment id
|
||||||
"""
|
"""
|
||||||
|
if not ObjectId.is_valid(deployment_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid deployment ID")
|
||||||
if set(current_user.groups) & set(["admin", "bec_group"]):
|
if set(current_user.groups) & set(["admin", "bec_group"]):
|
||||||
out = self.db.find(
|
out = self.db.find(
|
||||||
"deployment_credentials", {"_id": ObjectId(deployment_id)}, DeploymentCredential
|
"deployment_credentials", {"_id": ObjectId(deployment_id)}, DeploymentCredential
|
||||||
|
96
backend/tests/test_deployment_access_router.py
Normal file
96
backend/tests/test_deployment_access_router.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from bec_atlas.model.model import DeploymentAccess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def logged_in_client(backend):
|
||||||
|
client, _ = backend
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
token = response.json()
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) > 20
|
||||||
|
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def test_deployment_access_router_invalid_deployment_id(logged_in_client):
|
||||||
|
"""
|
||||||
|
Test that the deployment access endpoint returns a 400 when the deployment id is invalid.
|
||||||
|
"""
|
||||||
|
response = logged_in_client.get("/api/v1/deployment_access", params={"deployment_id": "test"})
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"detail": "Invalid deployment ID"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_deployment_access_router(logged_in_client):
|
||||||
|
"""
|
||||||
|
Test that the deployment access endpoint returns a 200 when the deployment id is valid.
|
||||||
|
"""
|
||||||
|
deployments = logged_in_client.get(
|
||||||
|
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||||
|
).json()
|
||||||
|
deployment_id = deployments[0]["_id"]
|
||||||
|
|
||||||
|
response = logged_in_client.get(
|
||||||
|
"/api/v1/deployment_access", params={"deployment_id": deployment_id}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
out = response.json()
|
||||||
|
out = DeploymentAccess(**out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_deployment_access(logged_in_client):
|
||||||
|
"""
|
||||||
|
Test that the deployment access endpoint returns a 200 when the deployment id is valid.
|
||||||
|
"""
|
||||||
|
deployments = logged_in_client.get(
|
||||||
|
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||||
|
).json()
|
||||||
|
deployment_id = deployments[0]["_id"]
|
||||||
|
|
||||||
|
response = logged_in_client.get(
|
||||||
|
"/api/v1/deployment_access", params={"deployment_id": deployment_id}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
out = response.json()
|
||||||
|
out = DeploymentAccess(**out)
|
||||||
|
|
||||||
|
response = logged_in_client.patch(
|
||||||
|
"/api/v1/deployment_access",
|
||||||
|
params={"deployment_id": deployment_id},
|
||||||
|
json={
|
||||||
|
"user_read_access": ["test1"],
|
||||||
|
"user_write_access": ["test2"],
|
||||||
|
"su_read_access": ["test3"],
|
||||||
|
"su_write_access": ["test4"],
|
||||||
|
"remote_read_access": ["test5"],
|
||||||
|
"remote_write_access": ["test6"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
out = response.json()
|
||||||
|
out = DeploymentAccess(**out)
|
||||||
|
assert out.user_read_access == ["test1"]
|
||||||
|
assert out.user_write_access == ["test2"]
|
||||||
|
assert out.su_read_access == ["test3"]
|
||||||
|
assert out.su_write_access == ["test4"]
|
||||||
|
assert out.remote_read_access == ["test5"]
|
||||||
|
assert out.remote_write_access == ["test6"]
|
||||||
|
|
||||||
|
for user in ["test1", "test2", "test3", "test4"]:
|
||||||
|
out = logged_in_client.get(
|
||||||
|
"/api/v1/bec_access", params={"deployment_id": deployment_id, "user": user}
|
||||||
|
)
|
||||||
|
assert out.status_code == 200
|
||||||
|
out = out.json()
|
||||||
|
assert "token" in out
|
||||||
|
|
||||||
|
for user in ["test5", "test6"]:
|
||||||
|
out = logged_in_client.get(
|
||||||
|
"/api/v1/bec_access", params={"deployment_id": deployment_id, "user": user}
|
||||||
|
)
|
||||||
|
assert out.status_code == 404
|
@ -18,7 +18,7 @@ def logged_in_client(backend):
|
|||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_get_deployment_credentials(logged_in_client):
|
def test_get_deployment_credentials(logged_in_client):
|
||||||
"""
|
"""
|
||||||
Test that the login endpoint returns a token.
|
Test that the deployment credentials endpoint returns a token.
|
||||||
"""
|
"""
|
||||||
client = logged_in_client
|
client = logged_in_client
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ def test_get_deployment_credentials(logged_in_client):
|
|||||||
@pytest.mark.timeout(60)
|
@pytest.mark.timeout(60)
|
||||||
def test_refresh_deployment_credentials(logged_in_client):
|
def test_refresh_deployment_credentials(logged_in_client):
|
||||||
"""
|
"""
|
||||||
Test that the login endpoint returns a token.
|
Test that the refresh deployment credentials endpoint returns a new token.
|
||||||
"""
|
"""
|
||||||
client = logged_in_client
|
client = logged_in_client
|
||||||
|
|
||||||
@ -54,3 +54,87 @@ def test_refresh_deployment_credentials(logged_in_client):
|
|||||||
out = response.json()
|
out = response.json()
|
||||||
assert out == {"_id": deployment_id, "credential": out["credential"]}
|
assert out == {"_id": deployment_id, "credential": out["credential"]}
|
||||||
assert out["credential"] != old_token
|
assert out["credential"] != old_token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
def test_deployment_credential_rejects_unauthorized_user(backend):
|
||||||
|
"""
|
||||||
|
Test that the deployment credentials endpoint returns a 403
|
||||||
|
when the user is not authorized.
|
||||||
|
"""
|
||||||
|
client, _ = backend
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/user/login", json={"username": "jane.doe@bec_atlas.ch", "password": "atlas"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
token = response.json()
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) > 20
|
||||||
|
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||||
|
|
||||||
|
deployments = client.get(
|
||||||
|
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||||
|
).json()
|
||||||
|
deployment_id = deployments[0]["_id"]
|
||||||
|
|
||||||
|
response = client.get("/api/v1/deploymentCredentials", params={"deployment_id": deployment_id})
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {"detail": "User does not have permission to access this resource."}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
def test_refresh_deployment_credentials_rejects_unauthorized_user(backend):
|
||||||
|
"""
|
||||||
|
Test that the refresh deployment credentials endpoint returns a 403
|
||||||
|
when the user is not authorized.
|
||||||
|
"""
|
||||||
|
client, _ = backend
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/user/login", json={"username": "jane.doe@bec_atlas.ch", "password": "atlas"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
token = response.json()
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) > 20
|
||||||
|
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||||
|
|
||||||
|
deployments = client.get(
|
||||||
|
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||||
|
).json()
|
||||||
|
deployment_id = deployments[0]["_id"]
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/deploymentCredentials/refresh", params={"deployment_id": deployment_id}
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json() == {"detail": "User does not have permission to access this resource."}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
def test_get_deployment_credentials_wrong_id(logged_in_client):
|
||||||
|
"""
|
||||||
|
Test that the deployment credentials endpoint returns a 400
|
||||||
|
when the deployment ID is invalid.
|
||||||
|
"""
|
||||||
|
client = logged_in_client
|
||||||
|
|
||||||
|
response = client.get("/api/v1/deploymentCredentials", params={"deployment_id": "wrong_id"})
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json() == {"detail": "Invalid deployment ID"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(60)
|
||||||
|
def test_deployment_credentials_refresh_not_found(logged_in_client):
|
||||||
|
"""
|
||||||
|
Test that the deployment credentials refresh endpoint returns a 404
|
||||||
|
when the deployment is not found.
|
||||||
|
"""
|
||||||
|
client = logged_in_client
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/deploymentCredentials/refresh",
|
||||||
|
params={"deployment_id": "678aa8d4875568640bd92000"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
out = response.json()
|
||||||
|
assert out == {"detail": "Deployment not found"}
|
||||||
|
@ -7,7 +7,7 @@ def backend_client(backend):
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(10)
|
@pytest.mark.timeout(20)
|
||||||
def test_login(backend_client):
|
def test_login(backend_client):
|
||||||
"""
|
"""
|
||||||
Test that the login endpoint returns a token.
|
Test that the login endpoint returns a token.
|
||||||
@ -33,7 +33,7 @@ def test_login_wrong_password(backend_client):
|
|||||||
assert response.json() == {"detail": "User not found or password is incorrect"}
|
assert response.json() == {"detail": "User not found or password is incorrect"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(10)
|
@pytest.mark.timeout(20)
|
||||||
def test_login_unknown_user(backend_client):
|
def test_login_unknown_user(backend_client):
|
||||||
"""
|
"""
|
||||||
Test that the login returns a 401 when the user is unknown.
|
Test that the login returns a 401 when the user is unknown.
|
||||||
|
Reference in New Issue
Block a user