feat(auth): moved to httponly token

This commit is contained in:
2025-02-17 15:14:28 +01:00
parent 5abbec19c8
commit 0b107c9882
25 changed files with 293 additions and 150 deletions

View File

@ -2,10 +2,10 @@ from __future__ import annotations
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated from functools import wraps
import jwt import jwt
from fastapi import Depends, HTTPException, status from fastapi import HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash from pwdlib import PasswordHash
@ -19,6 +19,24 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form")
password_hash = PasswordHash.recommended() password_hash = PasswordHash.recommended()
def convert_to_user(func):
"""
Decorator to convert the current_user parameter to a User object.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
if "current_user" in kwargs:
current_user = kwargs["current_user"]
if current_user:
router = args[0]
user = router.get_user_from_db(current_user.token, current_user.email)
kwargs["current_user"] = user
return await func(*args, **kwargs)
return wrapper
def get_secret_key(): def get_secret_key():
val = os.getenv("SECRET_KEY", "test_secret") val = os.getenv("SECRET_KEY", "test_secret")
return val return val
@ -56,7 +74,12 @@ def decode_token(token: str):
raise credentials_exception from exc raise credentials_exception from exc
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo: async def get_current_user(request: Request) -> UserInfo:
token = request.cookies.get("access_token")
return get_current_user_sync(token)
def get_current_user_sync(token: str) -> UserInfo:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=401, status_code=401,
detail="Could not validate credentials", detail="Could not validate credentials",
@ -70,4 +93,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use
raise credentials_exception raise credentials_exception
except Exception as exc: except Exception as exc:
raise credentials_exception from exc raise credentials_exception from exc
return UserInfo(groups=groups, email=email) return UserInfo(email=email, token=token)

View File

@ -67,7 +67,7 @@ class User(MongoBaseModel, AccessProfile):
class UserInfo(BaseModel): class UserInfo(BaseModel):
email: str email: str
groups: list[str] token: str
class Deployments(MongoBaseModel, AccessProfile): class Deployments(MongoBaseModel, AccessProfile):

View File

@ -1,4 +1,23 @@
from functools import lru_cache
from bec_atlas.model.model import User
class BaseRouter: class BaseRouter:
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None: def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
self.datasources = datasources self.datasources = datasources
self.prefix = prefix self.prefix = prefix
@lru_cache(maxsize=128)
def get_user_from_db(self, _token: str, email: str) -> User:
"""
Get the user from the database. This is a helper function to be used by the
convert_to_user decorator. The function is cached to avoid repeated database
queries. To scope the cache to the current request, the token and email are
used as the cache key.
Args:
_token (str): The token
email (str): The email
"""
return self.datasources.datasources["mongodb"].get_user_by_email(email)

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import BECAccessProfile, UserInfo from bec_atlas.model.model import BECAccessProfile, User, UserInfo
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
@ -18,11 +18,12 @@ class BECAccessRouter(BaseRouter):
description="Retrieve the access key for a specific deployment and user.", description="Retrieve the access key for a specific deployment and user.",
) )
@convert_to_user
async def get_bec_access( async def get_bec_access(
self, self,
deployment_id: str, deployment_id: str,
user: str = Query(None), user: str = Query(None),
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
Retrieve the access key for a specific deployment and user. Retrieve the access key for a specific deployment and user.
@ -31,7 +32,7 @@ class BECAccessRouter(BaseRouter):
deployment_id (str): The deployment id deployment_id (str): The deployment id
user (str): The user name to retrieve the access key for. If not provided, user (str): The user name to retrieve the access key for. If not provided,
the access key for the current user will be retrieved. the access key for the current user will be retrieved.
current_user (UserInfo): The current user current_user (User): The current user
""" """
if not user: if not user:
user = current_user.email user = current_user.email

View File

@ -7,9 +7,9 @@ from bec_lib.serialization import MsgpackSerialization
from bson import ObjectId from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
from bec_atlas.router.redis_router import RedisAtlasEndpoints from bec_atlas.router.redis_router import RedisAtlasEndpoints
@ -37,8 +37,9 @@ class DeploymentAccessRouter(BaseRouter):
response_model=DeploymentAccess, response_model=DeploymentAccess,
) )
@convert_to_user
async def get_deployment_access( async def get_deployment_access(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) self, deployment_id: str, current_user: User = Depends(get_current_user)
) -> DeploymentAccess: ) -> DeploymentAccess:
""" """
Get the access lists for a specific deployment. Get the access lists for a specific deployment.
@ -56,11 +57,12 @@ class DeploymentAccessRouter(BaseRouter):
"deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user "deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user
) )
@convert_to_user
async def patch_deployment_access( async def patch_deployment_access(
self, self,
deployment_id: str, deployment_id: str,
deployment_access: dict, deployment_access: dict,
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> DeploymentAccess: ) -> DeploymentAccess:
""" """
Update the access lists for a specific deployment. Update the access lists for a specific deployment.

View File

@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
from bson import ObjectId from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentCredential, UserInfo from bec_atlas.model.model import DeploymentCredential, User
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@ -33,8 +33,9 @@ class DeploymentCredentialsRouter(BaseRouter):
response_model=DeploymentCredential, response_model=DeploymentCredential,
) )
@convert_to_user
async def deployment_credential( async def deployment_credential(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) self, deployment_id: str, current_user: User = Depends(get_current_user)
) -> DeploymentCredential: ) -> DeploymentCredential:
""" """
Get the credentials for a deployment. Get the credentials for a deployment.
@ -56,8 +57,9 @@ class DeploymentCredentialsRouter(BaseRouter):
status_code=403, detail="User does not have permission to access this resource." status_code=403, detail="User does not have permission to access this resource."
) )
@convert_to_user
async def refresh_deployment_credentials( async def refresh_deployment_credentials(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) self, deployment_id: str, current_user: User = Depends(get_current_user)
): ):
""" """
Refresh the deployment credentials. Refresh the deployment credentials.

View File

@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
from bson import ObjectId from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentCredential, Deployments, UserInfo from bec_atlas.model.model import DeploymentCredential, Deployments, User
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@ -34,8 +34,9 @@ class DeploymentsRouter(BaseRouter):
) )
self.update_available_deployments() self.update_available_deployments()
@convert_to_user
async def deployments( async def deployments(
self, realm: str, current_user: UserInfo = Depends(get_current_user) self, realm: str, current_user: User = Depends(get_current_user)
) -> list[Deployments]: ) -> list[Deployments]:
""" """
Get all deployments for a realm. Get all deployments for a realm.
@ -48,8 +49,9 @@ class DeploymentsRouter(BaseRouter):
""" """
return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user) return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user)
@convert_to_user
async def deployment_with_id( async def deployment_with_id(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user) self, deployment_id: str, current_user: User = Depends(get_current_user)
): ):
""" """
Get deployment with id from realm Get deployment with id from realm

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentAccess, Realm, UserInfo from bec_atlas.model.model import DeploymentAccess, Realm, User
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
@ -36,8 +36,9 @@ class RealmRouter(BaseRouter):
response_model_exclude_none=True, response_model_exclude_none=True,
) )
@convert_to_user
async def realms( async def realms(
self, include_deployments: bool = False, current_user: UserInfo = Depends(get_current_user) self, include_deployments: bool = False, current_user: User = Depends(get_current_user)
) -> list[Realm]: ) -> list[Realm]:
""" """
Get all realms. Get all realms.
@ -62,8 +63,9 @@ class RealmRouter(BaseRouter):
return self.db.aggregate("realms", include, Realm, user=current_user) return self.db.aggregate("realms", include, Realm, user=current_user)
return self.db.find("realms", {}, Realm, user=current_user) return self.db.find("realms", {}, Realm, user=current_user)
@convert_to_user
async def realm_with_deployment_access( async def realm_with_deployment_access(
self, owner_only: bool = False, current_user: UserInfo = Depends(get_current_user) self, owner_only: bool = False, current_user: User = Depends(get_current_user)
): ):
""" """
Get all realms with deployment access. Get all realms with deployment access.
@ -96,9 +98,8 @@ class RealmRouter(BaseRouter):
] ]
return self.db.aggregate("realms", include, Realm, user=current_user) return self.db.aggregate("realms", include, Realm, user=current_user)
async def realm_with_id( @convert_to_user
self, realm_id: str, current_user: UserInfo = Depends(get_current_user) async def realm_with_id(self, realm_id: str, current_user: User = Depends(get_current_user)):
):
""" """
Get realm with id. Get realm with id.

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import enum
import functools import functools
import inspect import inspect
import json import json
@ -10,8 +11,11 @@ import socketio
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.serialization import MsgpackSerialization, json_ext from bec_lib.serialization import MsgpackSerialization, json_ext
from fastapi import APIRouter, Query, Response from bson import ObjectId
from fastapi import APIRouter, Depends, Query, Response
from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync
from bec_atlas.model.model import DeploymentAccess, User
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
logger = bec_logger.logger logger = bec_logger.logger
@ -20,6 +24,12 @@ if TYPE_CHECKING:
from bec_lib.redis_connector import RedisConnector from bec_lib.redis_connector import RedisConnector
class RemoteAccess(enum.Enum):
READ = "read"
READ_WRITE = "read_write"
NONE = "none"
class RedisAtlasEndpoints: class RedisAtlasEndpoints:
""" """
This class contains the endpoints for the Redis API. It is used to This class contains the endpoints for the Redis API. It is used to
@ -133,7 +143,10 @@ class RedisRouter(BaseRouter):
self.router.add_api_route("/redis", self.redis_post, methods=["POST"]) self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"]) self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
async def redis_get(self, deployment: str, key: str = Query(...)): @convert_to_user
async def redis_get(
self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user)
):
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id) response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
request_endpoint = RedisAtlasEndpoints.redis_request(deployment) request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
@ -149,10 +162,14 @@ class RedisRouter(BaseRouter):
return json_ext.dumps({"data": out.content, "metadata": out.metadata}) return json_ext.dumps({"data": out.content, "metadata": out.metadata})
async def redis_post(self, key: str, value: str): @convert_to_user
async def redis_post(
self, key: str, value: str, current_user: User = Depends(get_current_user)
):
return self.redis.set(key, value) return self.redis.set(key, value)
async def redis_delete(self, key: str): @convert_to_user
async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)):
return self.redis.delete(key) return self.redis.delete(key)
@ -268,6 +285,7 @@ class RedisWebsocket:
redis_host = datasources.datasources["redis"].config["host"] redis_host = datasources.datasources["redis"].config["host"]
redis_port = datasources.datasources["redis"].config["port"] redis_port = datasources.datasources["redis"].config["port"]
redis_password = datasources.datasources["redis"].config.get("password", "ingestor") redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
self.db = datasources.datasources["mongodb"]
self.socket = socketio.AsyncServer( self.socket = socketio.AsyncServer(
transports=["websocket"], transports=["websocket"],
ping_timeout=60, ping_timeout=60,
@ -289,7 +307,7 @@ class RedisWebsocket:
self.socket.on("disconnect", self.disconnect_client) self.socket.on("disconnect", self.disconnect_client)
print("Redis websocket started") print("Redis websocket started")
def _validate_new_user(self, http_query: str | None): def _validate_new_user(self, http_query: str | None, auth_token: str) -> tuple:
""" """
Validate the connection of a new user. In particular, Validate the connection of a new user. In particular,
the user must provide a valid token as well as have access the user must provide a valid token as well as have access
@ -310,19 +328,41 @@ class RedisWebsocket:
else: else:
query = http_query query = http_query
if "user" not in query: user_info = get_current_user_sync(auth_token)
raise ValueError("User not found in query parameters") user = self.db.find_one("users", {"email": user_info.email}, User)
user = query["user"]
# TODO: Validate the user token
deployment = query.get("deployment") deployment = query.get("deployment")
if not deployment: if not deployment:
raise ValueError("Deployment not found in query parameters") raise ValueError("Deployment not found in query parameters")
# TODO: Validate the user has access to the deployment deployment_access = self.db.find_one(
"deployments", {"_id": ObjectId(deployment)}, DeploymentAccess
)
if not deployment_access:
raise ValueError("Deployment not found")
return user, deployment access = self.get_access(user, deployment_access)
if access == RemoteAccess.NONE:
raise ValueError("User does not have remote access to the deployment")
return user, deployment, access
def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
"""
Get the access level of the user to the deployment.
"""
access = RemoteAccess.NONE
groups = set(user.groups)
if user.username is not None:
groups.add(user.username)
if user.email is not None:
groups.add(user.email)
if groups & set(deployment_access.remote_read_access):
access = RemoteAccess.READ
if groups & set(deployment_access.remote_write_access):
access = RemoteAccess.READ_WRITE
return access
@safe_socket @safe_socket
async def connect_client(self, sid, environ=None, auth=None, **kwargs): async def connect_client(self, sid, environ=None, auth=None, **kwargs):
@ -332,7 +372,23 @@ class RedisWebsocket:
http_query = environ.get("HTTP_QUERY") or auth http_query = environ.get("HTTP_QUERY") or auth
user, deployment = self._validate_new_user(http_query) cookies = environ.get("HTTP_COOKIE", "")
auth_token = None
for cookie in cookies.split("; "):
if cookie.startswith("access_token="):
auth_token = cookie.split("=")[1]
break
if not auth_token:
await self.disconnect_client(sid) # Reject connection
return
try:
user, deployment, access = self._validate_new_user(http_query, auth_token)
except ValueError:
await self.disconnect_client(sid, reason="Invalid user or deployment")
return
# check if the user was already registered in redis # check if the user was already registered in redis
socketio_server_keys = await self.socket.manager.redis.keys( socketio_server_keys = await self.socket.manager.redis.keys(
@ -350,22 +406,25 @@ class RedisWebsocket:
for value in obj.values(): for value in obj.values():
info[value["user"]] = value["subscriptions"] info[value["user"]] = value["subscriptions"]
if user in info: if user.email in info:
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
for endpoint, endpoint_request in info[user]: for endpoint, endpoint_request in info[user]:
print(f"Registering {endpoint}") print(f"Registering {endpoint}")
await self._update_user_subscriptions(sid, endpoint, endpoint_request) await self._update_user_subscriptions(sid, endpoint, endpoint_request)
else: else:
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
await self.socket.manager.update_websocket_states() await self.socket.manager.update_websocket_states()
async def disconnect_client(self, sid, _environ=None): async def disconnect_client(self, sid, reason: str = None, _environ=None):
print("Client disconnected")
is_exit = self.fastapi_app.server.should_exit is_exit = self.fastapi_app.server.should_exit
if is_exit: if is_exit:
return return
await self.socket.manager.remove_user(sid) if reason:
await self.socket.emit("error", {"error": reason}, room=sid)
if sid in self.users:
del self.users[sid]
await self.socket.disconnect(sid)
@safe_socket @safe_socket
async def redis_register(self, sid: str, msg: str): async def redis_register(self, sid: str, msg: str):

View File

@ -5,9 +5,9 @@ import json
from bson import ObjectId from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import ScanStatusPartial, ScanUserData, UserInfo from bec_atlas.model.model import ScanStatusPartial, ScanUserData, User, UserInfo
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
@ -47,6 +47,7 @@ class ScanRouter(BaseRouter):
response_model=dict, response_model=dict,
) )
@convert_to_user
async def scans( async def scans(
self, self,
session_id: str, session_id: str,
@ -55,7 +56,7 @@ class ScanRouter(BaseRouter):
offset: int = 0, offset: int = 0,
limit: int = 100, limit: int = 100,
sort: str | None = None, sort: str | None = None,
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> list[ScanStatusPartial]: ) -> list[ScanStatusPartial]:
""" """
Get all scans for a session. Get all scans for a session.
@ -69,7 +70,7 @@ class ScanRouter(BaseRouter):
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order, sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by '{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}' separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user current_user (User): The current user
Returns: Returns:
list[ScanStatusPartial]: List of scans list[ScanStatusPartial]: List of scans
@ -100,11 +101,12 @@ class ScanRouter(BaseRouter):
user=current_user, user=current_user,
) )
@convert_to_user
async def scans_with_id( async def scans_with_id(
self, self,
scan_id: str, scan_id: str,
fields: list[str] = Query(default=None), fields: list[str] = Query(default=None),
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
Get scan with id from session Get scan with id from session
@ -135,6 +137,7 @@ class ScanRouter(BaseRouter):
scan_id (str): The scan id scan_id (str): The scan id
user_data (dict): The user data to update user_data (dict): The user data to update
""" """
current_user = self.get_user_from_db(current_user.token, current_user.email)
out = self.db.patch( out = self.db.patch(
"scans", "scans",
id=scan_id, id=scan_id,
@ -147,15 +150,16 @@ class ScanRouter(BaseRouter):
raise HTTPException(status_code=404, detail="Scan not found") raise HTTPException(status_code=404, detail="Scan not found")
return {"message": "Scan user data updated."} return {"message": "Scan user data updated."}
@convert_to_user
async def count_scans( async def count_scans(
self, filter: str | None = None, current_user: UserInfo = Depends(get_current_user) self, filter: str | None = None, current_user: User = Depends(get_current_user)
) -> int: ) -> int:
""" """
Count the number of scans. Count the number of scans.
Args: Args:
filter (str): JSON filter for the query, e.g. '{"name": "test"}' filter (str): JSON filter for the query, e.g. '{"name": "test"}'
current_user (UserInfo): The current user current_user (User): The current user
Returns: Returns:
int: The number of scans int: The number of scans

View File

@ -1,15 +1,10 @@
import json import json
from typing import Annotated
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo from bec_atlas.model.model import Session, User
from bec_atlas.model.model import Session
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
@ -35,6 +30,7 @@ class SessionRouter(BaseRouter):
response_model_exclude_none=True, response_model_exclude_none=True,
) )
@convert_to_user
async def sessions( async def sessions(
self, self,
filter: str | None = None, filter: str | None = None,
@ -42,7 +38,7 @@ class SessionRouter(BaseRouter):
offset: int = 0, offset: int = 0,
limit: int = 100, limit: int = 100,
sort: str | None = None, sort: str | None = None,
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> list[Session]: ) -> list[Session]:
""" """
Get all sessions. Get all sessions.
@ -55,7 +51,7 @@ class SessionRouter(BaseRouter):
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order, sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by '{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}' separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user current_user (User): The current user
Returns: Returns:
list[Sessions]: List of sessions list[Sessions]: List of sessions
@ -82,6 +78,7 @@ class SessionRouter(BaseRouter):
user=current_user, user=current_user,
) )
@convert_to_user
async def sessions_by_realm( async def sessions_by_realm(
self, self,
realm_id: str, realm_id: str,
@ -90,7 +87,7 @@ class SessionRouter(BaseRouter):
offset: int = 0, offset: int = 0,
limit: int = 100, limit: int = 100,
sort: str | None = None, sort: str | None = None,
current_user: UserInfo = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> list[Session]: ) -> list[Session]:
""" """
Get all sessions for a realm. Get all sessions for a realm.

View File

@ -1,11 +1,16 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Response
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password from bec_atlas.authentication import (
convert_to_user,
create_access_token,
get_current_user,
verify_password,
)
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo from bec_atlas.model import UserInfo
from bec_atlas.model.model import User from bec_atlas.model.model import User
@ -19,8 +24,9 @@ class UserLoginRequest(BaseModel):
class UserRouter(BaseRouter): class UserRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None): def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
super().__init__(prefix, datasources) super().__init__(prefix, datasources)
self.use_ssl = use_ssl
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.ldap = LDAPUserService( self.ldap = LDAPUserService(
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch" ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
@ -31,23 +37,30 @@ class UserRouter(BaseRouter):
self.router.add_api_route( self.router.add_api_route(
"/user/login/form", self.form_login, methods=["POST"], dependencies=[] "/user/login/form", self.form_login, methods=["POST"], dependencies=[]
) )
self.router.add_api_route("/user/logout", self.user_logout, methods=["POST"])
async def user_me(self, user: UserInfo = Depends(get_current_user)): @convert_to_user
data = self.db.get_user_by_email(user.email) async def user_me(self, user: User = Depends(get_current_user)):
if data is None: return user
raise HTTPException(status_code=404, detail="User not found")
return data
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]): async def form_login(
self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response
):
user_login = UserLoginRequest(username=form_data.username, password=form_data.password) user_login = UserLoginRequest(username=form_data.username, password=form_data.password)
out = await self.user_login(user_login) out = await self.user_login(user_login, response)
return {"access_token": out, "token_type": "bearer"} return {"access_token": out, "token_type": "bearer"}
async def user_login(self, user_login: UserLoginRequest): async def user_login(self, user_login: UserLoginRequest, response: Response):
user = self._get_user(user_login) user = self._get_user(user_login)
if user is None: if user is None:
raise HTTPException(status_code=401, detail="User not found or password is incorrect") raise HTTPException(status_code=401, detail="User not found or password is incorrect")
return create_access_token(data={"groups": list(user.groups), "email": user.email}) token = create_access_token(data={"groups": list(user.groups), "email": user.email})
response.set_cookie(key="access_token", value=token, httponly=True, secure=self.use_ssl)
return token
async def user_logout(self, response: Response):
response.delete_cookie("access_token")
return {"message": "Logged out"}
def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None: def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None:
user = self._get_functional_account(user_login) user = self._get_functional_account(user_login)

View File

@ -118,4 +118,5 @@ def backend(redis_server):
"bec_atlas.router.redis_router.BECAsyncRedisManager", PatchedBECAsyncRedisManager "bec_atlas.router.redis_router.BECAsyncRedisManager", PatchedBECAsyncRedisManager
): ):
with TestClient(app.app) as _client: with TestClient(app.app) as _client:
app.user_router.use_ssl = False # disable ssl to allow for httponly cookies
yield _client, app yield _client, app

View File

@ -13,7 +13,6 @@ def logged_in_client(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client return client

View File

@ -11,7 +11,6 @@ def logged_in_client(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client return client
@ -70,7 +69,6 @@ def test_deployment_credential_rejects_unauthorized_user(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
deployments = client.get( deployments = client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
@ -96,7 +94,6 @@ def test_refresh_deployment_credentials_rejects_unauthorized_user(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
deployments = client.get( deployments = client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"} "/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}

View File

@ -11,7 +11,6 @@ def logged_in_client(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client return client

View File

@ -2,9 +2,10 @@ import json
from unittest import mock from unittest import mock
import pytest import pytest
from bec_atlas.router.redis_router import RedisAtlasEndpoints
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints
from bec_atlas.router.redis_router import RedisAtlasEndpoints, RemoteAccess
@pytest.fixture @pytest.fixture
def backend_client(backend): def backend_client(backend):
@ -12,48 +13,75 @@ def backend_client(backend):
app.server = mock.Mock() app.server = mock.Mock()
app.server.should_exit = False app.server.should_exit = False
app.redis_websocket.users = {} app.redis_websocket.users = {}
yield client, app response = client.post(
# app.redis_websocket.redis._redis_conn.flushall() "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
assert response.status_code == 200
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
return client, app
@pytest.fixture
@pytest.mark.asyncio(loop_scope="session")
async def connected_ws(backend_client):
client, app = backend_client
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ):
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid",
{
"HTTP_QUERY": json.dumps({"deployment": deployment[0]["_id"]}),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
)
yield backend_client
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_connect(backend_client): async def test_redis_websocket_connect(connected_ws):
client, app = backend_client _, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
assert "sid" in app.redis_websocket.users assert "sid" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_disconnect(backend_client): async def test_redis_websocket_disconnect(connected_ws):
client, app = backend_client _, app = await anext(connected_ws)
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users assert "sid" not in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect(backend_client): async def test_redis_websocket_multiple_connect(connected_ws):
client, app = backend_client client, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["connect"]( await app.redis_websocket.socket.handlers["/"]["connect"](
"sid1", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} "sid2",
{
"HTTP_QUERY": json.dumps(
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
) )
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid2", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} assert "sid" in app.redis_websocket.users
)
assert "sid1" in app.redis_websocket.users
assert "sid2" in app.redis_websocket.users assert "sid2" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect_same_sid(backend_client): async def test_redis_websocket_multiple_connect_same_sid(connected_ws):
client, app = backend_client client, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["connect"]( await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} "sid",
) {
await app.redis_websocket.socket.handlers["/"]["connect"]( "HTTP_QUERY": json.dumps(
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'} {"deployment": app.redis_websocket.users["sid"]["deployment"]}
),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
) )
assert "sid" in app.redis_websocket.users assert "sid" in app.redis_websocket.users
@ -61,9 +89,8 @@ async def test_redis_websocket_multiple_connect_same_sid(backend_client):
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_disconnect_same_sid(backend_client): async def test_redis_websocket_multiple_disconnect_same_sid(connected_ws):
client, app = backend_client client, app = await anext(connected_ws)
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid") await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users assert "sid" not in app.redis_websocket.users
@ -82,14 +109,10 @@ async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
@pytest.mark.asyncio(loop_scope="session") @pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_register(backend_client): async def test_redis_websocket_register(connected_ws):
client, app = backend_client client, app = await anext(connected_ws)
with mock.patch.object(app.redis_websocket.socket, "emit") as emit: with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room: with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
await app.redis_websocket.socket.handlers["/"]["register"]( await app.redis_websocket.socket.handlers["/"]["register"](
"sid", json.dumps({"endpoint": "scan_status"}) "sid", json.dumps({"endpoint": "scan_status"})
) )
@ -97,7 +120,8 @@ async def test_redis_websocket_register(backend_client):
enter_room.assert_called_with( enter_room.assert_called_with(
"sid", "sid",
RedisAtlasEndpoints.socketio_endpoint_room( RedisAtlasEndpoints.socketio_endpoint_room(
"test", MessageEndpoints.scan_status().endpoint app.redis_websocket.users["sid"]["deployment"],
MessageEndpoints.scan_status().endpoint,
), ),
) )

View File

@ -91,7 +91,6 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
response = client.post( response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"} "/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
) )
client.headers.update({"Authorization": f"Bearer {response.json()}"})
session_id = msg.info.get("session_id") session_id = msg.info.get("session_id")
scan_id = msg.scan_id scan_id = msg.scan_id

View File

@ -14,7 +14,6 @@ def logged_in_client(backend):
token = response.json() token = response.json()
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 20 assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client return client

View File

@ -11,29 +11,28 @@ import { Observable } from 'rxjs';
import { tap } from 'rxjs/operators'; import { tap } from 'rxjs/operators';
function logout() { function logout() {
localStorage.removeItem('id_token');
localStorage.removeItem('id_session'); localStorage.removeItem('id_session');
location.href = '/login'; location.href = '/login';
} }
function handle_request(handler: HttpHandler, req: HttpRequest<any>) { function handle_request(handler: HttpHandler, req: HttpRequest<any>) {
return handler.handle(req).pipe( return handler.handle(req).pipe(
tap( tap({
(event: HttpEvent<any>) => { next: (event: HttpEvent<any>) => {
if (event instanceof HttpResponse) { if (event instanceof HttpResponse) {
// console.log(cloned); // console.log(cloned);
// console.log("Service Response thr Interceptor"); // console.log("Service Response thr Interceptor");
} }
}, },
(err: any) => { error: (err: any) => {
if (err instanceof HttpErrorResponse) { if (err instanceof HttpErrorResponse) {
console.log('err.status', err); console.log('err.status', err);
if (err.status === 401) { if (err.status === 401) {
logout(); logout();
} }
} }
} },
) })
); );
} }
@ -45,16 +44,10 @@ export class AuthInterceptor implements HttpInterceptor {
req: HttpRequest<any>, req: HttpRequest<any>,
next: HttpHandler next: HttpHandler
): Observable<HttpEvent<any>> { ): Observable<HttpEvent<any>> {
const idToken = localStorage.getItem('id_token'); const cloned = req.clone({
withCredentials: true,
});
if (idToken) { return handle_request(next, cloned);
const cloned = req.clone({
headers: req.headers.set('Authorization', 'Bearer ' + idToken),
});
return handle_request(next, cloned);
} else {
return handle_request(next, req);
}
} }
} }

View File

@ -23,12 +23,11 @@ export class AuthService {
setSession(authResult: string) { setSession(authResult: string) {
console.log(authResult); console.log(authResult);
// it would be good to get an expiration date for the token... // it would be good to get an expiration date for the token...
localStorage.setItem('id_token', authResult);
localStorage.setItem('id_session', this.getRandomId()); localStorage.setItem('id_session', this.getRandomId());
} }
logout() { logout() {
localStorage.removeItem('id_token'); this.dataService.logout();
localStorage.removeItem('id_session'); localStorage.removeItem('id_session');
this.forceReload = true; this.forceReload = true;
} }

View File

@ -1,8 +1,7 @@
import { Injectable, signal, WritableSignal } from '@angular/core'; import { Injectable, signal, WritableSignal } from '@angular/core';
import { io, Socket } from 'socket.io-client'; import { io, Socket } from 'socket.io-client';
import { Observable } from 'rxjs'; import { Observable } from 'rxjs';
import { MessageEndpoints, EndpointInfo } from './redis_endpoints'; import { EndpointInfo } from './redis_endpoints';
import { AppConfigService } from '../app-config.service';
import { ServerSettingsService } from '../server-settings.service'; import { ServerSettingsService } from '../server-settings.service';
import { DeploymentService } from '../deployment.service'; import { DeploymentService } from '../deployment.service';
@ -36,8 +35,6 @@ export class RedisConnectorService {
timeout: 500, // Connection timeout in milliseconds timeout: 500, // Connection timeout in milliseconds
path: '/api/v1/ws', // Path to the WebSocket server path: '/api/v1/ws', // Path to the WebSocket server
auth: { auth: {
user: 'john_doe',
token: '1234',
deployment: id, deployment: id,
}, },
}); });

View File

@ -104,6 +104,16 @@ export class AuthDataService extends RemoteDataService {
headers headers
); );
} }
/**
* Method for logging out of BEC
* @returns response from the server
*/
logout(): Promise<string> {
let headers = new HttpHeaders();
headers = headers.set('Content-Type', 'application/json; charset=utf-8');
return firstValueFrom(this.post<string>('user/logout', {}, headers));
}
} }
@Injectable({ @Injectable({

View File

@ -20,7 +20,7 @@
<button mat-button>Profile</button> <button mat-button>Profile</button>
<button mat-button>Settings</button> <button mat-button>Settings</button>
<button mat-button>Logout</button> <button mat-button (click)="logout()">Logout</button>
</mat-expansion-panel> </mat-expansion-panel>
<mat-divider></mat-divider> <mat-divider></mat-divider>
@ -36,15 +36,10 @@
mat-button mat-button
class="menu-item" class="menu-item"
[routerLink]="['/dashboard/scan-table']" [routerLink]="['/dashboard/scan-table']"
> Scan Data >
Scan Data
</button> </button>
<button <button mat-button class="menu-item">Device Data</button>
mat-button
class="menu-item"
> Device Data
</button>
</mat-expansion-panel> </mat-expansion-panel>
<!-- Experiment Control Expansion --> <!-- Experiment Control Expansion -->

View File

@ -5,7 +5,7 @@ import { BreakpointObserver } from '@angular/cdk/layout';
import { CommonModule } from '@angular/common'; import { CommonModule } from '@angular/common';
import { MatButtonModule } from '@angular/material/button'; import { MatButtonModule } from '@angular/material/button';
import { MatDividerModule } from '@angular/material/divider'; import { MatDividerModule } from '@angular/material/divider';
import { RouterModule } from '@angular/router'; import { Router, RouterModule } from '@angular/router';
import { MatExpansionModule } from '@angular/material/expansion'; import { MatExpansionModule } from '@angular/material/expansion';
import { DeploymentService } from '../deployment.service'; import { DeploymentService } from '../deployment.service';
import { import {
@ -15,6 +15,7 @@ import {
} from '@angular/material/dialog'; } from '@angular/material/dialog';
import { DeploymentSelectionComponent } from '../deployment-selection/deployment-selection.component'; import { DeploymentSelectionComponent } from '../deployment-selection/deployment-selection.component';
import { RedisConnectorService } from '../core/redis-connector.service'; import { RedisConnectorService } from '../core/redis-connector.service';
import { AuthDataService } from '../core/remote-data.service';
@Component({ @Component({
selector: 'app-dashboard', selector: 'app-dashboard',
imports: [ imports: [
@ -41,7 +42,9 @@ export class DashboardComponent {
constructor( constructor(
private breakpointObserver: BreakpointObserver, private breakpointObserver: BreakpointObserver,
private deploymentService: DeploymentService private deploymentService: DeploymentService,
private authDataService: AuthDataService,
private router: Router
) {} ) {}
ngOnInit(): void { ngOnInit(): void {
@ -89,4 +92,9 @@ export class DashboardComponent {
this.openDeploymentDialog(); this.openDeploymentDialog();
} }
} }
logout() {
this.authDataService.logout();
this.router.navigate(['/login']);
}
} }