
This change introduces a fallback mechanism for development environments, allowing authentication to simulate a user from a local file named "user" when not found in the mock database. The file must include a username on the first line and a space-delimited list of pgroups on the second line. This enhancement helps streamline development workflows while maintaining error handling for missing or malformed files.
130 lines
4.8 KiB
Python
130 lines
4.8 KiB
Python
from fastapi import APIRouter, HTTPException, status, Depends
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from fastapi.security import OAuth2AuthorizationCodeBearer
|
|
from app.schemas import loginToken, loginData
|
|
|
|
import jwt
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
# Define an APIRouter for authentication
|
|
router = APIRouter()
|
|
|
|
|
|
mock_users_db = {
|
|
"testuser": {
|
|
"username": "testuser",
|
|
"password": "testpass", # In a real scenario, store the hash of the password
|
|
"pgroups": ["p20000", "p20001", "p20002", "p20003"],
|
|
},
|
|
"testuser2": {
|
|
"username": "testuser2",
|
|
"password": "testpass2", # In a real scenario, store the hash of the password
|
|
"pgroups": ["p20004", "p20005", "p20006"],
|
|
},
|
|
"admin": {
|
|
"username": "admin",
|
|
"password": "adminpass",
|
|
"pgroups": ["p20007"],
|
|
# "role": "admin",
|
|
},
|
|
}
|
|
|
|
|
|
# https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
|
|
# SECRET_KEY taken from FastAPI documentation, so not that secret :D
|
|
# openssl rand -hex 32
|
|
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
|
authorizationUrl="/login", tokenUrl="/token/login"
|
|
)
|
|
|
|
|
|
def create_access_token(data: dict) -> str:
|
|
to_encode = data.copy()
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
to_encode.update({"exp": expire})
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256")
|
|
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
print(f"[DEBUG] Username decoded from token: {username}") # Add debug log here
|
|
return loginData(username=username, pgroups=payload.get("pgroups"))
|
|
# return loginData(username=username, pgroups=payload.get("pgroups"),
|
|
# role=payload.get("role"))
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
print("[DEBUG] Token expired")
|
|
raise HTTPException(status_code=401, detail="Token expired")
|
|
except jwt.InvalidTokenError:
|
|
print("[DEBUG] Invalid token")
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
# async def get_user_role(token: str = Depends(oauth2_scheme)) -> str:
|
|
# try:
|
|
# payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
# return payload.get("role")
|
|
# except jwt.ExpiredSignatureError:
|
|
# raise HTTPException(status_code=401, detail="Token expired")
|
|
|
|
|
|
@router.post("/token/login", response_model=loginToken)
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
user = mock_users_db.get(form_data.username)
|
|
if user is None or user["password"] != form_data.password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
else:
|
|
# For development only: if the user is not in the mock db,
|
|
# then simulate authentication.
|
|
# Read the pgroups from the local file called "user" that lives only on
|
|
# your machine.
|
|
file_path = "user" # Adjust path as needed
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Development user file not found",
|
|
)
|
|
with open(file_path, "r") as f:
|
|
lines = f.read().splitlines()
|
|
if len(lines) < 2:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="User file must have at least two lines: "
|
|
"a username on the first and the list of pgroups on the second",
|
|
)
|
|
# The first line of the file is the username
|
|
file_username = lines[0].strip()
|
|
# The second line is the list of pgroups (space-delimited)
|
|
pgroups = lines[1].strip().split()
|
|
|
|
user = {
|
|
"username": file_username,
|
|
"pgroups": pgroups,
|
|
}
|
|
|
|
# Create token
|
|
access_token = create_access_token(
|
|
data={"sub": user["username"], "pgroups": user["pgroups"]}
|
|
# data = {"sub": user["username"], "pgroups": user["pgroups"],
|
|
# "role": user["role"]}
|
|
)
|
|
return loginToken(access_token=access_token, token_type="bearer")
|
|
|
|
|
|
@router.get("/protected-route")
|
|
async def read_protected_data(current_user: loginData = Depends(get_current_user)):
|
|
# return {"username": current_user.username, "pgroups":
|
|
# current_user.pgroups, "role": current_user.role}
|
|
return {"username": current_user.username, "pgroups": current_user.pgroups}
|