mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-13 22:51:49 +02:00
feat(auth): moved to httponly token
This commit is contained in:
@ -2,10 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from functools import wraps
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from pwdlib import PasswordHash
|
||||
@ -19,6 +19,24 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form")
|
||||
password_hash = PasswordHash.recommended()
|
||||
|
||||
|
||||
def convert_to_user(func):
|
||||
"""
|
||||
Decorator to convert the current_user parameter to a User object.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if "current_user" in kwargs:
|
||||
current_user = kwargs["current_user"]
|
||||
if current_user:
|
||||
router = args[0]
|
||||
user = router.get_user_from_db(current_user.token, current_user.email)
|
||||
kwargs["current_user"] = user
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_secret_key():
|
||||
val = os.getenv("SECRET_KEY", "test_secret")
|
||||
return val
|
||||
@ -56,7 +74,12 @@ def decode_token(token: str):
|
||||
raise credentials_exception from exc
|
||||
|
||||
|
||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo:
|
||||
async def get_current_user(request: Request) -> UserInfo:
|
||||
token = request.cookies.get("access_token")
|
||||
return get_current_user_sync(token)
|
||||
|
||||
|
||||
def get_current_user_sync(token: str) -> UserInfo:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
@ -70,4 +93,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use
|
||||
raise credentials_exception
|
||||
except Exception as exc:
|
||||
raise credentials_exception from exc
|
||||
return UserInfo(groups=groups, email=email)
|
||||
return UserInfo(email=email, token=token)
|
||||
|
@ -67,7 +67,7 @@ class User(MongoBaseModel, AccessProfile):
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
email: str
|
||||
groups: list[str]
|
||||
token: str
|
||||
|
||||
|
||||
class Deployments(MongoBaseModel, AccessProfile):
|
||||
|
@ -1,4 +1,23 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from bec_atlas.model.model import User
|
||||
|
||||
|
||||
class BaseRouter:
|
||||
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
|
||||
self.datasources = datasources
|
||||
self.prefix = prefix
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_user_from_db(self, _token: str, email: str) -> User:
|
||||
"""
|
||||
Get the user from the database. This is a helper function to be used by the
|
||||
convert_to_user decorator. The function is cached to avoid repeated database
|
||||
queries. To scope the cache to the current request, the token and email are
|
||||
used as the cache key.
|
||||
|
||||
Args:
|
||||
_token (str): The token
|
||||
email (str): The email
|
||||
"""
|
||||
return self.datasources.datasources["mongodb"].get_user_by_email(email)
|
||||
|
@ -1,8 +1,8 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import BECAccessProfile, UserInfo
|
||||
from bec_atlas.model.model import BECAccessProfile, User, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -18,11 +18,12 @@ class BECAccessRouter(BaseRouter):
|
||||
description="Retrieve the access key for a specific deployment and user.",
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def get_bec_access(
|
||||
self,
|
||||
deployment_id: str,
|
||||
user: str = Query(None),
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve the access key for a specific deployment and user.
|
||||
@ -31,7 +32,7 @@ class BECAccessRouter(BaseRouter):
|
||||
deployment_id (str): The deployment id
|
||||
user (str): The user name to retrieve the access key for. If not provided,
|
||||
the access key for the current user will be retrieved.
|
||||
current_user (UserInfo): The current user
|
||||
current_user (User): The current user
|
||||
"""
|
||||
if not user:
|
||||
user = current_user.email
|
||||
|
@ -7,9 +7,9 @@ from bec_lib.serialization import MsgpackSerialization
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo
|
||||
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
||||
|
||||
@ -37,8 +37,9 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
response_model=DeploymentAccess,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def get_deployment_access(
|
||||
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||
) -> DeploymentAccess:
|
||||
"""
|
||||
Get the access lists for a specific deployment.
|
||||
@ -56,11 +57,12 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
"deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def patch_deployment_access(
|
||||
self,
|
||||
deployment_id: str,
|
||||
deployment_access: dict,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> DeploymentAccess:
|
||||
"""
|
||||
Update the access lists for a specific deployment.
|
||||
|
@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import DeploymentCredential, UserInfo
|
||||
from bec_atlas.model.model import DeploymentCredential, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
@ -33,8 +33,9 @@ class DeploymentCredentialsRouter(BaseRouter):
|
||||
response_model=DeploymentCredential,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def deployment_credential(
|
||||
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||
) -> DeploymentCredential:
|
||||
"""
|
||||
Get the credentials for a deployment.
|
||||
@ -56,8 +57,9 @@ class DeploymentCredentialsRouter(BaseRouter):
|
||||
status_code=403, detail="User does not have permission to access this resource."
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def refresh_deployment_credentials(
|
||||
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Refresh the deployment credentials.
|
||||
|
@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import DeploymentCredential, Deployments, UserInfo
|
||||
from bec_atlas.model.model import DeploymentCredential, Deployments, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
@ -34,8 +34,9 @@ class DeploymentsRouter(BaseRouter):
|
||||
)
|
||||
self.update_available_deployments()
|
||||
|
||||
@convert_to_user
|
||||
async def deployments(
|
||||
self, realm: str, current_user: UserInfo = Depends(get_current_user)
|
||||
self, realm: str, current_user: User = Depends(get_current_user)
|
||||
) -> list[Deployments]:
|
||||
"""
|
||||
Get all deployments for a realm.
|
||||
@ -48,8 +49,9 @@ class DeploymentsRouter(BaseRouter):
|
||||
"""
|
||||
return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user)
|
||||
|
||||
@convert_to_user
|
||||
async def deployment_with_id(
|
||||
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get deployment with id from realm
|
||||
|
@ -1,8 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import DeploymentAccess, Realm, UserInfo
|
||||
from bec_atlas.model.model import DeploymentAccess, Realm, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -36,8 +36,9 @@ class RealmRouter(BaseRouter):
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def realms(
|
||||
self, include_deployments: bool = False, current_user: UserInfo = Depends(get_current_user)
|
||||
self, include_deployments: bool = False, current_user: User = Depends(get_current_user)
|
||||
) -> list[Realm]:
|
||||
"""
|
||||
Get all realms.
|
||||
@ -62,8 +63,9 @@ class RealmRouter(BaseRouter):
|
||||
return self.db.aggregate("realms", include, Realm, user=current_user)
|
||||
return self.db.find("realms", {}, Realm, user=current_user)
|
||||
|
||||
@convert_to_user
|
||||
async def realm_with_deployment_access(
|
||||
self, owner_only: bool = False, current_user: UserInfo = Depends(get_current_user)
|
||||
self, owner_only: bool = False, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all realms with deployment access.
|
||||
@ -96,9 +98,8 @@ class RealmRouter(BaseRouter):
|
||||
]
|
||||
return self.db.aggregate("realms", include, Realm, user=current_user)
|
||||
|
||||
async def realm_with_id(
|
||||
self, realm_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
@convert_to_user
|
||||
async def realm_with_id(self, realm_id: str, current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
Get realm with id.
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@ -10,8 +11,11 @@ import socketio
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.serialization import MsgpackSerialization, json_ext
|
||||
from fastapi import APIRouter, Query, Response
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync
|
||||
from bec_atlas.model.model import DeploymentAccess, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
logger = bec_logger.logger
|
||||
@ -20,6 +24,12 @@ if TYPE_CHECKING:
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
|
||||
class RemoteAccess(enum.Enum):
|
||||
READ = "read"
|
||||
READ_WRITE = "read_write"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class RedisAtlasEndpoints:
|
||||
"""
|
||||
This class contains the endpoints for the Redis API. It is used to
|
||||
@ -133,7 +143,10 @@ class RedisRouter(BaseRouter):
|
||||
self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
|
||||
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
|
||||
|
||||
async def redis_get(self, deployment: str, key: str = Query(...)):
|
||||
@convert_to_user
|
||||
async def redis_get(
|
||||
self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user)
|
||||
):
|
||||
request_id = uuid.uuid4().hex
|
||||
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
|
||||
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
|
||||
@ -149,10 +162,14 @@ class RedisRouter(BaseRouter):
|
||||
|
||||
return json_ext.dumps({"data": out.content, "metadata": out.metadata})
|
||||
|
||||
async def redis_post(self, key: str, value: str):
|
||||
@convert_to_user
|
||||
async def redis_post(
|
||||
self, key: str, value: str, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
return self.redis.set(key, value)
|
||||
|
||||
async def redis_delete(self, key: str):
|
||||
@convert_to_user
|
||||
async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)):
|
||||
return self.redis.delete(key)
|
||||
|
||||
|
||||
@ -268,6 +285,7 @@ class RedisWebsocket:
|
||||
redis_host = datasources.datasources["redis"].config["host"]
|
||||
redis_port = datasources.datasources["redis"].config["port"]
|
||||
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
|
||||
self.db = datasources.datasources["mongodb"]
|
||||
self.socket = socketio.AsyncServer(
|
||||
transports=["websocket"],
|
||||
ping_timeout=60,
|
||||
@ -289,7 +307,7 @@ class RedisWebsocket:
|
||||
self.socket.on("disconnect", self.disconnect_client)
|
||||
print("Redis websocket started")
|
||||
|
||||
def _validate_new_user(self, http_query: str | None):
|
||||
def _validate_new_user(self, http_query: str | None, auth_token: str) -> tuple:
|
||||
"""
|
||||
Validate the connection of a new user. In particular,
|
||||
the user must provide a valid token as well as have access
|
||||
@ -310,19 +328,41 @@ class RedisWebsocket:
|
||||
else:
|
||||
query = http_query
|
||||
|
||||
if "user" not in query:
|
||||
raise ValueError("User not found in query parameters")
|
||||
user = query["user"]
|
||||
|
||||
# TODO: Validate the user token
|
||||
user_info = get_current_user_sync(auth_token)
|
||||
user = self.db.find_one("users", {"email": user_info.email}, User)
|
||||
|
||||
deployment = query.get("deployment")
|
||||
if not deployment:
|
||||
raise ValueError("Deployment not found in query parameters")
|
||||
|
||||
# TODO: Validate the user has access to the deployment
|
||||
deployment_access = self.db.find_one(
|
||||
"deployments", {"_id": ObjectId(deployment)}, DeploymentAccess
|
||||
)
|
||||
if not deployment_access:
|
||||
raise ValueError("Deployment not found")
|
||||
|
||||
return user, deployment
|
||||
access = self.get_access(user, deployment_access)
|
||||
if access == RemoteAccess.NONE:
|
||||
raise ValueError("User does not have remote access to the deployment")
|
||||
|
||||
return user, deployment, access
|
||||
|
||||
def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
|
||||
"""
|
||||
Get the access level of the user to the deployment.
|
||||
"""
|
||||
access = RemoteAccess.NONE
|
||||
groups = set(user.groups)
|
||||
if user.username is not None:
|
||||
groups.add(user.username)
|
||||
if user.email is not None:
|
||||
groups.add(user.email)
|
||||
|
||||
if groups & set(deployment_access.remote_read_access):
|
||||
access = RemoteAccess.READ
|
||||
if groups & set(deployment_access.remote_write_access):
|
||||
access = RemoteAccess.READ_WRITE
|
||||
return access
|
||||
|
||||
@safe_socket
|
||||
async def connect_client(self, sid, environ=None, auth=None, **kwargs):
|
||||
@ -332,7 +372,23 @@ class RedisWebsocket:
|
||||
|
||||
http_query = environ.get("HTTP_QUERY") or auth
|
||||
|
||||
user, deployment = self._validate_new_user(http_query)
|
||||
cookies = environ.get("HTTP_COOKIE", "")
|
||||
auth_token = None
|
||||
|
||||
for cookie in cookies.split("; "):
|
||||
if cookie.startswith("access_token="):
|
||||
auth_token = cookie.split("=")[1]
|
||||
break
|
||||
|
||||
if not auth_token:
|
||||
await self.disconnect_client(sid) # Reject connection
|
||||
return
|
||||
|
||||
try:
|
||||
user, deployment, access = self._validate_new_user(http_query, auth_token)
|
||||
except ValueError:
|
||||
await self.disconnect_client(sid, reason="Invalid user or deployment")
|
||||
return
|
||||
|
||||
# check if the user was already registered in redis
|
||||
socketio_server_keys = await self.socket.manager.redis.keys(
|
||||
@ -350,22 +406,25 @@ class RedisWebsocket:
|
||||
for value in obj.values():
|
||||
info[value["user"]] = value["subscriptions"]
|
||||
|
||||
if user in info:
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
if user.email in info:
|
||||
self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
|
||||
for endpoint, endpoint_request in info[user]:
|
||||
print(f"Registering {endpoint}")
|
||||
await self._update_user_subscriptions(sid, endpoint, endpoint_request)
|
||||
else:
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
|
||||
|
||||
await self.socket.manager.update_websocket_states()
|
||||
|
||||
async def disconnect_client(self, sid, _environ=None):
|
||||
print("Client disconnected")
|
||||
async def disconnect_client(self, sid, reason: str = None, _environ=None):
|
||||
is_exit = self.fastapi_app.server.should_exit
|
||||
if is_exit:
|
||||
return
|
||||
await self.socket.manager.remove_user(sid)
|
||||
if reason:
|
||||
await self.socket.emit("error", {"error": reason}, room=sid)
|
||||
if sid in self.users:
|
||||
del self.users[sid]
|
||||
await self.socket.disconnect(sid)
|
||||
|
||||
@safe_socket
|
||||
async def redis_register(self, sid: str, msg: str):
|
||||
|
@ -5,9 +5,9 @@ import json
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import ScanStatusPartial, ScanUserData, UserInfo
|
||||
from bec_atlas.model.model import ScanStatusPartial, ScanUserData, User, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ class ScanRouter(BaseRouter):
|
||||
response_model=dict,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def scans(
|
||||
self,
|
||||
session_id: str,
|
||||
@ -55,7 +56,7 @@ class ScanRouter(BaseRouter):
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
sort: str | None = None,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[ScanStatusPartial]:
|
||||
"""
|
||||
Get all scans for a session.
|
||||
@ -69,7 +70,7 @@ class ScanRouter(BaseRouter):
|
||||
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
|
||||
'{"name": -1}' for descending order. Multiple fields can be sorted by
|
||||
separating them with a comma, e.g. '{"name": 1, "description": -1}'
|
||||
current_user (UserInfo): The current user
|
||||
current_user (User): The current user
|
||||
|
||||
Returns:
|
||||
list[ScanStatusPartial]: List of scans
|
||||
@ -100,11 +101,12 @@ class ScanRouter(BaseRouter):
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def scans_with_id(
|
||||
self,
|
||||
scan_id: str,
|
||||
fields: list[str] = Query(default=None),
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get scan with id from session
|
||||
@ -135,6 +137,7 @@ class ScanRouter(BaseRouter):
|
||||
scan_id (str): The scan id
|
||||
user_data (dict): The user data to update
|
||||
"""
|
||||
current_user = self.get_user_from_db(current_user.token, current_user.email)
|
||||
out = self.db.patch(
|
||||
"scans",
|
||||
id=scan_id,
|
||||
@ -147,15 +150,16 @@ class ScanRouter(BaseRouter):
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
return {"message": "Scan user data updated."}
|
||||
|
||||
@convert_to_user
|
||||
async def count_scans(
|
||||
self, filter: str | None = None, current_user: UserInfo = Depends(get_current_user)
|
||||
self, filter: str | None = None, current_user: User = Depends(get_current_user)
|
||||
) -> int:
|
||||
"""
|
||||
Count the number of scans.
|
||||
|
||||
Args:
|
||||
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
|
||||
current_user (UserInfo): The current user
|
||||
current_user (User): The current user
|
||||
|
||||
Returns:
|
||||
int: The number of scans
|
||||
|
@ -1,15 +1,10 @@
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model import UserInfo
|
||||
from bec_atlas.model.model import Session
|
||||
from bec_atlas.model.model import Session, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -35,6 +30,7 @@ class SessionRouter(BaseRouter):
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def sessions(
|
||||
self,
|
||||
filter: str | None = None,
|
||||
@ -42,7 +38,7 @@ class SessionRouter(BaseRouter):
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
sort: str | None = None,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[Session]:
|
||||
"""
|
||||
Get all sessions.
|
||||
@ -55,7 +51,7 @@ class SessionRouter(BaseRouter):
|
||||
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
|
||||
'{"name": -1}' for descending order. Multiple fields can be sorted by
|
||||
separating them with a comma, e.g. '{"name": 1, "description": -1}'
|
||||
current_user (UserInfo): The current user
|
||||
current_user (User): The current user
|
||||
|
||||
Returns:
|
||||
list[Sessions]: List of sessions
|
||||
@ -82,6 +78,7 @@ class SessionRouter(BaseRouter):
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
@convert_to_user
|
||||
async def sessions_by_realm(
|
||||
self,
|
||||
realm_id: str,
|
||||
@ -90,7 +87,7 @@ class SessionRouter(BaseRouter):
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
sort: str | None = None,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[Session]:
|
||||
"""
|
||||
Get all sessions for a realm.
|
||||
|
@ -1,11 +1,16 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
|
||||
from bec_atlas.authentication import (
|
||||
convert_to_user,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
verify_password,
|
||||
)
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model import UserInfo
|
||||
from bec_atlas.model.model import User
|
||||
@ -19,8 +24,9 @@ class UserLoginRequest(BaseModel):
|
||||
|
||||
|
||||
class UserRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
|
||||
super().__init__(prefix, datasources)
|
||||
self.use_ssl = use_ssl
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.ldap = LDAPUserService(
|
||||
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
|
||||
@ -31,23 +37,30 @@ class UserRouter(BaseRouter):
|
||||
self.router.add_api_route(
|
||||
"/user/login/form", self.form_login, methods=["POST"], dependencies=[]
|
||||
)
|
||||
self.router.add_api_route("/user/logout", self.user_logout, methods=["POST"])
|
||||
|
||||
async def user_me(self, user: UserInfo = Depends(get_current_user)):
|
||||
data = self.db.get_user_by_email(user.email)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return data
|
||||
@convert_to_user
|
||||
async def user_me(self, user: User = Depends(get_current_user)):
|
||||
return user
|
||||
|
||||
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
async def form_login(
|
||||
self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response
|
||||
):
|
||||
user_login = UserLoginRequest(username=form_data.username, password=form_data.password)
|
||||
out = await self.user_login(user_login)
|
||||
out = await self.user_login(user_login, response)
|
||||
return {"access_token": out, "token_type": "bearer"}
|
||||
|
||||
async def user_login(self, user_login: UserLoginRequest):
|
||||
async def user_login(self, user_login: UserLoginRequest, response: Response):
|
||||
user = self._get_user(user_login)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not found or password is incorrect")
|
||||
return create_access_token(data={"groups": list(user.groups), "email": user.email})
|
||||
token = create_access_token(data={"groups": list(user.groups), "email": user.email})
|
||||
response.set_cookie(key="access_token", value=token, httponly=True, secure=self.use_ssl)
|
||||
return token
|
||||
|
||||
async def user_logout(self, response: Response):
|
||||
response.delete_cookie("access_token")
|
||||
return {"message": "Logged out"}
|
||||
|
||||
def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None:
|
||||
user = self._get_functional_account(user_login)
|
||||
|
@ -118,4 +118,5 @@ def backend(redis_server):
|
||||
"bec_atlas.router.redis_router.BECAsyncRedisManager", PatchedBECAsyncRedisManager
|
||||
):
|
||||
with TestClient(app.app) as _client:
|
||||
app.user_router.use_ssl = False # disable ssl to allow for httponly cookies
|
||||
yield _client, app
|
||||
|
@ -13,7 +13,6 @@ def logged_in_client(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
return client
|
||||
|
||||
|
||||
|
@ -11,7 +11,6 @@ def logged_in_client(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
return client
|
||||
|
||||
|
||||
@ -70,7 +69,6 @@ def test_deployment_credential_rejects_unauthorized_user(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
|
||||
deployments = client.get(
|
||||
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||
@ -96,7 +94,6 @@ def test_refresh_deployment_credentials_rejects_unauthorized_user(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
|
||||
deployments = client.get(
|
||||
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||
|
@ -11,7 +11,6 @@ def logged_in_client(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
return client
|
||||
|
||||
|
||||
|
@ -2,9 +2,10 @@ import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
|
||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints, RemoteAccess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend_client(backend):
|
||||
@ -12,48 +13,75 @@ def backend_client(backend):
|
||||
app.server = mock.Mock()
|
||||
app.server.should_exit = False
|
||||
app.redis_websocket.users = {}
|
||||
yield client, app
|
||||
# app.redis_websocket.redis._redis_conn.flushall()
|
||||
response = client.post(
|
||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
return client, app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def connected_ws(backend_client):
|
||||
client, app = backend_client
|
||||
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
|
||||
with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ):
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid",
|
||||
{
|
||||
"HTTP_QUERY": json.dumps({"deployment": deployment[0]["_id"]}),
|
||||
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
||||
},
|
||||
)
|
||||
yield backend_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_connect(backend_client):
|
||||
client, app = backend_client
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
)
|
||||
async def test_redis_websocket_connect(connected_ws):
|
||||
_, app = await anext(connected_ws)
|
||||
assert "sid" in app.redis_websocket.users
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_disconnect(backend_client):
|
||||
client, app = backend_client
|
||||
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
|
||||
async def test_redis_websocket_disconnect(connected_ws):
|
||||
_, app = await anext(connected_ws)
|
||||
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||
assert "sid" not in app.redis_websocket.users
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_multiple_connect(backend_client):
|
||||
client, app = backend_client
|
||||
async def test_redis_websocket_multiple_connect(connected_ws):
|
||||
client, app = await anext(connected_ws)
|
||||
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid1", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
"sid2",
|
||||
{
|
||||
"HTTP_QUERY": json.dumps(
|
||||
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
|
||||
),
|
||||
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
||||
},
|
||||
)
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid2", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
)
|
||||
assert "sid1" in app.redis_websocket.users
|
||||
|
||||
assert "sid" in app.redis_websocket.users
|
||||
assert "sid2" in app.redis_websocket.users
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_multiple_connect_same_sid(backend_client):
|
||||
client, app = backend_client
|
||||
async def test_redis_websocket_multiple_connect_same_sid(connected_ws):
|
||||
client, app = await anext(connected_ws)
|
||||
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
)
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
"sid",
|
||||
{
|
||||
"HTTP_QUERY": json.dumps(
|
||||
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
|
||||
),
|
||||
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
|
||||
},
|
||||
)
|
||||
|
||||
assert "sid" in app.redis_websocket.users
|
||||
@ -61,9 +89,8 @@ async def test_redis_websocket_multiple_connect_same_sid(backend_client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
|
||||
client, app = backend_client
|
||||
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
|
||||
async def test_redis_websocket_multiple_disconnect_same_sid(connected_ws):
|
||||
client, app = await anext(connected_ws)
|
||||
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
|
||||
assert "sid" not in app.redis_websocket.users
|
||||
@ -82,14 +109,10 @@ async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_redis_websocket_register(backend_client):
|
||||
client, app = backend_client
|
||||
async def test_redis_websocket_register(connected_ws):
|
||||
client, app = await anext(connected_ws)
|
||||
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
||||
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"](
|
||||
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
|
||||
)
|
||||
|
||||
await app.redis_websocket.socket.handlers["/"]["register"](
|
||||
"sid", json.dumps({"endpoint": "scan_status"})
|
||||
)
|
||||
@ -97,7 +120,8 @@ async def test_redis_websocket_register(backend_client):
|
||||
enter_room.assert_called_with(
|
||||
"sid",
|
||||
RedisAtlasEndpoints.socketio_endpoint_room(
|
||||
"test", MessageEndpoints.scan_status().endpoint
|
||||
app.redis_websocket.users["sid"]["deployment"],
|
||||
MessageEndpoints.scan_status().endpoint,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -91,7 +91,6 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
|
||||
response = client.post(
|
||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||
)
|
||||
client.headers.update({"Authorization": f"Bearer {response.json()}"})
|
||||
|
||||
session_id = msg.info.get("session_id")
|
||||
scan_id = msg.scan_id
|
||||
|
@ -14,7 +14,6 @@ def logged_in_client(backend):
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
return client
|
||||
|
||||
|
||||
|
@ -11,29 +11,28 @@ import { Observable } from 'rxjs';
|
||||
import { tap } from 'rxjs/operators';
|
||||
|
||||
function logout() {
|
||||
localStorage.removeItem('id_token');
|
||||
localStorage.removeItem('id_session');
|
||||
location.href = '/login';
|
||||
}
|
||||
|
||||
function handle_request(handler: HttpHandler, req: HttpRequest<any>) {
|
||||
return handler.handle(req).pipe(
|
||||
tap(
|
||||
(event: HttpEvent<any>) => {
|
||||
tap({
|
||||
next: (event: HttpEvent<any>) => {
|
||||
if (event instanceof HttpResponse) {
|
||||
// console.log(cloned);
|
||||
// console.log("Service Response thr Interceptor");
|
||||
}
|
||||
},
|
||||
(err: any) => {
|
||||
error: (err: any) => {
|
||||
if (err instanceof HttpErrorResponse) {
|
||||
console.log('err.status', err);
|
||||
if (err.status === 401) {
|
||||
logout();
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@ -45,16 +44,10 @@ export class AuthInterceptor implements HttpInterceptor {
|
||||
req: HttpRequest<any>,
|
||||
next: HttpHandler
|
||||
): Observable<HttpEvent<any>> {
|
||||
const idToken = localStorage.getItem('id_token');
|
||||
const cloned = req.clone({
|
||||
withCredentials: true,
|
||||
});
|
||||
|
||||
if (idToken) {
|
||||
const cloned = req.clone({
|
||||
headers: req.headers.set('Authorization', 'Bearer ' + idToken),
|
||||
});
|
||||
|
||||
return handle_request(next, cloned);
|
||||
} else {
|
||||
return handle_request(next, req);
|
||||
}
|
||||
return handle_request(next, cloned);
|
||||
}
|
||||
}
|
||||
|
@ -23,12 +23,11 @@ export class AuthService {
|
||||
setSession(authResult: string) {
|
||||
console.log(authResult);
|
||||
// it would be good to get an expiration date for the token...
|
||||
localStorage.setItem('id_token', authResult);
|
||||
localStorage.setItem('id_session', this.getRandomId());
|
||||
}
|
||||
|
||||
logout() {
|
||||
localStorage.removeItem('id_token');
|
||||
this.dataService.logout();
|
||||
localStorage.removeItem('id_session');
|
||||
this.forceReload = true;
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
import { Injectable, signal, WritableSignal } from '@angular/core';
|
||||
import { io, Socket } from 'socket.io-client';
|
||||
import { Observable } from 'rxjs';
|
||||
import { MessageEndpoints, EndpointInfo } from './redis_endpoints';
|
||||
import { AppConfigService } from '../app-config.service';
|
||||
import { EndpointInfo } from './redis_endpoints';
|
||||
import { ServerSettingsService } from '../server-settings.service';
|
||||
import { DeploymentService } from '../deployment.service';
|
||||
|
||||
@ -36,8 +35,6 @@ export class RedisConnectorService {
|
||||
timeout: 500, // Connection timeout in milliseconds
|
||||
path: '/api/v1/ws', // Path to the WebSocket server
|
||||
auth: {
|
||||
user: 'john_doe',
|
||||
token: '1234',
|
||||
deployment: id,
|
||||
},
|
||||
});
|
||||
|
@ -104,6 +104,16 @@ export class AuthDataService extends RemoteDataService {
|
||||
headers
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Method for logging out of BEC
|
||||
* @returns response from the server
|
||||
*/
|
||||
logout(): Promise<string> {
|
||||
let headers = new HttpHeaders();
|
||||
headers = headers.set('Content-Type', 'application/json; charset=utf-8');
|
||||
return firstValueFrom(this.post<string>('user/logout', {}, headers));
|
||||
}
|
||||
}
|
||||
|
||||
@Injectable({
|
||||
|
@ -20,7 +20,7 @@
|
||||
|
||||
<button mat-button>Profile</button>
|
||||
<button mat-button>Settings</button>
|
||||
<button mat-button>Logout</button>
|
||||
<button mat-button (click)="logout()">Logout</button>
|
||||
</mat-expansion-panel>
|
||||
|
||||
<mat-divider></mat-divider>
|
||||
@ -36,15 +36,10 @@
|
||||
mat-button
|
||||
class="menu-item"
|
||||
[routerLink]="['/dashboard/scan-table']"
|
||||
> Scan Data
|
||||
>
|
||||
Scan Data
|
||||
</button>
|
||||
<button
|
||||
mat-button
|
||||
class="menu-item"
|
||||
> Device Data
|
||||
</button>
|
||||
|
||||
|
||||
<button mat-button class="menu-item">Device Data</button>
|
||||
</mat-expansion-panel>
|
||||
|
||||
<!-- Experiment Control Expansion -->
|
||||
|
@ -5,7 +5,7 @@ import { BreakpointObserver } from '@angular/cdk/layout';
|
||||
import { CommonModule } from '@angular/common';
|
||||
import { MatButtonModule } from '@angular/material/button';
|
||||
import { MatDividerModule } from '@angular/material/divider';
|
||||
import { RouterModule } from '@angular/router';
|
||||
import { Router, RouterModule } from '@angular/router';
|
||||
import { MatExpansionModule } from '@angular/material/expansion';
|
||||
import { DeploymentService } from '../deployment.service';
|
||||
import {
|
||||
@ -15,6 +15,7 @@ import {
|
||||
} from '@angular/material/dialog';
|
||||
import { DeploymentSelectionComponent } from '../deployment-selection/deployment-selection.component';
|
||||
import { RedisConnectorService } from '../core/redis-connector.service';
|
||||
import { AuthDataService } from '../core/remote-data.service';
|
||||
@Component({
|
||||
selector: 'app-dashboard',
|
||||
imports: [
|
||||
@ -41,7 +42,9 @@ export class DashboardComponent {
|
||||
|
||||
constructor(
|
||||
private breakpointObserver: BreakpointObserver,
|
||||
private deploymentService: DeploymentService
|
||||
private deploymentService: DeploymentService,
|
||||
private authDataService: AuthDataService,
|
||||
private router: Router
|
||||
) {}
|
||||
|
||||
ngOnInit(): void {
|
||||
@ -89,4 +92,9 @@ export class DashboardComponent {
|
||||
this.openDeploymentDialog();
|
||||
}
|
||||
}
|
||||
|
||||
logout() {
|
||||
this.authDataService.logout();
|
||||
this.router.navigate(['/login']);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user