diff --git a/backend/bec_atlas/router/bec_access_router.py b/backend/bec_atlas/router/bec_access_router.py index 3dae769..486b7c3 100644 --- a/backend/bec_atlas/router/bec_access_router.py +++ b/backend/bec_atlas/router/bec_access_router.py @@ -1,4 +1,3 @@ -from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException, Query from bec_atlas.authentication import get_current_user diff --git a/backend/bec_atlas/router/deployment_access_router.py b/backend/bec_atlas/router/deployment_access_router.py index 3652049..9510420 100644 --- a/backend/bec_atlas/router/deployment_access_router.py +++ b/backend/bec_atlas/router/deployment_access_router.py @@ -5,15 +5,15 @@ from typing import TYPE_CHECKING, Any from bec_lib.endpoints import EndpointInfo, MessageOp from bec_lib.serialization import MsgpackSerialization from bson import ObjectId -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from bec_atlas.authentication import get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, Deployments, UserInfo +from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.redis_router import RedisAtlasEndpoints -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from bec_atlas.datasources.redis_datasource import RedisDatasource @@ -50,6 +50,8 @@ class DeploymentAccessRouter(BaseRouter): Returns: DeploymentAccess: The access lists for the deployment """ + if not ObjectId.is_valid(deployment_id): + raise HTTPException(status_code=400, detail="Invalid deployment ID") return self.db.find_one( "deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user ) @@ -119,17 +121,17 @@ class DeploymentAccessRouter(BaseRouter): db.delete_one("bec_access_profiles", {"username": profile, "deployment_id": updated.id}) for profile in new_profiles: if profile in updated.su_write_access: - access = self._get_redis_access_profile("su_write", profile, updated.id) + access = self._get_redis_access_profile("su_write", profile, str(updated.id)) elif profile in updated.su_read_access: - access = self._get_redis_access_profile("su_read", profile, updated.id) + access = self._get_redis_access_profile("su_read", profile, str(updated.id)) elif profile in updated.user_write_access: - access = self._get_redis_access_profile("user_write", profile, updated.id) + access = self._get_redis_access_profile("user_write", profile, str(updated.id)) else: - access = self._get_redis_access_profile("user_read", profile, updated.id) + access = self._get_redis_access_profile("user_read", profile, str(updated.id)) existing_profile = db.find_one( "bec_access_profiles", - {"username": profile, "deployment_id": updated.id}, + {"username": profile, "deployment_id": str(updated.id)}, BECAccessProfile, ) if existing_profile: diff --git a/backend/bec_atlas/router/deployment_credentials.py b/backend/bec_atlas/router/deployment_credentials.py index 16d370b..64a7c46 100644 --- a/backend/bec_atlas/router/deployment_credentials.py +++ b/backend/bec_atlas/router/deployment_credentials.py @@ -9,7 +9,7 @@ from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.model.model import DeploymentCredential, UserInfo from bec_atlas.router.base_router import BaseRouter -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from bec_atlas.datasources.redis_datasource import RedisDatasource @@ -23,7 +23,7 @@ class DeploymentCredentialsRouter(BaseRouter): self.deployment_credential, methods=["GET"], description="Retrieve the deployment key for a specific deployment.", - response_model=DeploymentCredential, + response_model=DeploymentCredential | None, ) self.router.add_api_route( "/deploymentCredentials/refresh", @@ -42,6 +42,8 @@ class DeploymentCredentialsRouter(BaseRouter): Args: deployment_id (str): The deployment id """ + if not ObjectId.is_valid(deployment_id): + raise HTTPException(status_code=400, detail="Invalid deployment ID") if set(current_user.groups) & set(["admin", "bec_group"]): out = self.db.find( "deployment_credentials", {"_id": ObjectId(deployment_id)}, DeploymentCredential diff --git a/backend/tests/test_deployment_access_router.py b/backend/tests/test_deployment_access_router.py new file mode 100644 index 0000000..3e2d884 --- /dev/null +++ b/backend/tests/test_deployment_access_router.py @@ -0,0 +1,96 @@ +import pytest + +from bec_atlas.model.model import DeploymentAccess + + +@pytest.fixture +def logged_in_client(backend): + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + client.headers.update({"Authorization": f"Bearer {token}"}) + return client + + +def test_deployment_access_router_invalid_deployment_id(logged_in_client): + """ + Test that the deployment access endpoint returns a 400 when the deployment id is invalid. + """ + response = logged_in_client.get("/api/v1/deployment_access", params={"deployment_id": "test"}) + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid deployment ID"} + + +def test_deployment_access_router(logged_in_client): + """ + Test that the deployment access endpoint returns a 200 when the deployment id is valid. + """ + deployments = logged_in_client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = logged_in_client.get( + "/api/v1/deployment_access", params={"deployment_id": deployment_id} + ) + assert response.status_code == 200 + out = response.json() + out = DeploymentAccess(**out) + + +def test_patch_deployment_access(logged_in_client): + """ + Test that the deployment access endpoint returns a 200 when the deployment id is valid. + """ + deployments = logged_in_client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = logged_in_client.get( + "/api/v1/deployment_access", params={"deployment_id": deployment_id} + ) + assert response.status_code == 200 + out = response.json() + out = DeploymentAccess(**out) + + response = logged_in_client.patch( + "/api/v1/deployment_access", + params={"deployment_id": deployment_id}, + json={ + "user_read_access": ["test1"], + "user_write_access": ["test2"], + "su_read_access": ["test3"], + "su_write_access": ["test4"], + "remote_read_access": ["test5"], + "remote_write_access": ["test6"], + }, + ) + assert response.status_code == 200 + out = response.json() + out = DeploymentAccess(**out) + assert out.user_read_access == ["test1"] + assert out.user_write_access == ["test2"] + assert out.su_read_access == ["test3"] + assert out.su_write_access == ["test4"] + assert out.remote_read_access == ["test5"] + assert out.remote_write_access == ["test6"] + + for user in ["test1", "test2", "test3", "test4"]: + out = logged_in_client.get( + "/api/v1/bec_access", params={"deployment_id": deployment_id, "user": user} + ) + assert out.status_code == 200 + out = out.json() + assert "token" in out + + for user in ["test5", "test6"]: + out = logged_in_client.get( + "/api/v1/bec_access", params={"deployment_id": deployment_id, "user": user} + ) + assert out.status_code == 404 diff --git a/backend/tests/test_deployment_credentials.py b/backend/tests/test_deployment_credentials.py index 9486d57..f1808c9 100644 --- a/backend/tests/test_deployment_credentials.py +++ b/backend/tests/test_deployment_credentials.py @@ -18,7 +18,7 @@ def logged_in_client(backend): @pytest.mark.timeout(60) def test_get_deployment_credentials(logged_in_client): """ - Test that the login endpoint returns a token. + Test that the deployment credentials endpoint returns a token. """ client = logged_in_client @@ -34,7 +34,7 @@ def test_get_deployment_credentials(logged_in_client): @pytest.mark.timeout(60) def test_refresh_deployment_credentials(logged_in_client): """ - Test that the login endpoint returns a token. + Test that the refresh deployment credentials endpoint returns a new token. """ client = logged_in_client @@ -54,3 +54,87 @@ def test_refresh_deployment_credentials(logged_in_client): out = response.json() assert out == {"_id": deployment_id, "credential": out["credential"]} assert out["credential"] != old_token + + +@pytest.mark.timeout(60) +def test_deployment_credential_rejects_unauthorized_user(backend): + """ + Test that the deployment credentials endpoint returns a 403 + when the user is not authorized. + """ + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "jane.doe@bec_atlas.ch", "password": "atlas"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + client.headers.update({"Authorization": f"Bearer {token}"}) + + deployments = client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = client.get("/api/v1/deploymentCredentials", params={"deployment_id": deployment_id}) + assert response.status_code == 403 + assert response.json() == {"detail": "User does not have permission to access this resource."} + + +@pytest.mark.timeout(60) +def test_refresh_deployment_credentials_rejects_unauthorized_user(backend): + """ + Test that the refresh deployment credentials endpoint returns a 403 + when the user is not authorized. + """ + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "jane.doe@bec_atlas.ch", "password": "atlas"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + client.headers.update({"Authorization": f"Bearer {token}"}) + + deployments = client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = client.post( + "/api/v1/deploymentCredentials/refresh", params={"deployment_id": deployment_id} + ) + assert response.status_code == 403 + assert response.json() == {"detail": "User does not have permission to access this resource."} + + +@pytest.mark.timeout(60) +def test_get_deployment_credentials_wrong_id(logged_in_client): + """ + Test that the deployment credentials endpoint returns a 400 + when the deployment ID is invalid. + """ + client = logged_in_client + + response = client.get("/api/v1/deploymentCredentials", params={"deployment_id": "wrong_id"}) + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid deployment ID"} + + +@pytest.mark.timeout(60) +def test_deployment_credentials_refresh_not_found(logged_in_client): + """ + Test that the deployment credentials refresh endpoint returns a 404 + when the deployment is not found. + """ + client = logged_in_client + + response = client.post( + "/api/v1/deploymentCredentials/refresh", + params={"deployment_id": "678aa8d4875568640bd92000"}, + ) + assert response.status_code == 404 + out = response.json() + assert out == {"detail": "Deployment not found"} diff --git a/backend/tests/test_login.py b/backend/tests/test_login.py index 0131476..0f33b74 100644 --- a/backend/tests/test_login.py +++ b/backend/tests/test_login.py @@ -7,7 +7,7 @@ def backend_client(backend): return client -@pytest.mark.timeout(10) +@pytest.mark.timeout(20) def test_login(backend_client): """ Test that the login endpoint returns a token. @@ -33,7 +33,7 @@ def test_login_wrong_password(backend_client): assert response.json() == {"detail": "User not found or password is incorrect"} -@pytest.mark.timeout(10) +@pytest.mark.timeout(20) def test_login_unknown_user(backend_client): """ Test that the login returns a 401 when the user is unknown.