
Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
200 lines
6.8 KiB
Python
200 lines
6.8 KiB
Python
# app/main.py
|
|
import os
|
|
import json
|
|
import tomllib
|
|
from pathlib import Path
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from app import ssl_heidi
|
|
|
|
|
|
from app.routers import (
|
|
address,
|
|
contact,
|
|
proposal,
|
|
dewar,
|
|
shipment,
|
|
puck,
|
|
spreadsheet,
|
|
logistics,
|
|
auth,
|
|
sample,
|
|
)
|
|
from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data
|
|
|
|
|
|
# Utility function to fetch metadata from pyproject.toml
|
|
def get_project_metadata():
|
|
# Start from the current script's directory and search for pyproject.toml
|
|
script_dir = Path(__file__).resolve().parent
|
|
for parent in script_dir.parents:
|
|
pyproject_path = parent / "pyproject.toml"
|
|
if pyproject_path.exists():
|
|
with open(pyproject_path, "rb") as f:
|
|
pyproject = tomllib.load(f)
|
|
name = pyproject["project"]["name"]
|
|
version = pyproject["project"]["version"]
|
|
return name, version
|
|
|
|
# If no pyproject.toml is found, raise FileNotFoundError
|
|
raise FileNotFoundError(
|
|
f"pyproject.toml not found in any parent directory of {script_dir}"
|
|
)
|
|
|
|
|
|
# Get project metadata from pyproject.toml
|
|
project_name, project_version = get_project_metadata()
|
|
app = FastAPI(
|
|
title=project_name, # Syncs with project `name`
|
|
description="Backend for next-gen sample management system",
|
|
version=project_version, # Syncs with project `version`
|
|
)
|
|
|
|
# Determine environment and configuration file path
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
|
|
|
|
# Load configuration
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
|
|
cert_path = config["ssl_cert_path"]
|
|
key_path = config["ssl_key_path"]
|
|
|
|
# Generate SSL Key and Certificate if not exist (only for development)
|
|
if environment == "dev":
|
|
Path("ssl").mkdir(parents=True, exist_ok=True)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
# Apply CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Enable CORS for all origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
db = SessionLocal()
|
|
try:
|
|
if environment == "prod":
|
|
from sqlalchemy.engine import reflection
|
|
|
|
inspector = reflection.Inspector.from_engine(engine)
|
|
tables_exist = inspector.get_table_names()
|
|
|
|
if not tables_exist:
|
|
print("Production database is empty. Initializing...")
|
|
Base.metadata.create_all(bind=engine)
|
|
load_slots_data(db)
|
|
else:
|
|
print("Production database already initialized.")
|
|
|
|
else: # dev or test
|
|
print(f"{environment.capitalize()} environment: Regenerating database.")
|
|
Base.metadata.drop_all(bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
if environment == "dev":
|
|
load_sample_data(db)
|
|
elif environment == "test":
|
|
load_slots_data(db)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Include routers with correct configuration
|
|
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
|
app.include_router(contact.router, prefix="/contacts", tags=["contacts"])
|
|
app.include_router(address.router, prefix="/addresses", tags=["addresses"])
|
|
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
|
|
app.include_router(dewar.router, prefix="/dewars", tags=["dewars"])
|
|
app.include_router(shipment.router, prefix="/shipments", tags=["shipments"])
|
|
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
|
|
app.include_router(spreadsheet.router, tags=["spreadsheet"])
|
|
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
|
|
app.include_router(sample.router, prefix="/samples", tags=["samples"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
from pathlib import Path
|
|
from app import ssl_heidi
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Fetch environment values
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
|
|
is_ci = (
|
|
os.getenv("CI", "false").lower() == "true"
|
|
) # Detect if running in CI environment
|
|
|
|
# Determine SSL certificate and key paths
|
|
if environment == "prod":
|
|
# Production environment must use proper SSL cert and key paths
|
|
cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem")
|
|
key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem")
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
raise FileNotFoundError(
|
|
f"Production certificates not found."
|
|
f"Make sure the following files exist:\n"
|
|
f"Certificate: {cert_path}\nKey: {key_path}"
|
|
)
|
|
host = "0.0.0.0" # Allow external traffic
|
|
print(
|
|
f"Running in production mode with provided SSL certificates:\n"
|
|
f" - Certificate: {cert_path}\n - Key: {key_path}"
|
|
)
|
|
|
|
elif environment in ["test", "dev"]:
|
|
# Test/Development environments use self-signed certificates
|
|
cert_path = "ssl/cert.pem"
|
|
key_path = "ssl/key.pem"
|
|
host = "127.0.0.1" # Restrict to localhost
|
|
print(f"Running in {environment} mode with self-signed certificates...")
|
|
|
|
# Ensure self-signed certificates exist or generate them
|
|
Path("ssl").mkdir(parents=True, exist_ok=True)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
print(f"Generating self-signed SSL certificate at {cert_path}...")
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown environment: {environment}. "
|
|
f"Must be one of 'prod', 'test', or 'dev'."
|
|
)
|
|
|
|
# Function to run the server
|
|
def run_server():
|
|
uvicorn.run(
|
|
app,
|
|
host=host,
|
|
port=port,
|
|
log_level="debug",
|
|
ssl_keyfile=key_path,
|
|
ssl_certfile=cert_path,
|
|
)
|
|
|
|
# Continuous Integration handling
|
|
if is_ci:
|
|
from multiprocessing import Process
|
|
from time import sleep
|
|
|
|
print("CI mode detected: Starting server in a subprocess...")
|
|
server_process = Process(target=run_server)
|
|
server_process.start()
|
|
sleep(5) # Wait 5 seconds to ensure the server starts without errors
|
|
server_process.terminate() # Terminate the server (test purposes)
|
|
server_process.join() # Ensure proper cleanup
|
|
print("CI: Server started and terminated successfully for test validation.")
|
|
else:
|
|
# Run the server normally
|
|
run_server()
|