aaredb/backend/main.py
GotthardG 9b6ad107bb Refactor OpenAPI fetcher for improved clarity and robustness
Reorganized and enhanced the OpenAPI fetch logic for better maintainability and error handling. Key updates include improved environment variable validation, more detailed error messages, streamlined configuration loading, and additional safety checks for file paths and directories. Added proper logging and ensured the process flow is easy to trace.
2024-12-17 16:25:53 +01:00

217 lines
7.5 KiB
Python

import os
import json
import tomllib
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app import ssl_heidi
from app.routers import (
address,
contact,
proposal,
dewar,
shipment,
puck,
spreadsheet,
logistics,
auth,
sample,
)
from app.database import Base, engine, SessionLocal
# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
script_dir = Path(__file__).resolve().parent
for parent in script_dir.parents:
pyproject_path = parent / "pyproject.toml"
if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
name = pyproject["project"]["name"]
version = pyproject["project"]["version"]
return name, version
raise FileNotFoundError(
f"pyproject.toml not found in any parent directory of {script_dir}"
)
# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()
app = FastAPI(
title=project_name,
description="Backend for next-gen sample management system",
version=project_version,
)
# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
if not config_file.exists():
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
# Load configuration
with open(config_file) as f:
config = json.load(f)
# Set SSL paths based on environment
if environment in ["test", "dev"]:
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
key_path = config.get("ssl_key_path", "ssl/key.pem")
elif environment == "prod":
cert_path = config.get("SSL_CERT_PATH")
key_path = config.get("SSL_KEY_PATH")
# Validate production SSL paths
if not cert_path or not key_path:
raise ValueError(
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
" for production."
)
if not Path(cert_path).exists() or not Path(key_path).exists():
raise FileNotFoundError(
f"Missing SSL files in production. Ensure the following files exist:\n"
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
)
else:
raise ValueError(f"Unknown environment: {environment}")
# Generate SSL Key and Certificate if not exist (only for development)
if environment == "dev":
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def on_startup():
print("[INFO] Running application startup tasks...")
db = SessionLocal()
try:
if environment == "prod":
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
# Ensure the production database is initialized
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
# Seed the database (slots + proposals)
from app.database import load_slots_data
load_slots_data(db)
else: # dev or test environments
print(f"{environment.capitalize()} environment: Regenerating database.")
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
if environment == "dev":
from app.database import load_sample_data
load_sample_data(db)
elif environment == "test":
from app.database import load_slots_data
load_slots_data(db)
finally:
db.close()
# Include routers with correct configuration
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(contact.router, prefix="/contacts", tags=["contacts"])
app.include_router(address.router, prefix="/addresses", tags=["addresses"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(dewar.router, prefix="/dewars", tags=["dewars"])
app.include_router(shipment.router, prefix="/shipments", tags=["shipments"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])
if __name__ == "__main__":
import sys
import uvicorn
from dotenv import load_dotenv
from multiprocessing import Process
from time import sleep
# Load environment variables from .env file
load_dotenv()
# Check if `generate-openapi` option is passed
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
from fastapi.openapi.utils import get_openapi
# Generate and save OpenAPI JSON file
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
with open("openapi.json", "w") as f:
json.dump(openapi_schema, f, indent=4)
print("openapi.json generated successfully.")
sys.exit(0) # Exit after generating the file
# Default behavior: Run the server
environment = os.getenv("ENVIRONMENT", "dev")
port = int(os.getenv("PORT", 8000))
is_ci = os.getenv("CI", "false").lower() == "true"
def run_server():
print(f"[INFO] Starting server in {environment} environment...")
print(f"[INFO] SSL Certificate Path: {cert_path}")
print(f"[INFO] SSL Key Path: {key_path}")
port = config.get("PORT", os.getenv("PORT"))
if not port:
print(
"[ERROR] No port defined in config or environment variables. Aborting!"
)
sys.exit(1) # Exit if no port is defined
port = int(port)
print(f"[INFO] Running on port {port}")
uvicorn.run(
app,
host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
# Run in CI mode
if is_ci: # CI mode
print("CI mode detected: Starting server in a subprocess...")
# Ensure SSL directory exists
ssl_dir = Path("ssl")
ssl_dir.mkdir(
parents=True, exist_ok=True
) # Create ssl directory if it doesn't exist
# Ensure SSL certificate and key exist
if not Path(cert_path).exists() or not Path(key_path).exists():
print("Generating SSL certificates for CI mode...")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
server_process = Process(target=run_server)
server_process.start()
sleep(5) # Wait 5 seconds to ensure the server starts without errors
server_process.terminate() # Terminate the server (test purposes)
server_process.join() # Ensure proper cleanup
print("CI: Server started and terminated successfully for test validation.")
else:
run_server()