
This change introduces a fallback mechanism for development environments, allowing authentication to simulate a user from a local file named "user" when not found in the mock database. The file must include a username on the first line and a space-delimited list of pgroups on the second line. This enhancement helps streamline development workflows while maintaining error handling for missing or malformed files.
137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
from fastapi import APIRouter, HTTPException, status, Depends
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from fastapi.security import OAuth2AuthorizationCodeBearer
|
|
from app.schemas import loginToken, loginData
|
|
|
|
import jwt
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
# Define an APIRouter for authentication
|
|
router = APIRouter()
|
|
|
|
|
|
mock_users_db = {
|
|
"testuser": {
|
|
"username": "testuser",
|
|
"password": "testpass", # In a real scenario, store the hash of the password
|
|
"pgroups": ["p20000", "p20001", "p20002", "p20003"],
|
|
},
|
|
"testuser2": {
|
|
"username": "testuser2",
|
|
"password": "testpass2", # In a real scenario, store the hash of the password
|
|
"pgroups": ["p20004", "p20005", "p20006"],
|
|
},
|
|
"admin": {
|
|
"username": "admin",
|
|
"password": "adminpass",
|
|
"pgroups": ["p20007"],
|
|
# "role": "admin",
|
|
},
|
|
}
|
|
|
|
|
|
# https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
|
|
# SECRET_KEY taken from FastAPI documentation, so not that secret :D
|
|
# openssl rand -hex 32
|
|
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
|
authorizationUrl="/login", tokenUrl="/token/login"
|
|
)
|
|
|
|
|
|
def create_access_token(data: dict) -> str:
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
to_encode.update({"exp": expire})
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256")
|
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
print(f"[DEBUG] Username decoded from token: {username}") # Add debug log here
|
|
return loginData(username=username, pgroups=payload.get("pgroups"))
|
|
# return loginData(username=username, pgroups=payload.get("pgroups"),
|
|
# role=payload.get("role"))
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
print("[DEBUG] Token expired")
|
|
raise HTTPException(status_code=401, detail="Token expired")
|
|
except jwt.InvalidTokenError:
|
|
print("[DEBUG] Invalid token")
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
# async def get_user_role(token: str = Depends(oauth2_scheme)) -> str:
|
|
# try:
|
|
# payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
# return payload.get("role")
|
|
# except jwt.ExpiredSignatureError:
|
|
# raise HTTPException(status_code=401, detail="Token expired")
|
|
|
|
|
|
@router.post("/token/login", response_model=loginToken)
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
# Attempt to find the user in the normal DB
|
|
user = mock_users_db.get(form_data.username)
|
|
if user is not None:
|
|
# Verify password as usual for known users
|
|
if user["password"] != form_data.password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
else:
|
|
# For development: use the local user file.
|
|
# Any password is accepted for the file-based user.
|
|
file_path = "user" # Adjust path if your file is somewhere else.
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Development user file not found",
|
|
)
|
|
with open(file_path, "r") as f:
|
|
lines = f.read().splitlines()
|
|
if len(lines) < 2:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="User file must have at least two lines:"
|
|
" one for username and one for pgroups",
|
|
)
|
|
file_username = lines[0].strip()
|
|
# If desired, you can check if the provided username matches the one in
|
|
# your file:
|
|
if form_data.username != file_username:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Username not found",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
# Second line: pgroups are assumed to be space-delimited.
|
|
pgroups = lines[1].strip().split()
|
|
|
|
user = {
|
|
"username": file_username,
|
|
"pgroups": pgroups,
|
|
}
|
|
# Note: Any provided password is accepted for the user defined in the file.
|
|
|
|
# Create access token from the user details
|
|
access_token = create_access_token(
|
|
data={"sub": user["username"], "pgroups": user["pgroups"]}
|
|
)
|
|
return loginToken(access_token=access_token, token_type="bearer")
|
|
|
|
|
|
@router.get("/protected-route")
|
|
async def read_protected_data(current_user: loginData = Depends(get_current_user)):
|
|
# return {"username": current_user.username, "pgroups":
|
|
# current_user.pgroups, "role": current_user.role}
|
|
return {"username": current_user.username, "pgroups": current_user.pgroups}
|