
Reorganized and enhanced the OpenAPI fetch logic for better maintainability and error handling. Key updates include improved environment variable validation, more detailed error messages, streamlined configuration loading, and additional safety checks for file paths and directories. Added proper logging and ensured the process flow is easy to trace.
217 lines
7.5 KiB
Python
217 lines
7.5 KiB
Python
import os
|
|
import json
|
|
import tomllib
|
|
from pathlib import Path
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from app import ssl_heidi
|
|
from app.routers import (
|
|
address,
|
|
contact,
|
|
proposal,
|
|
dewar,
|
|
shipment,
|
|
puck,
|
|
spreadsheet,
|
|
logistics,
|
|
auth,
|
|
sample,
|
|
)
|
|
from app.database import Base, engine, SessionLocal
|
|
|
|
|
|
# Utility function to fetch metadata from pyproject.toml
|
|
def get_project_metadata():
|
|
script_dir = Path(__file__).resolve().parent
|
|
for parent in script_dir.parents:
|
|
pyproject_path = parent / "pyproject.toml"
|
|
if pyproject_path.exists():
|
|
with open(pyproject_path, "rb") as f:
|
|
pyproject = tomllib.load(f)
|
|
name = pyproject["project"]["name"]
|
|
version = pyproject["project"]["version"]
|
|
return name, version
|
|
raise FileNotFoundError(
|
|
f"pyproject.toml not found in any parent directory of {script_dir}"
|
|
)
|
|
|
|
|
|
# Get project metadata from pyproject.toml
|
|
project_name, project_version = get_project_metadata()
|
|
app = FastAPI(
|
|
title=project_name,
|
|
description="Backend for next-gen sample management system",
|
|
version=project_version,
|
|
)
|
|
|
|
# Determine environment and configuration file path
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
|
|
|
|
if not config_file.exists():
|
|
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
|
|
|
|
# Load configuration
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
|
|
# Set SSL paths based on environment
|
|
if environment in ["test", "dev"]:
|
|
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
|
|
key_path = config.get("ssl_key_path", "ssl/key.pem")
|
|
elif environment == "prod":
|
|
cert_path = config.get("SSL_CERT_PATH")
|
|
key_path = config.get("SSL_KEY_PATH")
|
|
# Validate production SSL paths
|
|
if not cert_path or not key_path:
|
|
raise ValueError(
|
|
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
|
|
" for production."
|
|
)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
raise FileNotFoundError(
|
|
f"Missing SSL files in production. Ensure the following files exist:\n"
|
|
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown environment: {environment}")
|
|
|
|
# Generate SSL Key and Certificate if not exist (only for development)
|
|
if environment == "dev":
|
|
Path("ssl").mkdir(parents=True, exist_ok=True)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
# Apply CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
print("[INFO] Running application startup tasks...")
|
|
db = SessionLocal()
|
|
try:
|
|
if environment == "prod":
|
|
from sqlalchemy.engine import reflection
|
|
|
|
inspector = reflection.Inspector.from_engine(engine)
|
|
tables_exist = inspector.get_table_names()
|
|
|
|
# Ensure the production database is initialized
|
|
if not tables_exist:
|
|
print("Production database is empty. Initializing...")
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Seed the database (slots + proposals)
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
else: # dev or test environments
|
|
print(f"{environment.capitalize()} environment: Regenerating database.")
|
|
Base.metadata.drop_all(bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
if environment == "dev":
|
|
from app.database import load_sample_data
|
|
|
|
load_sample_data(db)
|
|
elif environment == "test":
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Include routers with correct configuration
|
|
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
|
app.include_router(contact.router, prefix="/contacts", tags=["contacts"])
|
|
app.include_router(address.router, prefix="/addresses", tags=["addresses"])
|
|
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
|
|
app.include_router(dewar.router, prefix="/dewars", tags=["dewars"])
|
|
app.include_router(shipment.router, prefix="/shipments", tags=["shipments"])
|
|
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
|
|
app.include_router(spreadsheet.router, tags=["spreadsheet"])
|
|
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
|
|
app.include_router(sample.router, prefix="/samples", tags=["samples"])
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
import uvicorn
|
|
from dotenv import load_dotenv
|
|
from multiprocessing import Process
|
|
from time import sleep
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Check if `generate-openapi` option is passed
|
|
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
|
|
from fastapi.openapi.utils import get_openapi
|
|
|
|
# Generate and save OpenAPI JSON file
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
)
|
|
with open("openapi.json", "w") as f:
|
|
json.dump(openapi_schema, f, indent=4)
|
|
print("openapi.json generated successfully.")
|
|
sys.exit(0) # Exit after generating the file
|
|
|
|
# Default behavior: Run the server
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
port = int(os.getenv("PORT", 8000))
|
|
is_ci = os.getenv("CI", "false").lower() == "true"
|
|
|
|
def run_server():
|
|
print(f"[INFO] Starting server in {environment} environment...")
|
|
print(f"[INFO] SSL Certificate Path: {cert_path}")
|
|
print(f"[INFO] SSL Key Path: {key_path}")
|
|
port = config.get("PORT", os.getenv("PORT"))
|
|
if not port:
|
|
print(
|
|
"[ERROR] No port defined in config or environment variables. Aborting!"
|
|
)
|
|
sys.exit(1) # Exit if no port is defined
|
|
port = int(port)
|
|
print(f"[INFO] Running on port {port}")
|
|
uvicorn.run(
|
|
app,
|
|
host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
|
|
port=port,
|
|
log_level="debug",
|
|
ssl_keyfile=key_path,
|
|
ssl_certfile=cert_path,
|
|
)
|
|
|
|
# Run in CI mode
|
|
if is_ci: # CI mode
|
|
print("CI mode detected: Starting server in a subprocess...")
|
|
# Ensure SSL directory exists
|
|
ssl_dir = Path("ssl")
|
|
ssl_dir.mkdir(
|
|
parents=True, exist_ok=True
|
|
) # Create ssl directory if it doesn't exist
|
|
|
|
# Ensure SSL certificate and key exist
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
print("Generating SSL certificates for CI mode...")
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
server_process = Process(target=run_server)
|
|
server_process.start()
|
|
sleep(5) # Wait 5 seconds to ensure the server starts without errors
|
|
server_process.terminate() # Terminate the server (test purposes)
|
|
server_process.join() # Ensure proper cleanup
|
|
print("CI: Server started and terminated successfully for test validation.")
|
|
else:
|
|
run_server()
|