aaredb/backend/main.py
GotthardG 19c5d7f880 Refactor environment-specific configurations and data loading.
Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
2024-12-17 13:11:26 +01:00

200 lines
6.8 KiB
Python

# app/main.py
import os
import json
import tomllib
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app import ssl_heidi
from app.routers import (
address,
contact,
proposal,
dewar,
shipment,
puck,
spreadsheet,
logistics,
auth,
sample,
)
from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data
# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
# Start from the current script's directory and search for pyproject.toml
script_dir = Path(__file__).resolve().parent
for parent in script_dir.parents:
pyproject_path = parent / "pyproject.toml"
if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
name = pyproject["project"]["name"]
version = pyproject["project"]["version"]
return name, version
# If no pyproject.toml is found, raise FileNotFoundError
raise FileNotFoundError(
f"pyproject.toml not found in any parent directory of {script_dir}"
)
# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()
app = FastAPI(
title=project_name, # Syncs with project `name`
description="Backend for next-gen sample management system",
version=project_version, # Syncs with project `version`
)
# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
# Load configuration
with open(config_file) as f:
config = json.load(f)
cert_path = config["ssl_cert_path"]
key_path = config["ssl_key_path"]
# Generate SSL Key and Certificate if not exist (only for development)
if environment == "dev":
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Enable CORS for all origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def on_startup():
db = SessionLocal()
try:
if environment == "prod":
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
load_slots_data(db)
else:
print("Production database already initialized.")
else: # dev or test
print(f"{environment.capitalize()} environment: Regenerating database.")
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
if environment == "dev":
load_sample_data(db)
elif environment == "test":
load_slots_data(db)
finally:
db.close()
# Include routers with correct configuration
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(contact.router, prefix="/contacts", tags=["contacts"])
app.include_router(address.router, prefix="/addresses", tags=["addresses"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(dewar.router, prefix="/dewars", tags=["dewars"])
app.include_router(shipment.router, prefix="/shipments", tags=["shipments"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])
if __name__ == "__main__":
import uvicorn
from pathlib import Path
from app import ssl_heidi
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Fetch environment values
environment = os.getenv("ENVIRONMENT", "dev")
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
is_ci = (
os.getenv("CI", "false").lower() == "true"
) # Detect if running in CI environment
# Determine SSL certificate and key paths
if environment == "prod":
# Production environment must use proper SSL cert and key paths
cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem")
key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem")
if not Path(cert_path).exists() or not Path(key_path).exists():
raise FileNotFoundError(
f"Production certificates not found."
f"Make sure the following files exist:\n"
f"Certificate: {cert_path}\nKey: {key_path}"
)
host = "0.0.0.0" # Allow external traffic
print(
f"Running in production mode with provided SSL certificates:\n"
f" - Certificate: {cert_path}\n - Key: {key_path}"
)
elif environment in ["test", "dev"]:
# Test/Development environments use self-signed certificates
cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem"
host = "127.0.0.1" # Restrict to localhost
print(f"Running in {environment} mode with self-signed certificates...")
# Ensure self-signed certificates exist or generate them
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
print(f"Generating self-signed SSL certificate at {cert_path}...")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
else:
raise ValueError(
f"Unknown environment: {environment}. "
f"Must be one of 'prod', 'test', or 'dev'."
)
# Function to run the server
def run_server():
uvicorn.run(
app,
host=host,
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
# Continuous Integration handling
if is_ci:
from multiprocessing import Process
from time import sleep
print("CI mode detected: Starting server in a subprocess...")
server_process = Process(target=run_server)
server_process.start()
sleep(5) # Wait 5 seconds to ensure the server starts without errors
server_process.terminate() # Terminate the server (test purposes)
server_process.join() # Ensure proper cleanup
print("CI: Server started and terminated successfully for test validation.")
else:
# Run the server normally
run_server()