diff --git a/backend/bec_atlas/authentication.py b/backend/bec_atlas/authentication.py index c5e0e9d..4f58c56 100644 --- a/backend/bec_atlas/authentication.py +++ b/backend/bec_atlas/authentication.py @@ -2,10 +2,10 @@ from __future__ import annotations import os from datetime import datetime, timedelta -from typing import Annotated +from functools import wraps import jwt -from fastapi import Depends, HTTPException, status +from fastapi import HTTPException, Request, status from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from pwdlib import PasswordHash @@ -19,6 +19,24 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form") password_hash = PasswordHash.recommended() +def convert_to_user(func): + """ + Decorator to convert the current_user parameter to a User object. + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + if "current_user" in kwargs: + current_user = kwargs["current_user"] + if current_user: + router = args[0] + user = router.get_user_from_db(current_user.token, current_user.email) + kwargs["current_user"] = user + return await func(*args, **kwargs) + + return wrapper + + def get_secret_key(): val = os.getenv("SECRET_KEY", "test_secret") return val @@ -56,7 +74,12 @@ def decode_token(token: str): raise credentials_exception from exc -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo: +async def get_current_user(request: Request) -> UserInfo: + token = request.cookies.get("access_token") + return get_current_user_sync(token) + + +def get_current_user_sync(token: str) -> UserInfo: credentials_exception = HTTPException( status_code=401, detail="Could not validate credentials", @@ -70,4 +93,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use raise credentials_exception except Exception as exc: raise credentials_exception from exc - return UserInfo(groups=groups, email=email) + return UserInfo(email=email, token=token) diff --git a/backend/bec_atlas/model/model.py b/backend/bec_atlas/model/model.py index b1ca975..77c6d07 100644 --- a/backend/bec_atlas/model/model.py +++ b/backend/bec_atlas/model/model.py @@ -67,7 +67,7 @@ class User(MongoBaseModel, AccessProfile): class UserInfo(BaseModel): email: str - groups: list[str] + token: str class Deployments(MongoBaseModel, AccessProfile): diff --git a/backend/bec_atlas/router/base_router.py b/backend/bec_atlas/router/base_router.py index 8edd8cf..c8f4d8b 100644 --- a/backend/bec_atlas/router/base_router.py +++ b/backend/bec_atlas/router/base_router.py @@ -1,4 +1,23 @@ +from functools import lru_cache + +from bec_atlas.model.model import User + + class BaseRouter: def __init__(self, prefix: str = "/api/v1", datasources=None) -> None: self.datasources = datasources self.prefix = prefix + + @lru_cache(maxsize=128) + def get_user_from_db(self, _token: str, email: str) -> User: + """ + Get the user from the database. This is a helper function to be used by the + convert_to_user decorator. The function is cached to avoid repeated database + queries. To scope the cache to the current request, the token and email are + used as the cache key. + + Args: + _token (str): The token + email (str): The email + """ + return self.datasources.datasources["mongodb"].get_user_by_email(email) diff --git a/backend/bec_atlas/router/bec_access_router.py b/backend/bec_atlas/router/bec_access_router.py index 486b7c3..81be5d6 100644 --- a/backend/bec_atlas/router/bec_access_router.py +++ b/backend/bec_atlas/router/bec_access_router.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends, HTTPException, Query -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import BECAccessProfile, UserInfo +from bec_atlas.model.model import BECAccessProfile, User, UserInfo from bec_atlas.router.base_router import BaseRouter @@ -18,11 +18,12 @@ class BECAccessRouter(BaseRouter): description="Retrieve the access key for a specific deployment and user.", ) + @convert_to_user async def get_bec_access( self, deployment_id: str, user: str = Query(None), - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> dict: """ Retrieve the access key for a specific deployment and user. @@ -31,7 +32,7 @@ class BECAccessRouter(BaseRouter): deployment_id (str): The deployment id user (str): The user name to retrieve the access key for. If not provided, the access key for the current user will be retrieved. - current_user (UserInfo): The current user + current_user (User): The current user """ if not user: user = current_user.email diff --git a/backend/bec_atlas/router/deployment_access_router.py b/backend/bec_atlas/router/deployment_access_router.py index 9510420..bb7766b 100644 --- a/backend/bec_atlas/router/deployment_access_router.py +++ b/backend/bec_atlas/router/deployment_access_router.py @@ -7,9 +7,9 @@ from bec_lib.serialization import MsgpackSerialization from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo +from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.redis_router import RedisAtlasEndpoints @@ -37,8 +37,9 @@ class DeploymentAccessRouter(BaseRouter): response_model=DeploymentAccess, ) + @convert_to_user async def get_deployment_access( - self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) + self, deployment_id: str, current_user: User = Depends(get_current_user) ) -> DeploymentAccess: """ Get the access lists for a specific deployment. @@ -56,11 +57,12 @@ class DeploymentAccessRouter(BaseRouter): "deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user ) + @convert_to_user async def patch_deployment_access( self, deployment_id: str, deployment_access: dict, - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> DeploymentAccess: """ Update the access lists for a specific deployment. diff --git a/backend/bec_atlas/router/deployment_credentials.py b/backend/bec_atlas/router/deployment_credentials.py index 64a7c46..69d2645 100644 --- a/backend/bec_atlas/router/deployment_credentials.py +++ b/backend/bec_atlas/router/deployment_credentials.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import DeploymentCredential, UserInfo +from bec_atlas.model.model import DeploymentCredential, User from bec_atlas.router.base_router import BaseRouter if TYPE_CHECKING: # pragma: no cover @@ -33,8 +33,9 @@ class DeploymentCredentialsRouter(BaseRouter): response_model=DeploymentCredential, ) + @convert_to_user async def deployment_credential( - self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) + self, deployment_id: str, current_user: User = Depends(get_current_user) ) -> DeploymentCredential: """ Get the credentials for a deployment. @@ -56,8 +57,9 @@ class DeploymentCredentialsRouter(BaseRouter): status_code=403, detail="User does not have permission to access this resource." ) + @convert_to_user async def refresh_deployment_credentials( - self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) + self, deployment_id: str, current_user: User = Depends(get_current_user) ): """ Refresh the deployment credentials. diff --git a/backend/bec_atlas/router/deployments_router.py b/backend/bec_atlas/router/deployments_router.py index 518a782..78468b5 100644 --- a/backend/bec_atlas/router/deployments_router.py +++ b/backend/bec_atlas/router/deployments_router.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import DeploymentCredential, Deployments, UserInfo +from bec_atlas.model.model import DeploymentCredential, Deployments, User from bec_atlas.router.base_router import BaseRouter if TYPE_CHECKING: # pragma: no cover @@ -34,8 +34,9 @@ class DeploymentsRouter(BaseRouter): ) self.update_available_deployments() + @convert_to_user async def deployments( - self, realm: str, current_user: UserInfo = Depends(get_current_user) + self, realm: str, current_user: User = Depends(get_current_user) ) -> list[Deployments]: """ Get all deployments for a realm. @@ -48,8 +49,9 @@ class DeploymentsRouter(BaseRouter): """ return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user) + @convert_to_user async def deployment_with_id( - self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) + self, deployment_id: str, current_user: User = Depends(get_current_user) ): """ Get deployment with id from realm diff --git a/backend/bec_atlas/router/realm_router.py b/backend/bec_atlas/router/realm_router.py index b6c4700..f0b0eda 100644 --- a/backend/bec_atlas/router/realm_router.py +++ b/backend/bec_atlas/router/realm_router.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import DeploymentAccess, Realm, UserInfo +from bec_atlas.model.model import DeploymentAccess, Realm, User from bec_atlas.router.base_router import BaseRouter @@ -36,8 +36,9 @@ class RealmRouter(BaseRouter): response_model_exclude_none=True, ) + @convert_to_user async def realms( - self, include_deployments: bool = False, current_user: UserInfo = Depends(get_current_user) + self, include_deployments: bool = False, current_user: User = Depends(get_current_user) ) -> list[Realm]: """ Get all realms. @@ -62,8 +63,9 @@ class RealmRouter(BaseRouter): return self.db.aggregate("realms", include, Realm, user=current_user) return self.db.find("realms", {}, Realm, user=current_user) + @convert_to_user async def realm_with_deployment_access( - self, owner_only: bool = False, current_user: UserInfo = Depends(get_current_user) + self, owner_only: bool = False, current_user: User = Depends(get_current_user) ): """ Get all realms with deployment access. @@ -96,9 +98,8 @@ class RealmRouter(BaseRouter): ] return self.db.aggregate("realms", include, Realm, user=current_user) - async def realm_with_id( - self, realm_id: str, current_user: UserInfo = Depends(get_current_user) - ): + @convert_to_user + async def realm_with_id(self, realm_id: str, current_user: User = Depends(get_current_user)): """ Get realm with id. diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 2a14504..6d66c7d 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -1,4 +1,5 @@ import asyncio +import enum import functools import inspect import json @@ -10,8 +11,11 @@ import socketio from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.logger import bec_logger from bec_lib.serialization import MsgpackSerialization, json_ext -from fastapi import APIRouter, Query, Response +from bson import ObjectId +from fastapi import APIRouter, Depends, Query, Response +from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync +from bec_atlas.model.model import DeploymentAccess, User from bec_atlas.router.base_router import BaseRouter logger = bec_logger.logger @@ -20,6 +24,12 @@ if TYPE_CHECKING: from bec_lib.redis_connector import RedisConnector +class RemoteAccess(enum.Enum): + READ = "read" + READ_WRITE = "read_write" + NONE = "none" + + class RedisAtlasEndpoints: """ This class contains the endpoints for the Redis API. It is used to @@ -133,7 +143,10 @@ class RedisRouter(BaseRouter): self.router.add_api_route("/redis", self.redis_post, methods=["POST"]) self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"]) - async def redis_get(self, deployment: str, key: str = Query(...)): + @convert_to_user + async def redis_get( + self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user) + ): request_id = uuid.uuid4().hex response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id) request_endpoint = RedisAtlasEndpoints.redis_request(deployment) @@ -149,10 +162,14 @@ class RedisRouter(BaseRouter): return json_ext.dumps({"data": out.content, "metadata": out.metadata}) - async def redis_post(self, key: str, value: str): + @convert_to_user + async def redis_post( + self, key: str, value: str, current_user: User = Depends(get_current_user) + ): return self.redis.set(key, value) - async def redis_delete(self, key: str): + @convert_to_user + async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)): return self.redis.delete(key) @@ -268,6 +285,7 @@ class RedisWebsocket: redis_host = datasources.datasources["redis"].config["host"] redis_port = datasources.datasources["redis"].config["port"] redis_password = datasources.datasources["redis"].config.get("password", "ingestor") + self.db = datasources.datasources["mongodb"] self.socket = socketio.AsyncServer( transports=["websocket"], ping_timeout=60, @@ -289,7 +307,7 @@ class RedisWebsocket: self.socket.on("disconnect", self.disconnect_client) print("Redis websocket started") - def _validate_new_user(self, http_query: str | None): + def _validate_new_user(self, http_query: str | None, auth_token: str) -> tuple: """ Validate the connection of a new user. In particular, the user must provide a valid token as well as have access @@ -310,19 +328,41 @@ class RedisWebsocket: else: query = http_query - if "user" not in query: - raise ValueError("User not found in query parameters") - user = query["user"] - - # TODO: Validate the user token + user_info = get_current_user_sync(auth_token) + user = self.db.find_one("users", {"email": user_info.email}, User) deployment = query.get("deployment") if not deployment: raise ValueError("Deployment not found in query parameters") - # TODO: Validate the user has access to the deployment + deployment_access = self.db.find_one( + "deployments", {"_id": ObjectId(deployment)}, DeploymentAccess + ) + if not deployment_access: + raise ValueError("Deployment not found") - return user, deployment + access = self.get_access(user, deployment_access) + if access == RemoteAccess.NONE: + raise ValueError("User does not have remote access to the deployment") + + return user, deployment, access + + def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess: + """ + Get the access level of the user to the deployment. + """ + access = RemoteAccess.NONE + groups = set(user.groups) + if user.username is not None: + groups.add(user.username) + if user.email is not None: + groups.add(user.email) + + if groups & set(deployment_access.remote_read_access): + access = RemoteAccess.READ + if groups & set(deployment_access.remote_write_access): + access = RemoteAccess.READ_WRITE + return access @safe_socket async def connect_client(self, sid, environ=None, auth=None, **kwargs): @@ -332,7 +372,23 @@ class RedisWebsocket: http_query = environ.get("HTTP_QUERY") or auth - user, deployment = self._validate_new_user(http_query) + cookies = environ.get("HTTP_COOKIE", "") + auth_token = None + + for cookie in cookies.split("; "): + if cookie.startswith("access_token="): + auth_token = cookie.split("=")[1] + break + + if not auth_token: + await self.disconnect_client(sid) # Reject connection + return + + try: + user, deployment, access = self._validate_new_user(http_query, auth_token) + except ValueError: + await self.disconnect_client(sid, reason="Invalid user or deployment") + return # check if the user was already registered in redis socketio_server_keys = await self.socket.manager.redis.keys( @@ -350,22 +406,25 @@ class RedisWebsocket: for value in obj.values(): info[value["user"]] = value["subscriptions"] - if user in info: - self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} + if user.email in info: + self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment} for endpoint, endpoint_request in info[user]: print(f"Registering {endpoint}") await self._update_user_subscriptions(sid, endpoint, endpoint_request) else: - self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} + self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment} await self.socket.manager.update_websocket_states() - async def disconnect_client(self, sid, _environ=None): - print("Client disconnected") + async def disconnect_client(self, sid, reason: str = None, _environ=None): is_exit = self.fastapi_app.server.should_exit if is_exit: return - await self.socket.manager.remove_user(sid) + if reason: + await self.socket.emit("error", {"error": reason}, room=sid) + if sid in self.users: + del self.users[sid] + await self.socket.disconnect(sid) @safe_socket async def redis_register(self, sid: str, msg: str): diff --git a/backend/bec_atlas/router/scan_router.py b/backend/bec_atlas/router/scan_router.py index 99497e4..189025a 100644 --- a/backend/bec_atlas/router/scan_router.py +++ b/backend/bec_atlas/router/scan_router.py @@ -5,9 +5,9 @@ import json from bson import ObjectId from fastapi import APIRouter, Depends, HTTPException, Query -from bec_atlas.authentication import get_current_user +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import ScanStatusPartial, ScanUserData, UserInfo +from bec_atlas.model.model import ScanStatusPartial, ScanUserData, User, UserInfo from bec_atlas.router.base_router import BaseRouter @@ -47,6 +47,7 @@ class ScanRouter(BaseRouter): response_model=dict, ) + @convert_to_user async def scans( self, session_id: str, @@ -55,7 +56,7 @@ class ScanRouter(BaseRouter): offset: int = 0, limit: int = 100, sort: str | None = None, - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> list[ScanStatusPartial]: """ Get all scans for a session. @@ -69,7 +70,7 @@ class ScanRouter(BaseRouter): sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order, '{"name": -1}' for descending order. Multiple fields can be sorted by separating them with a comma, e.g. '{"name": 1, "description": -1}' - current_user (UserInfo): The current user + current_user (User): The current user Returns: list[ScanStatusPartial]: List of scans @@ -100,11 +101,12 @@ class ScanRouter(BaseRouter): user=current_user, ) + @convert_to_user async def scans_with_id( self, scan_id: str, fields: list[str] = Query(default=None), - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ): """ Get scan with id from session @@ -135,6 +137,7 @@ class ScanRouter(BaseRouter): scan_id (str): The scan id user_data (dict): The user data to update """ + current_user = self.get_user_from_db(current_user.token, current_user.email) out = self.db.patch( "scans", id=scan_id, @@ -147,15 +150,16 @@ class ScanRouter(BaseRouter): raise HTTPException(status_code=404, detail="Scan not found") return {"message": "Scan user data updated."} + @convert_to_user async def count_scans( - self, filter: str | None = None, current_user: UserInfo = Depends(get_current_user) + self, filter: str | None = None, current_user: User = Depends(get_current_user) ) -> int: """ Count the number of scans. Args: filter (str): JSON filter for the query, e.g. '{"name": "test"}' - current_user (UserInfo): The current user + current_user (User): The current user Returns: int: The number of scans diff --git a/backend/bec_atlas/router/session_router.py b/backend/bec_atlas/router/session_router.py index 8a919a1..da988d5 100644 --- a/backend/bec_atlas/router/session_router.py +++ b/backend/bec_atlas/router/session_router.py @@ -1,15 +1,10 @@ import json -from typing import Annotated from fastapi import APIRouter, Depends, Query -from fastapi.exceptions import HTTPException -from fastapi.security import OAuth2PasswordRequestForm -from pydantic import BaseModel -from bec_atlas.authentication import create_access_token, get_current_user, verify_password +from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model import UserInfo -from bec_atlas.model.model import Session +from bec_atlas.model.model import Session, User from bec_atlas.router.base_router import BaseRouter @@ -35,6 +30,7 @@ class SessionRouter(BaseRouter): response_model_exclude_none=True, ) + @convert_to_user async def sessions( self, filter: str | None = None, @@ -42,7 +38,7 @@ class SessionRouter(BaseRouter): offset: int = 0, limit: int = 100, sort: str | None = None, - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> list[Session]: """ Get all sessions. @@ -55,7 +51,7 @@ class SessionRouter(BaseRouter): sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order, '{"name": -1}' for descending order. Multiple fields can be sorted by separating them with a comma, e.g. '{"name": 1, "description": -1}' - current_user (UserInfo): The current user + current_user (User): The current user Returns: list[Sessions]: List of sessions @@ -82,6 +78,7 @@ class SessionRouter(BaseRouter): user=current_user, ) + @convert_to_user async def sessions_by_realm( self, realm_id: str, @@ -90,7 +87,7 @@ class SessionRouter(BaseRouter): offset: int = 0, limit: int = 100, sort: str | None = None, - current_user: UserInfo = Depends(get_current_user), + current_user: User = Depends(get_current_user), ) -> list[Session]: """ Get all sessions for a realm. diff --git a/backend/bec_atlas/router/user_router.py b/backend/bec_atlas/router/user_router.py index 5f9aed5..48d82d0 100644 --- a/backend/bec_atlas/router/user_router.py +++ b/backend/bec_atlas/router/user_router.py @@ -1,11 +1,16 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Response from fastapi.exceptions import HTTPException from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel -from bec_atlas.authentication import create_access_token, get_current_user, verify_password +from bec_atlas.authentication import ( + convert_to_user, + create_access_token, + get_current_user, + verify_password, +) from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.model import UserInfo from bec_atlas.model.model import User @@ -19,8 +24,9 @@ class UserLoginRequest(BaseModel): class UserRouter(BaseRouter): - def __init__(self, prefix="/api/v1", datasources=None): + def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True): super().__init__(prefix, datasources) + self.use_ssl = use_ssl self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") self.ldap = LDAPUserService( ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch" @@ -31,23 +37,30 @@ class UserRouter(BaseRouter): self.router.add_api_route( "/user/login/form", self.form_login, methods=["POST"], dependencies=[] ) + self.router.add_api_route("/user/logout", self.user_logout, methods=["POST"]) - async def user_me(self, user: UserInfo = Depends(get_current_user)): - data = self.db.get_user_by_email(user.email) - if data is None: - raise HTTPException(status_code=404, detail="User not found") - return data + @convert_to_user + async def user_me(self, user: User = Depends(get_current_user)): + return user - async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]): + async def form_login( + self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response + ): user_login = UserLoginRequest(username=form_data.username, password=form_data.password) - out = await self.user_login(user_login) + out = await self.user_login(user_login, response) return {"access_token": out, "token_type": "bearer"} - async def user_login(self, user_login: UserLoginRequest): + async def user_login(self, user_login: UserLoginRequest, response: Response): user = self._get_user(user_login) if user is None: raise HTTPException(status_code=401, detail="User not found or password is incorrect") - return create_access_token(data={"groups": list(user.groups), "email": user.email}) + token = create_access_token(data={"groups": list(user.groups), "email": user.email}) + response.set_cookie(key="access_token", value=token, httponly=True, secure=self.use_ssl) + return token + + async def user_logout(self, response: Response): + response.delete_cookie("access_token") + return {"message": "Logged out"} def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None: user = self._get_functional_account(user_login) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index e546b3b..f24ffdf 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -118,4 +118,5 @@ def backend(redis_server): "bec_atlas.router.redis_router.BECAsyncRedisManager", PatchedBECAsyncRedisManager ): with TestClient(app.app) as _client: + app.user_router.use_ssl = False # disable ssl to allow for httponly cookies yield _client, app diff --git a/backend/tests/test_deployment_access_router.py b/backend/tests/test_deployment_access_router.py index 3e2d884..3158d0b 100644 --- a/backend/tests/test_deployment_access_router.py +++ b/backend/tests/test_deployment_access_router.py @@ -13,7 +13,6 @@ def logged_in_client(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) return client diff --git a/backend/tests/test_deployment_credentials.py b/backend/tests/test_deployment_credentials.py index f1808c9..4e677ab 100644 --- a/backend/tests/test_deployment_credentials.py +++ b/backend/tests/test_deployment_credentials.py @@ -11,7 +11,6 @@ def logged_in_client(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) return client @@ -70,7 +69,6 @@ def test_deployment_credential_rejects_unauthorized_user(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) deployments = client.get( "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} @@ -96,7 +94,6 @@ def test_refresh_deployment_credentials_rejects_unauthorized_user(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) deployments = client.get( "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} diff --git a/backend/tests/test_deployment_router.py b/backend/tests/test_deployment_router.py index 5002401..4695dc2 100644 --- a/backend/tests/test_deployment_router.py +++ b/backend/tests/test_deployment_router.py @@ -11,7 +11,6 @@ def logged_in_client(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) return client diff --git a/backend/tests/test_redis_websocket.py b/backend/tests/test_redis_websocket.py index 64959f2..6e2d8ee 100644 --- a/backend/tests/test_redis_websocket.py +++ b/backend/tests/test_redis_websocket.py @@ -2,9 +2,10 @@ import json from unittest import mock import pytest -from bec_atlas.router.redis_router import RedisAtlasEndpoints from bec_lib.endpoints import MessageEndpoints +from bec_atlas.router.redis_router import RedisAtlasEndpoints, RemoteAccess + @pytest.fixture def backend_client(backend): @@ -12,48 +13,75 @@ def backend_client(backend): app.server = mock.Mock() app.server.should_exit = False app.redis_websocket.users = {} - yield client, app - # app.redis_websocket.redis._redis_conn.flushall() + response = client.post( + "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} + ) + assert response.status_code == 200 + token = response.json() + assert isinstance(token, str) + assert len(token) > 20 + return client, app + + +@pytest.fixture +@pytest.mark.asyncio(loop_scope="session") +async def connected_ws(backend_client): + client, app = backend_client + deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json() + with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ): + await app.redis_websocket.socket.handlers["/"]["connect"]( + "sid", + { + "HTTP_QUERY": json.dumps({"deployment": deployment[0]["_id"]}), + "HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}", + }, + ) + yield backend_client @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_connect(backend_client): - client, app = backend_client - await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} - ) +async def test_redis_websocket_connect(connected_ws): + _, app = await anext(connected_ws) assert "sid" in app.redis_websocket.users @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_disconnect(backend_client): - client, app = backend_client - app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []} +async def test_redis_websocket_disconnect(connected_ws): + _, app = await anext(connected_ws) await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") assert "sid" not in app.redis_websocket.users @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_multiple_connect(backend_client): - client, app = backend_client +async def test_redis_websocket_multiple_connect(connected_ws): + client, app = await anext(connected_ws) + await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid1", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} + "sid2", + { + "HTTP_QUERY": json.dumps( + {"deployment": app.redis_websocket.users["sid"]["deployment"]} + ), + "HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}", + }, ) - await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid2", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} - ) - assert "sid1" in app.redis_websocket.users + + assert "sid" in app.redis_websocket.users assert "sid2" in app.redis_websocket.users @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_multiple_connect_same_sid(backend_client): - client, app = backend_client +async def test_redis_websocket_multiple_connect_same_sid(connected_ws): + client, app = await anext(connected_ws) + await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} - ) - await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} + "sid", + { + "HTTP_QUERY": json.dumps( + {"deployment": app.redis_websocket.users["sid"]["deployment"]} + ), + "HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}", + }, ) assert "sid" in app.redis_websocket.users @@ -61,9 +89,8 @@ async def test_redis_websocket_multiple_connect_same_sid(backend_client): @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_multiple_disconnect_same_sid(backend_client): - client, app = backend_client - app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []} +async def test_redis_websocket_multiple_disconnect_same_sid(connected_ws): + client, app = await anext(connected_ws) await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") assert "sid" not in app.redis_websocket.users @@ -82,14 +109,10 @@ async def test_redis_websocket_register_wrong_endpoint_raises(backend_client): @pytest.mark.asyncio(loop_scope="session") -async def test_redis_websocket_register(backend_client): - client, app = backend_client +async def test_redis_websocket_register(connected_ws): + client, app = await anext(connected_ws) with mock.patch.object(app.redis_websocket.socket, "emit") as emit: with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room: - await app.redis_websocket.socket.handlers["/"]["connect"]( - "sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} - ) - await app.redis_websocket.socket.handlers["/"]["register"]( "sid", json.dumps({"endpoint": "scan_status"}) ) @@ -97,7 +120,8 @@ async def test_redis_websocket_register(backend_client): enter_room.assert_called_with( "sid", RedisAtlasEndpoints.socketio_endpoint_room( - "test", MessageEndpoints.scan_status().endpoint + app.redis_websocket.users["sid"]["deployment"], + MessageEndpoints.scan_status().endpoint, ), ) diff --git a/backend/tests/test_scan_ingestor.py b/backend/tests/test_scan_ingestor.py index e919951..a19cd08 100644 --- a/backend/tests/test_scan_ingestor.py +++ b/backend/tests/test_scan_ingestor.py @@ -91,7 +91,6 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend): response = client.post( "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} ) - client.headers.update({"Authorization": f"Bearer {response.json()}"}) session_id = msg.info.get("session_id") scan_id = msg.scan_id diff --git a/backend/tests/test_scans_router.py b/backend/tests/test_scans_router.py index 941b996..687d1a1 100644 --- a/backend/tests/test_scans_router.py +++ b/backend/tests/test_scans_router.py @@ -14,7 +14,6 @@ def logged_in_client(backend): token = response.json() assert isinstance(token, str) assert len(token) > 20 - client.headers.update({"Authorization": f"Bearer {token}"}) return client diff --git a/frontend/bec_atlas/src/app/core/auth.interceptor.ts b/frontend/bec_atlas/src/app/core/auth.interceptor.ts index 7a85a36..5989e7d 100644 --- a/frontend/bec_atlas/src/app/core/auth.interceptor.ts +++ b/frontend/bec_atlas/src/app/core/auth.interceptor.ts @@ -11,29 +11,28 @@ import { Observable } from 'rxjs'; import { tap } from 'rxjs/operators'; function logout() { - localStorage.removeItem('id_token'); localStorage.removeItem('id_session'); location.href = '/login'; } function handle_request(handler: HttpHandler, req: HttpRequest) { return handler.handle(req).pipe( - tap( - (event: HttpEvent) => { + tap({ + next: (event: HttpEvent) => { if (event instanceof HttpResponse) { // console.log(cloned); // console.log("Service Response thr Interceptor"); } }, - (err: any) => { + error: (err: any) => { if (err instanceof HttpErrorResponse) { console.log('err.status', err); if (err.status === 401) { logout(); } } - } - ) + }, + }) ); } @@ -45,16 +44,10 @@ export class AuthInterceptor implements HttpInterceptor { req: HttpRequest, next: HttpHandler ): Observable> { - const idToken = localStorage.getItem('id_token'); + const cloned = req.clone({ + withCredentials: true, + }); - if (idToken) { - const cloned = req.clone({ - headers: req.headers.set('Authorization', 'Bearer ' + idToken), - }); - - return handle_request(next, cloned); - } else { - return handle_request(next, req); - } + return handle_request(next, cloned); } } diff --git a/frontend/bec_atlas/src/app/core/auth.service.ts b/frontend/bec_atlas/src/app/core/auth.service.ts index 61347d0..5a212f2 100644 --- a/frontend/bec_atlas/src/app/core/auth.service.ts +++ b/frontend/bec_atlas/src/app/core/auth.service.ts @@ -23,12 +23,11 @@ export class AuthService { setSession(authResult: string) { console.log(authResult); // it would be good to get an expiration date for the token... - localStorage.setItem('id_token', authResult); localStorage.setItem('id_session', this.getRandomId()); } logout() { - localStorage.removeItem('id_token'); + this.dataService.logout(); localStorage.removeItem('id_session'); this.forceReload = true; } diff --git a/frontend/bec_atlas/src/app/core/redis-connector.service.ts b/frontend/bec_atlas/src/app/core/redis-connector.service.ts index e1142f6..ded577b 100644 --- a/frontend/bec_atlas/src/app/core/redis-connector.service.ts +++ b/frontend/bec_atlas/src/app/core/redis-connector.service.ts @@ -1,8 +1,7 @@ import { Injectable, signal, WritableSignal } from '@angular/core'; import { io, Socket } from 'socket.io-client'; import { Observable } from 'rxjs'; -import { MessageEndpoints, EndpointInfo } from './redis_endpoints'; -import { AppConfigService } from '../app-config.service'; +import { EndpointInfo } from './redis_endpoints'; import { ServerSettingsService } from '../server-settings.service'; import { DeploymentService } from '../deployment.service'; @@ -36,8 +35,6 @@ export class RedisConnectorService { timeout: 500, // Connection timeout in milliseconds path: '/api/v1/ws', // Path to the WebSocket server auth: { - user: 'john_doe', - token: '1234', deployment: id, }, }); diff --git a/frontend/bec_atlas/src/app/core/remote-data.service.ts b/frontend/bec_atlas/src/app/core/remote-data.service.ts index f7ef987..579be8e 100644 --- a/frontend/bec_atlas/src/app/core/remote-data.service.ts +++ b/frontend/bec_atlas/src/app/core/remote-data.service.ts @@ -104,6 +104,16 @@ export class AuthDataService extends RemoteDataService { headers ); } + + /** + * Method for logging out of BEC + * @returns response from the server + */ + logout(): Promise { + let headers = new HttpHeaders(); + headers = headers.set('Content-Type', 'application/json; charset=utf-8'); + return firstValueFrom(this.post('user/logout', {}, headers)); + } } @Injectable({ diff --git a/frontend/bec_atlas/src/app/dashboard/dashboard.component.html b/frontend/bec_atlas/src/app/dashboard/dashboard.component.html index 5b5fe00..54ba06e 100644 --- a/frontend/bec_atlas/src/app/dashboard/dashboard.component.html +++ b/frontend/bec_atlas/src/app/dashboard/dashboard.component.html @@ -20,7 +20,7 @@ - + @@ -36,15 +36,10 @@ mat-button class="menu-item" [routerLink]="['/dashboard/scan-table']" - > Scan Data + > + Scan Data - - - + diff --git a/frontend/bec_atlas/src/app/dashboard/dashboard.component.ts b/frontend/bec_atlas/src/app/dashboard/dashboard.component.ts index ef8c26e..2d73d74 100644 --- a/frontend/bec_atlas/src/app/dashboard/dashboard.component.ts +++ b/frontend/bec_atlas/src/app/dashboard/dashboard.component.ts @@ -5,7 +5,7 @@ import { BreakpointObserver } from '@angular/cdk/layout'; import { CommonModule } from '@angular/common'; import { MatButtonModule } from '@angular/material/button'; import { MatDividerModule } from '@angular/material/divider'; -import { RouterModule } from '@angular/router'; +import { Router, RouterModule } from '@angular/router'; import { MatExpansionModule } from '@angular/material/expansion'; import { DeploymentService } from '../deployment.service'; import { @@ -15,6 +15,7 @@ import { } from '@angular/material/dialog'; import { DeploymentSelectionComponent } from '../deployment-selection/deployment-selection.component'; import { RedisConnectorService } from '../core/redis-connector.service'; +import { AuthDataService } from '../core/remote-data.service'; @Component({ selector: 'app-dashboard', imports: [ @@ -41,7 +42,9 @@ export class DashboardComponent { constructor( private breakpointObserver: BreakpointObserver, - private deploymentService: DeploymentService + private deploymentService: DeploymentService, + private authDataService: AuthDataService, + private router: Router ) {} ngOnInit(): void { @@ -89,4 +92,9 @@ export class DashboardComponent { this.openDeploymentDialog(); } } + + logout() { + this.authDataService.logout(); + this.router.navigate(['/login']); + } }