aaredb/backend/main.py
GotthardG 86d03285e4 Update server config, SSL handling, and port mapping logic
Refactored `run_server` to accept explicit config and SSL paths. Added dynamic environment-based config loading and stricter SSL path checks for production. Updated `docker-compose.yml` to use environment variable for port mapping and adjusted `config_prod.json` to reflect correct port usage.
2025-04-11 12:37:18 +02:00

280 lines
9.6 KiB
Python

import os
import json
import tomllib
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from app import ssl_heidi
from app.routers import (
proposal,
puck,
spreadsheet,
logistics,
auth,
sample,
processing,
)
from app.database import Base, engine, SessionLocal
from app.routers.protected_router import protected_router
os.makedirs("images", exist_ok=True)
# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
script_dir = Path(__file__).resolve().parent
pyproject_path = script_dir / "pyproject.toml" # Check current directory first
if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
name = pyproject["project"]["name"]
version = pyproject["project"]["version"]
return name, version
# Search in parent directories
for parent in script_dir.parents:
pyproject_path = parent / "pyproject.toml"
if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
name = pyproject["project"]["name"]
version = pyproject["project"]["version"]
return name, version
raise FileNotFoundError(
f"pyproject.toml not found in any parent directory of {script_dir}"
)
def run_server(config, cert_path, key_path):
import uvicorn
environment = os.getenv(
"ENVIRONMENT", "dev"
) # needs to be set explicitly here if not globally available
print(f"[INFO] Starting server in {environment} environment...")
print(f"[INFO] SSL Certificate Path: {cert_path}")
print(f"[INFO] SSL Key Path: {key_path}")
port = config.get("PORT")
if not port:
print("[ERROR] No port defined in config or environment variables. Aborting!")
sys.exit(1) # Exit if no port is defined
port = int(port)
print(f"[INFO] Running on port {port}")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()
# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"
if not config_file.exists():
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
# Load configuration
with open(config_file) as f:
config = json.load(f)
# Set SSL paths based on environment
if environment in ["test", "dev"]:
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
key_path = config.get("ssl_key_path", "ssl/key.pem")
ssl_dir = Path(cert_path).parent
# Ensure the directory exists before file operations
ssl_dir.mkdir(parents=True, exist_ok=True)
elif environment == "prod":
cert_path = config.get("SSL_CERT_PATH")
key_path = config.get("SSL_KEY_PATH")
if not cert_path or not key_path:
raise ValueError(
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
" for production."
)
if not Path(cert_path).exists() or not Path(key_path).exists():
raise FileNotFoundError(
f"Missing SSL files in production. Ensure the following files exist:\n"
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
)
else:
raise ValueError(f"Unknown environment: {environment}")
# Generate SSL Key and Certificate if they do not exist
if environment == "dev":
if not Path(cert_path).exists() or not Path(key_path).exists():
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
@asynccontextmanager
async def lifespan(app: FastAPI):
print("[INFO] Running application startup tasks...")
db = SessionLocal()
try:
if environment == "prod":
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
# from app.models import ExperimentParameters # adjust the import as needed
#
# inspector = reflection.Inspector.from_engine(engine)
# tables_exist = inspector.get_table_names()
#
# if ExperimentParameters.__tablename__ not in tables_exist:
# print("Creating missing table: ExperimentParameters")
# ExperimentParameters.__table__.create(bind=engine)
#
# Ensure the production database is initialized
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
# Seed the database (slots + proposals)
from app.database import load_slots_data
load_slots_data(db)
else: # dev or test environments
print(f"{environment.capitalize()} environment: Regenerating database.")
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# from sqlalchemy.engine import reflection
# from app.models import ExperimentParameters # adjust the import as needed
# inspector = reflection.Inspector.from_engine(engine)
# tables_exist = inspector.get_table_names()
#
# if ExperimentParameters.__tablename__ not in tables_exist:
# print("Creating missing table: ExperimentParameters")
# ExperimentParameters.__table__.create(bind=engine)
#
if environment == "dev":
from app.database import load_sample_data
load_sample_data(db)
elif environment == "test":
from app.database import load_slots_data
load_slots_data(db)
yield
finally:
db.close()
app = FastAPI(
lifespan=lifespan,
title=project_name,
description="Backend for next-gen sample management system",
version=project_version,
servers=[
{"url": "https://mx-aare-test.psi.ch:8000", "description": "Default server"}
],
)
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers with correct configuration
app.include_router(protected_router, prefix="/protected")
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])
app.include_router(processing.router, prefix="/processing", tags=["processing"])
app.mount("/images", StaticFiles(directory="images"), name="images")
if __name__ == "__main__":
import sys
from dotenv import load_dotenv
from multiprocessing import Process
from time import sleep
# Load environment variables from .env file
load_dotenv()
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"
# Check if `generate-openapi` option is passed
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
from fastapi.openapi.utils import get_openapi
# Generate and save OpenAPI JSON file
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
with open("openapi.json", "w") as f:
json.dump(openapi_schema, f, indent=4)
print("openapi.json generated successfully.")
sys.exit(0) # Exit after generating the file
# Explicitly load the configuration file
with open(config_file, "r") as f:
config = json.load(f)
# Explicitly obtain SSL paths from config
if environment in ["test", "dev"]:
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
key_path = config.get("ssl_key_path", "ssl/key.pem")
elif environment == "prod":
cert_path = config.get("SSL_CERT_PATH")
key_path = config.get("SSL_KEY_PATH")
if not cert_path or not key_path:
raise ValueError(
"SSL_CERT_PATH and SSL_KEY_PATH must be explicitly"
"set in config_prod.json for production."
)
else:
raise ValueError(f"Unknown environment: {environment}")
is_ci = os.getenv("CI", "false").lower() == "true"
# Handle certificates for dev/test if not available
ssl_dir = Path(cert_path).parent
ssl_dir.mkdir(parents=True, exist_ok=True)
if environment in ["dev", "test"] and (
not Path(cert_path).exists() or not Path(key_path).exists()
):
print(f"[INFO] Generating SSL certificates at {ssl_dir}")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
if is_ci or environment == "test":
server_process = Process(target=run_server, args=(config, cert_path, key_path))
server_process.start()
sleep(5)
server_process.terminate()
server_process.join()
print("CI/Test environment: Server started and terminated successfully.")
else:
run_server(config, cert_path, key_path)