diff --git a/backend/bec_atlas/datasources/mongodb/mongodb.py b/backend/bec_atlas/datasources/mongodb/mongodb.py index 776c08b..e0d60e6 100644 --- a/backend/bec_atlas/datasources/mongodb/mongodb.py +++ b/backend/bec_atlas/datasources/mongodb/mongodb.py @@ -196,6 +196,8 @@ class MongoDBDatasource: if user is not None: data = self.add_user_filter(user, data, operation="w") out = self.db[collection].insert_one(data) + if dtype is None: + return data return dtype(**data) def patch( @@ -229,6 +231,8 @@ class MongoDBDatasource: ) if out is None: return None + if dtype is None: + return out return dtype(**out) def delete_one(self, collection: str, filter: dict, user: User | None = None) -> bool: diff --git a/backend/bec_atlas/model/model.py b/backend/bec_atlas/model/model.py index 319d1ba..b1ca975 100644 --- a/backend/bec_atlas/model/model.py +++ b/backend/bec_atlas/model/model.py @@ -62,6 +62,7 @@ class User(MongoBaseModel, AccessProfile): groups: list[str] first_name: str last_name: str + username: str | None = None class UserInfo(BaseModel): diff --git a/backend/bec_atlas/router/user_router.py b/backend/bec_atlas/router/user_router.py index f6867ce..5f9aed5 100644 --- a/backend/bec_atlas/router/user_router.py +++ b/backend/bec_atlas/router/user_router.py @@ -8,7 +8,9 @@ from pydantic import BaseModel from bec_atlas.authentication import create_access_token, get_current_user, verify_password from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.model import UserInfo +from bec_atlas.model.model import User from bec_atlas.router.base_router import BaseRouter +from bec_atlas.utils.ldap_auth import LDAPUserService class UserLoginRequest(BaseModel): @@ -20,6 +22,9 @@ class UserRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.ldap = LDAPUserService( + ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch" + ) self.router = APIRouter(prefix=prefix) self.router.add_api_route("/user/me", self.user_me, methods=["GET"]) self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[]) @@ -39,14 +44,48 @@ class UserRouter(BaseRouter): return {"access_token": out, "token_type": "bearer"} async def user_login(self, user_login: UserLoginRequest): - exc = HTTPException(status_code=401, detail="User not found or password is incorrect") + user = self._get_user(user_login) + if user is None: + raise HTTPException(status_code=401, detail="User not found or password is incorrect") + return create_access_token(data={"groups": list(user.groups), "email": user.email}) + + def _get_user(self, user_login: UserLoginRequest) -> UserInfo | None: + user = self._get_functional_account(user_login) + if user is None: + user = self._get_ad_account(user_login) + return user + + def _get_functional_account(self, user_login: UserLoginRequest) -> UserInfo | None: user = self.db.get_user_by_email(user_login.username) if user is None: - raise exc + return None credentials = self.db.get_user_credentials(user.id) if credentials is None: - raise exc + return None if not verify_password(user_login.password, credentials.password): - raise exc + return None + return user - return create_access_token(data={"groups": list(user.groups), "email": user.email}) + def _get_ad_account(self, user_login: UserLoginRequest) -> User | None: + user = self.ldap.authenticate_and_get_info(user_login.username, user_login.password) + if user is None: + return None + user_info = User( + owner_groups=["admin"], + email=user["email"], + first_name=user["first_name"], + last_name=user["last_name"], + username=user["username"], + groups=user["roles"], + ) + # update the user info in the database + user = self.db.get_user_by_email(user_info.email) + if user is None: + self.db.post( + collection="users", data=user_info.model_dump(exclude_none=True), dtype=None + ) + else: + self.db.patch( + collection="users", id=user.id, update={"groups": user_info.groups}, dtype=None + ) + return user_info