diff --git a/backend/bec_atlas/main.py b/backend/bec_atlas/main.py index b92a7e5..36857d7 100644 --- a/backend/bec_atlas/main.py +++ b/backend/bec_atlas/main.py @@ -108,14 +108,14 @@ class AtlasApp: ) self.app.mount("/", self.redis_websocket.app) - def run(self, port=8000): + def run(self, port=8000): # pragma: no cover config = uvicorn.Config(self.app, host="localhost", port=port) self.server = uvicorn.Server(config=config) self.server.run() # uvicorn.run(self.app, host="localhost", port=port) -def main(): +def main(): # pragma: no cover import argparse import logging @@ -132,5 +132,5 @@ def main(): horizon_app.run(port=args.port) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/backend/bec_atlas/router/deployments_router.py b/backend/bec_atlas/router/deployments_router.py index df187e5..ef356b1 100644 --- a/backend/bec_atlas/router/deployments_router.py +++ b/backend/bec_atlas/router/deployments_router.py @@ -2,14 +2,14 @@ import json from typing import TYPE_CHECKING from bson import ObjectId -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from bec_atlas.authentication import get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.model.model import DeploymentCredential, Deployments, UserInfo from bec_atlas.router.base_router import BaseRouter -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from bec_atlas.datasources.redis_datasource import RedisDatasource @@ -57,6 +57,8 @@ class DeploymentsRouter(BaseRouter): Args: scan_id (str): The scan id """ + if not ObjectId.is_valid(deployment_id): + raise HTTPException(status_code=400, detail="Invalid deployment id") return self.db.find_one( "deployments", {"_id": ObjectId(deployment_id)}, Deployments, user=current_user ) diff --git a/backend/bec_atlas/utils/ldap_auth.py b/backend/bec_atlas/utils/ldap_auth.py index 3fdb961..e03c085 100644 --- a/backend/bec_atlas/utils/ldap_auth.py +++ b/backend/bec_atlas/utils/ldap_auth.py @@ -1,10 +1,10 @@ from ldap3 import ALL, SUBTREE, Connection, Server -from ldap3.core.exceptions import LDAPBindError +from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError class LDAPUserService: def __init__(self, ldap_server, base_dn): - self.server = Server(ldap_server, get_info=ALL) + self.server = Server(ldap_server, get_info=ALL, connect_timeout=5) self.base_dn = base_dn def authenticate_and_get_info(self, principal, password): @@ -52,12 +52,12 @@ class LDAPUserService: } return user_data - except LDAPBindError as e: + except (LDAPBindError, LDAPSocketOpenError) as e: print(f"LDAP authentication failed: {e}") return None -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover ldap_service = LDAPUserService( ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch" ) diff --git a/backend/tests/test_deployment_credentials.py b/backend/tests/test_deployment_credentials.py new file mode 100644 index 0000000..9486d57 --- /dev/null +++ b/backend/tests/test_deployment_credentials.py @@ -0,0 +1,56 @@ +import pytest + + +@pytest.fixture +def logged_in_client(backend): + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + client.headers.update({"Authorization": f"Bearer {token}"}) + return client + + +@pytest.mark.timeout(60) +def test_get_deployment_credentials(logged_in_client): + """ + Test that the login endpoint returns a token. + """ + client = logged_in_client + + deployments = client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = client.get("/api/v1/deploymentCredentials", params={"deployment_id": deployment_id}) + assert response.status_code == 200 + + +@pytest.mark.timeout(60) +def test_refresh_deployment_credentials(logged_in_client): + """ + Test that the login endpoint returns a token. + """ + client = logged_in_client + + deployments = client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + old_token = client.get( + "/api/v1/deploymentCredentials", params={"deployment_id": deployment_id} + ).json()["credential"] + + response = client.post( + "/api/v1/deploymentCredentials/refresh", params={"deployment_id": deployment_id} + ) + assert response.status_code == 200 + out = response.json() + assert out == {"_id": deployment_id, "credential": out["credential"]} + assert out["credential"] != old_token diff --git a/backend/tests/test_deployment_router.py b/backend/tests/test_deployment_router.py new file mode 100644 index 0000000..5002401 --- /dev/null +++ b/backend/tests/test_deployment_router.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture +def logged_in_client(backend): + client, _ = backend + response = client.post( + "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + client.headers.update({"Authorization": f"Bearer {token}"}) + return client + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize("realm, num_deployments", [("test", 0), ("demo_beamline_1", 1)]) +def test_get_deployment_by_realm(logged_in_client, realm, num_deployments): + """ + Test that the login endpoint returns a token. + """ + client = logged_in_client + response = client.get("/api/v1/deployments/realm", params={"realm": realm}) + assert response.status_code == 200 + deployments = response.json() + assert len(deployments) == num_deployments + + +@pytest.mark.timeout(60) +def test_get_deployment_by_id(logged_in_client): + """ + Test that the login endpoint returns a token. + """ + client = logged_in_client + + deployments = client.get( + "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} + ).json() + deployment_id = deployments[0]["_id"] + + response = client.get("/api/v1/deployments/id", params={"deployment_id": deployment_id}) + assert response.status_code == 200 + deployment = response.json() + assert deployment["_id"] == deployment_id + assert deployment["realm_id"] == "demo_beamline_1" + + +@pytest.mark.timeout(60) +def test_get_deployment_by_id_wrong_id(logged_in_client): + """ + Test that the login endpoint returns a token. + """ + client = logged_in_client + + response = client.get("/api/v1/deployments/id", params={"deployment_id": "wrong_id"}) + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid deployment id"} diff --git a/backend/tests/test_login.py b/backend/tests/test_login.py index 30b840b..0131476 100644 --- a/backend/tests/test_login.py +++ b/backend/tests/test_login.py @@ -7,7 +7,7 @@ def backend_client(backend): return client -@pytest.mark.timeout(60) +@pytest.mark.timeout(10) def test_login(backend_client): """ Test that the login endpoint returns a token. @@ -21,7 +21,7 @@ def test_login(backend_client): assert len(token) > 20 -@pytest.mark.timeout(60) +@pytest.mark.timeout(20) def test_login_wrong_password(backend_client): """ Test that the login returns a 401 when the password is wrong. @@ -33,7 +33,7 @@ def test_login_wrong_password(backend_client): assert response.json() == {"detail": "User not found or password is incorrect"} -@pytest.mark.timeout(60) +@pytest.mark.timeout(10) def test_login_unknown_user(backend_client): """ Test that the login returns a 401 when the user is unknown.