feat(auth): moved to httponly token

This commit is contained in:
2025-02-17 15:14:28 +01:00
parent 5abbec19c8
commit 0b107c9882
25 changed files with 293 additions and 150 deletions

View File

@ -2,10 +2,10 @@ from __future__ import annotations
import os
from datetime import datetime, timedelta
from typing import Annotated
from functools import wraps
import jwt
from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash
@ -19,6 +19,24 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form")
password_hash = PasswordHash.recommended()
def convert_to_user(func):
"""
Decorator to convert the current_user parameter to a User object.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
if "current_user" in kwargs:
current_user = kwargs["current_user"]
if current_user:
router = args[0]
user = router.get_user_from_db(current_user.token, current_user.email)
kwargs["current_user"] = user
return await func(*args, **kwargs)
return wrapper
def get_secret_key():
val = os.getenv("SECRET_KEY", "test_secret")
return val
@ -56,7 +74,12 @@ def decode_token(token: str):
raise credentials_exception from exc
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo:
async def get_current_user(request: Request) -> UserInfo:
token = request.cookies.get("access_token")
return get_current_user_sync(token)
def get_current_user_sync(token: str) -> UserInfo:
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
@ -70,4 +93,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> Use
raise credentials_exception
except Exception as exc:
raise credentials_exception from exc
return UserInfo(groups=groups, email=email)
return UserInfo(email=email, token=token)

View File

@ -67,7 +67,7 @@ class User(MongoBaseModel, AccessProfile):
class UserInfo(BaseModel):
email: str
groups: list[str]
token: str
class Deployments(MongoBaseModel, AccessProfile):

View File

@ -1,4 +1,23 @@
from functools import lru_cache
from bec_atlas.model.model import User
class BaseRouter:
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
self.datasources = datasources
self.prefix = prefix
@lru_cache(maxsize=128)
def get_user_from_db(self, _token: str, email: str) -> User:
"""
Get the user from the database. This is a helper function to be used by the
convert_to_user decorator. The function is cached to avoid repeated database
queries. To scope the cache to the current request, the token and email are
used as the cache key.
Args:
_token (str): The token
email (str): The email
"""
return self.datasources.datasources["mongodb"].get_user_by_email(email)

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import BECAccessProfile, UserInfo
from bec_atlas.model.model import BECAccessProfile, User, UserInfo
from bec_atlas.router.base_router import BaseRouter
@ -18,11 +18,12 @@ class BECAccessRouter(BaseRouter):
description="Retrieve the access key for a specific deployment and user.",
)
@convert_to_user
async def get_bec_access(
self,
deployment_id: str,
user: str = Query(None),
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
) -> dict:
"""
Retrieve the access key for a specific deployment and user.
@ -31,7 +32,7 @@ class BECAccessRouter(BaseRouter):
deployment_id (str): The deployment id
user (str): The user name to retrieve the access key for. If not provided,
the access key for the current user will be retrieved.
current_user (UserInfo): The current user
current_user (User): The current user
"""
if not user:
user = current_user.email

View File

@ -7,9 +7,9 @@ from bec_lib.serialization import MsgpackSerialization
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, UserInfo
from bec_atlas.model.model import BECAccessProfile, DeploymentAccess, User
from bec_atlas.router.base_router import BaseRouter
from bec_atlas.router.redis_router import RedisAtlasEndpoints
@ -37,8 +37,9 @@ class DeploymentAccessRouter(BaseRouter):
response_model=DeploymentAccess,
)
@convert_to_user
async def get_deployment_access(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
self, deployment_id: str, current_user: User = Depends(get_current_user)
) -> DeploymentAccess:
"""
Get the access lists for a specific deployment.
@ -56,11 +57,12 @@ class DeploymentAccessRouter(BaseRouter):
"deployments", {"_id": ObjectId(deployment_id)}, DeploymentAccess, user=current_user
)
@convert_to_user
async def patch_deployment_access(
self,
deployment_id: str,
deployment_access: dict,
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
) -> DeploymentAccess:
"""
Update the access lists for a specific deployment.

View File

@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentCredential, UserInfo
from bec_atlas.model.model import DeploymentCredential, User
from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING: # pragma: no cover
@ -33,8 +33,9 @@ class DeploymentCredentialsRouter(BaseRouter):
response_model=DeploymentCredential,
)
@convert_to_user
async def deployment_credential(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
self, deployment_id: str, current_user: User = Depends(get_current_user)
) -> DeploymentCredential:
"""
Get the credentials for a deployment.
@ -56,8 +57,9 @@ class DeploymentCredentialsRouter(BaseRouter):
status_code=403, detail="User does not have permission to access this resource."
)
@convert_to_user
async def refresh_deployment_credentials(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
self, deployment_id: str, current_user: User = Depends(get_current_user)
):
"""
Refresh the deployment credentials.

View File

@ -4,9 +4,9 @@ from typing import TYPE_CHECKING
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentCredential, Deployments, UserInfo
from bec_atlas.model.model import DeploymentCredential, Deployments, User
from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING: # pragma: no cover
@ -34,8 +34,9 @@ class DeploymentsRouter(BaseRouter):
)
self.update_available_deployments()
@convert_to_user
async def deployments(
self, realm: str, current_user: UserInfo = Depends(get_current_user)
self, realm: str, current_user: User = Depends(get_current_user)
) -> list[Deployments]:
"""
Get all deployments for a realm.
@ -48,8 +49,9 @@ class DeploymentsRouter(BaseRouter):
"""
return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user)
@convert_to_user
async def deployment_with_id(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
self, deployment_id: str, current_user: User = Depends(get_current_user)
):
"""
Get deployment with id from realm

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import DeploymentAccess, Realm, UserInfo
from bec_atlas.model.model import DeploymentAccess, Realm, User
from bec_atlas.router.base_router import BaseRouter
@ -36,8 +36,9 @@ class RealmRouter(BaseRouter):
response_model_exclude_none=True,
)
@convert_to_user
async def realms(
self, include_deployments: bool = False, current_user: UserInfo = Depends(get_current_user)
self, include_deployments: bool = False, current_user: User = Depends(get_current_user)
) -> list[Realm]:
"""
Get all realms.
@ -62,8 +63,9 @@ class RealmRouter(BaseRouter):
return self.db.aggregate("realms", include, Realm, user=current_user)
return self.db.find("realms", {}, Realm, user=current_user)
@convert_to_user
async def realm_with_deployment_access(
self, owner_only: bool = False, current_user: UserInfo = Depends(get_current_user)
self, owner_only: bool = False, current_user: User = Depends(get_current_user)
):
"""
Get all realms with deployment access.
@ -96,9 +98,8 @@ class RealmRouter(BaseRouter):
]
return self.db.aggregate("realms", include, Realm, user=current_user)
async def realm_with_id(
self, realm_id: str, current_user: UserInfo = Depends(get_current_user)
):
@convert_to_user
async def realm_with_id(self, realm_id: str, current_user: User = Depends(get_current_user)):
"""
Get realm with id.

View File

@ -1,4 +1,5 @@
import asyncio
import enum
import functools
import inspect
import json
@ -10,8 +11,11 @@ import socketio
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
from bec_lib.logger import bec_logger
from bec_lib.serialization import MsgpackSerialization, json_ext
from fastapi import APIRouter, Query, Response
from bson import ObjectId
from fastapi import APIRouter, Depends, Query, Response
from bec_atlas.authentication import convert_to_user, get_current_user, get_current_user_sync
from bec_atlas.model.model import DeploymentAccess, User
from bec_atlas.router.base_router import BaseRouter
logger = bec_logger.logger
@ -20,6 +24,12 @@ if TYPE_CHECKING:
from bec_lib.redis_connector import RedisConnector
class RemoteAccess(enum.Enum):
READ = "read"
READ_WRITE = "read_write"
NONE = "none"
class RedisAtlasEndpoints:
"""
This class contains the endpoints for the Redis API. It is used to
@ -133,7 +143,10 @@ class RedisRouter(BaseRouter):
self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
async def redis_get(self, deployment: str, key: str = Query(...)):
@convert_to_user
async def redis_get(
self, deployment: str, key: str = Query(...), current_user: User = Depends(get_current_user)
):
request_id = uuid.uuid4().hex
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
@ -149,10 +162,14 @@ class RedisRouter(BaseRouter):
return json_ext.dumps({"data": out.content, "metadata": out.metadata})
async def redis_post(self, key: str, value: str):
@convert_to_user
async def redis_post(
self, key: str, value: str, current_user: User = Depends(get_current_user)
):
return self.redis.set(key, value)
async def redis_delete(self, key: str):
@convert_to_user
async def redis_delete(self, key: str, current_user: User = Depends(get_current_user)):
return self.redis.delete(key)
@ -268,6 +285,7 @@ class RedisWebsocket:
redis_host = datasources.datasources["redis"].config["host"]
redis_port = datasources.datasources["redis"].config["port"]
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
self.db = datasources.datasources["mongodb"]
self.socket = socketio.AsyncServer(
transports=["websocket"],
ping_timeout=60,
@ -289,7 +307,7 @@ class RedisWebsocket:
self.socket.on("disconnect", self.disconnect_client)
print("Redis websocket started")
def _validate_new_user(self, http_query: str | None):
def _validate_new_user(self, http_query: str | None, auth_token: str) -> tuple:
"""
Validate the connection of a new user. In particular,
the user must provide a valid token as well as have access
@ -310,19 +328,41 @@ class RedisWebsocket:
else:
query = http_query
if "user" not in query:
raise ValueError("User not found in query parameters")
user = query["user"]
# TODO: Validate the user token
user_info = get_current_user_sync(auth_token)
user = self.db.find_one("users", {"email": user_info.email}, User)
deployment = query.get("deployment")
if not deployment:
raise ValueError("Deployment not found in query parameters")
# TODO: Validate the user has access to the deployment
deployment_access = self.db.find_one(
"deployments", {"_id": ObjectId(deployment)}, DeploymentAccess
)
if not deployment_access:
raise ValueError("Deployment not found")
return user, deployment
access = self.get_access(user, deployment_access)
if access == RemoteAccess.NONE:
raise ValueError("User does not have remote access to the deployment")
return user, deployment, access
def get_access(self, user: User, deployment_access: DeploymentAccess) -> RemoteAccess:
"""
Get the access level of the user to the deployment.
"""
access = RemoteAccess.NONE
groups = set(user.groups)
if user.username is not None:
groups.add(user.username)
if user.email is not None:
groups.add(user.email)
if groups & set(deployment_access.remote_read_access):
access = RemoteAccess.READ
if groups & set(deployment_access.remote_write_access):
access = RemoteAccess.READ_WRITE
return access
@safe_socket
async def connect_client(self, sid, environ=None, auth=None, **kwargs):
@ -332,7 +372,23 @@ class RedisWebsocket:
http_query = environ.get("HTTP_QUERY") or auth
user, deployment = self._validate_new_user(http_query)
cookies = environ.get("HTTP_COOKIE", "")
auth_token = None
for cookie in cookies.split("; "):
if cookie.startswith("access_token="):
auth_token = cookie.split("=")[1]
break
if not auth_token:
await self.disconnect_client(sid) # Reject connection
return
try:
user, deployment, access = self._validate_new_user(http_query, auth_token)
except ValueError:
await self.disconnect_client(sid, reason="Invalid user or deployment")
return
# check if the user was already registered in redis
socketio_server_keys = await self.socket.manager.redis.keys(
@ -350,22 +406,25 @@ class RedisWebsocket:
for value in obj.values():
info[value["user"]] = value["subscriptions"]
if user in info:
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
if user.email in info:
self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
for endpoint, endpoint_request in info[user]:
print(f"Registering {endpoint}")
await self._update_user_subscriptions(sid, endpoint, endpoint_request)
else:
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
self.users[sid] = {"user": user.email, "subscriptions": [], "deployment": deployment}
await self.socket.manager.update_websocket_states()
async def disconnect_client(self, sid, _environ=None):
print("Client disconnected")
async def disconnect_client(self, sid, reason: str = None, _environ=None):
is_exit = self.fastapi_app.server.should_exit
if is_exit:
return
await self.socket.manager.remove_user(sid)
if reason:
await self.socket.emit("error", {"error": reason}, room=sid)
if sid in self.users:
del self.users[sid]
await self.socket.disconnect(sid)
@safe_socket
async def redis_register(self, sid: str, msg: str):

View File

@ -5,9 +5,9 @@ import json
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import get_current_user
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import ScanStatusPartial, ScanUserData, UserInfo
from bec_atlas.model.model import ScanStatusPartial, ScanUserData, User, UserInfo
from bec_atlas.router.base_router import BaseRouter
@ -47,6 +47,7 @@ class ScanRouter(BaseRouter):
response_model=dict,
)
@convert_to_user
async def scans(
self,
session_id: str,
@ -55,7 +56,7 @@ class ScanRouter(BaseRouter):
offset: int = 0,
limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
) -> list[ScanStatusPartial]:
"""
Get all scans for a session.
@ -69,7 +70,7 @@ class ScanRouter(BaseRouter):
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user
current_user (User): The current user
Returns:
list[ScanStatusPartial]: List of scans
@ -100,11 +101,12 @@ class ScanRouter(BaseRouter):
user=current_user,
)
@convert_to_user
async def scans_with_id(
self,
scan_id: str,
fields: list[str] = Query(default=None),
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
):
"""
Get scan with id from session
@ -135,6 +137,7 @@ class ScanRouter(BaseRouter):
scan_id (str): The scan id
user_data (dict): The user data to update
"""
current_user = self.get_user_from_db(current_user.token, current_user.email)
out = self.db.patch(
"scans",
id=scan_id,
@ -147,15 +150,16 @@ class ScanRouter(BaseRouter):
raise HTTPException(status_code=404, detail="Scan not found")
return {"message": "Scan user data updated."}
@convert_to_user
async def count_scans(
self, filter: str | None = None, current_user: UserInfo = Depends(get_current_user)
self, filter: str | None = None, current_user: User = Depends(get_current_user)
) -> int:
"""
Count the number of scans.
Args:
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
current_user (UserInfo): The current user
current_user (User): The current user
Returns:
int: The number of scans

View File

@ -1,15 +1,10 @@
import json
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo
from bec_atlas.model.model import Session
from bec_atlas.model.model import Session, User
from bec_atlas.router.base_router import BaseRouter
@ -35,6 +30,7 @@ class SessionRouter(BaseRouter):
response_model_exclude_none=True,
)
@convert_to_user
async def sessions(
self,
filter: str | None = None,
@ -42,7 +38,7 @@ class SessionRouter(BaseRouter):
offset: int = 0,
limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
) -> list[Session]:
"""
Get all sessions.
@ -55,7 +51,7 @@ class SessionRouter(BaseRouter):
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user
current_user (User): The current user
Returns:
list[Sessions]: List of sessions
@ -82,6 +78,7 @@ class SessionRouter(BaseRouter):
user=current_user,
)
@convert_to_user
async def sessions_by_realm(
self,
realm_id: str,
@ -90,7 +87,7 @@ class SessionRouter(BaseRouter):
offset: int = 0,
limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user),
current_user: User = Depends(get_current_user),
) -> list[Session]:
"""
Get all sessions for a realm.

View File

@ -1,11 +1,16 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Response
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
from bec_atlas.authentication import (
convert_to_user,
create_access_token,
get_current_user,
verify_password,
)
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo
from bec_atlas.model.model import User
@ -19,8 +24,9 @@ class UserLoginRequest(BaseModel):
class UserRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
super().__init__(prefix, datasources)
self.use_ssl = use_ssl
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.ldap = LDAPUserService(
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
@ -31,23 +37,30 @@ class UserRouter(BaseRouter):
self.router.add_api_route(
"/user/login/form", self.form_login, methods=["POST"], dependencies=[]
)
self.router.add_api_route("/user/logout", self.user_logout, methods=["POST"])
async def user_me(self, user: UserInfo = Depends(get_current_user)):
data = self.db.get_user_by_email(user.email)
if data is None:
raise HTTPException(status_code=404, detail="User not found")
return data
@convert_to_user
async def user_me(self, user: User = Depends(get_current_user)):
return user
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
async def form_login(
self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response
):
user_login = UserLoginRequest(username=form_data.username, password=form_data.password)
out = await self.user_login(user_login)
out = await self.user_login(user_login, response)
return {"access_token": out, "token_type": "bearer"}
async def user_login(self, user_login: UserLoginRequest):
async def user_login(self, user_login: UserLoginRequest, response: Response):
user = self._get_user(user_login)
if user is None:
raise HTTPException(status_code=401, detail="User not found or password is incorrect")
return create_access_token(data={"groups": list(user.groups), "email": user.email})
token = create_access_token(data={"groups": list(user.groups), "email": user.email})
response.set_cookie(key="access_token", value=token, httponly=True, secure=self.use_ssl)
return token
async def user_logout(self, response: Response):
response.delete_cookie("access_token")
return {"message": "Logged out"}
def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None:
user = self._get_functional_account(user_login)

View File

@ -118,4 +118,5 @@ def backend(redis_server):
"bec_atlas.router.redis_router.BECAsyncRedisManager", PatchedBECAsyncRedisManager
):
with TestClient(app.app) as _client:
app.user_router.use_ssl = False # disable ssl to allow for httponly cookies
yield _client, app

View File

@ -13,7 +13,6 @@ def logged_in_client(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client

View File

@ -11,7 +11,6 @@ def logged_in_client(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client
@ -70,7 +69,6 @@ def test_deployment_credential_rejects_unauthorized_user(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
deployments = client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
@ -96,7 +94,6 @@ def test_refresh_deployment_credentials_rejects_unauthorized_user(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
deployments = client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}

View File

@ -11,7 +11,6 @@ def logged_in_client(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client

View File

@ -2,9 +2,10 @@ import json
from unittest import mock
import pytest
from bec_atlas.router.redis_router import RedisAtlasEndpoints
from bec_lib.endpoints import MessageEndpoints
from bec_atlas.router.redis_router import RedisAtlasEndpoints, RemoteAccess
@pytest.fixture
def backend_client(backend):
@ -12,48 +13,75 @@ def backend_client(backend):
app.server = mock.Mock()
app.server.should_exit = False
app.redis_websocket.users = {}
yield client, app
# app.redis_websocket.redis._redis_conn.flushall()
response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
assert response.status_code == 200
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
return client, app
@pytest.fixture
@pytest.mark.asyncio(loop_scope="session")
async def connected_ws(backend_client):
client, app = backend_client
deployment = client.get("/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}).json()
with mock.patch.object(app.redis_websocket, "get_access", return_value=RemoteAccess.READ):
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid",
{
"HTTP_QUERY": json.dumps({"deployment": deployment[0]["_id"]}),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
)
yield backend_client
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_connect(backend_client):
client, app = backend_client
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
async def test_redis_websocket_connect(connected_ws):
_, app = await anext(connected_ws)
assert "sid" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_disconnect(backend_client):
client, app = backend_client
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
async def test_redis_websocket_disconnect(connected_ws):
_, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect(backend_client):
client, app = backend_client
async def test_redis_websocket_multiple_connect(connected_ws):
client, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid1", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
"sid2",
{
"HTTP_QUERY": json.dumps(
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid2", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
assert "sid1" in app.redis_websocket.users
assert "sid" in app.redis_websocket.users
assert "sid2" in app.redis_websocket.users
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_connect_same_sid(backend_client):
client, app = backend_client
async def test_redis_websocket_multiple_connect_same_sid(connected_ws):
client, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
"sid",
{
"HTTP_QUERY": json.dumps(
{"deployment": app.redis_websocket.users["sid"]["deployment"]}
),
"HTTP_COOKIE": f"access_token={client.cookies.get('access_token')}",
},
)
assert "sid" in app.redis_websocket.users
@ -61,9 +89,8 @@ async def test_redis_websocket_multiple_connect_same_sid(backend_client):
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
client, app = backend_client
app.redis_websocket.users["sid"] = {"user": "test", "subscriptions": []}
async def test_redis_websocket_multiple_disconnect_same_sid(connected_ws):
client, app = await anext(connected_ws)
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
await app.redis_websocket.socket.handlers["/"]["disconnect"]("sid")
assert "sid" not in app.redis_websocket.users
@ -82,14 +109,10 @@ async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
@pytest.mark.asyncio(loop_scope="session")
async def test_redis_websocket_register(backend_client):
client, app = backend_client
async def test_redis_websocket_register(connected_ws):
client, app = await anext(connected_ws)
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
with mock.patch.object(app.redis_websocket.socket, "enter_room") as enter_room:
await app.redis_websocket.socket.handlers["/"]["connect"](
"sid", {"HTTP_QUERY": '{"user": "test", "deployment": "test"}'}
)
await app.redis_websocket.socket.handlers["/"]["register"](
"sid", json.dumps({"endpoint": "scan_status"})
)
@ -97,7 +120,8 @@ async def test_redis_websocket_register(backend_client):
enter_room.assert_called_with(
"sid",
RedisAtlasEndpoints.socketio_endpoint_room(
"test", MessageEndpoints.scan_status().endpoint
app.redis_websocket.users["sid"]["deployment"],
MessageEndpoints.scan_status().endpoint,
),
)

View File

@ -91,7 +91,6 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
client.headers.update({"Authorization": f"Bearer {response.json()}"})
session_id = msg.info.get("session_id")
scan_id = msg.scan_id

View File

@ -14,7 +14,6 @@ def logged_in_client(backend):
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client

View File

@ -11,29 +11,28 @@ import { Observable } from 'rxjs';
import { tap } from 'rxjs/operators';
function logout() {
localStorage.removeItem('id_token');
localStorage.removeItem('id_session');
location.href = '/login';
}
function handle_request(handler: HttpHandler, req: HttpRequest<any>) {
return handler.handle(req).pipe(
tap(
(event: HttpEvent<any>) => {
tap({
next: (event: HttpEvent<any>) => {
if (event instanceof HttpResponse) {
// console.log(cloned);
// console.log("Service Response thr Interceptor");
}
},
(err: any) => {
error: (err: any) => {
if (err instanceof HttpErrorResponse) {
console.log('err.status', err);
if (err.status === 401) {
logout();
}
}
}
)
},
})
);
}
@ -45,16 +44,10 @@ export class AuthInterceptor implements HttpInterceptor {
req: HttpRequest<any>,
next: HttpHandler
): Observable<HttpEvent<any>> {
const idToken = localStorage.getItem('id_token');
const cloned = req.clone({
withCredentials: true,
});
if (idToken) {
const cloned = req.clone({
headers: req.headers.set('Authorization', 'Bearer ' + idToken),
});
return handle_request(next, cloned);
} else {
return handle_request(next, req);
}
return handle_request(next, cloned);
}
}

View File

@ -23,12 +23,11 @@ export class AuthService {
setSession(authResult: string) {
console.log(authResult);
// it would be good to get an expiration date for the token...
localStorage.setItem('id_token', authResult);
localStorage.setItem('id_session', this.getRandomId());
}
logout() {
localStorage.removeItem('id_token');
this.dataService.logout();
localStorage.removeItem('id_session');
this.forceReload = true;
}

View File

@ -1,8 +1,7 @@
import { Injectable, signal, WritableSignal } from '@angular/core';
import { io, Socket } from 'socket.io-client';
import { Observable } from 'rxjs';
import { MessageEndpoints, EndpointInfo } from './redis_endpoints';
import { AppConfigService } from '../app-config.service';
import { EndpointInfo } from './redis_endpoints';
import { ServerSettingsService } from '../server-settings.service';
import { DeploymentService } from '../deployment.service';
@ -36,8 +35,6 @@ export class RedisConnectorService {
timeout: 500, // Connection timeout in milliseconds
path: '/api/v1/ws', // Path to the WebSocket server
auth: {
user: 'john_doe',
token: '1234',
deployment: id,
},
});

View File

@ -104,6 +104,16 @@ export class AuthDataService extends RemoteDataService {
headers
);
}
/**
* Method for logging out of BEC
* @returns response from the server
*/
logout(): Promise<string> {
let headers = new HttpHeaders();
headers = headers.set('Content-Type', 'application/json; charset=utf-8');
return firstValueFrom(this.post<string>('user/logout', {}, headers));
}
}
@Injectable({

View File

@ -20,7 +20,7 @@
<button mat-button>Profile</button>
<button mat-button>Settings</button>
<button mat-button>Logout</button>
<button mat-button (click)="logout()">Logout</button>
</mat-expansion-panel>
<mat-divider></mat-divider>
@ -36,15 +36,10 @@
mat-button
class="menu-item"
[routerLink]="['/dashboard/scan-table']"
> Scan Data
>
Scan Data
</button>
<button
mat-button
class="menu-item"
> Device Data
</button>
<button mat-button class="menu-item">Device Data</button>
</mat-expansion-panel>
<!-- Experiment Control Expansion -->

View File

@ -5,7 +5,7 @@ import { BreakpointObserver } from '@angular/cdk/layout';
import { CommonModule } from '@angular/common';
import { MatButtonModule } from '@angular/material/button';
import { MatDividerModule } from '@angular/material/divider';
import { RouterModule } from '@angular/router';
import { Router, RouterModule } from '@angular/router';
import { MatExpansionModule } from '@angular/material/expansion';
import { DeploymentService } from '../deployment.service';
import {
@ -15,6 +15,7 @@ import {
} from '@angular/material/dialog';
import { DeploymentSelectionComponent } from '../deployment-selection/deployment-selection.component';
import { RedisConnectorService } from '../core/redis-connector.service';
import { AuthDataService } from '../core/remote-data.service';
@Component({
selector: 'app-dashboard',
imports: [
@ -41,7 +42,9 @@ export class DashboardComponent {
constructor(
private breakpointObserver: BreakpointObserver,
private deploymentService: DeploymentService
private deploymentService: DeploymentService,
private authDataService: AuthDataService,
private router: Router
) {}
ngOnInit(): void {
@ -89,4 +92,9 @@ export class DashboardComponent {
this.openDeploymentDialog();
}
}
logout() {
this.authDataService.logout();
this.router.navigate(['/login']);
}
}