aaredb/backend/main.py
GotthardG c2215860bf Refactor Dewar service methods and improve field handling
Updated Dewar API methods to use protected endpoints for enhanced security and consistency. Added `pgroups` handling in various frontend components and modified the LogisticsView contact field for clarity. Simplified backend router imports for better readability.
2025-01-30 13:39:49 +01:00

212 lines
7.3 KiB
Python

import os
import json
import tomllib
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app import ssl_heidi
from app.routers import (
proposal,
puck,
spreadsheet,
logistics,
auth,
sample,
)
from app.database import Base, engine, SessionLocal
from app.routers.protected_router import protected_router
# Utility function to fetch metadata from pyproject.toml
def get_project_metadata():
script_dir = Path(__file__).resolve().parent
for parent in script_dir.parents:
pyproject_path = parent / "pyproject.toml"
if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
name = pyproject["project"]["name"]
version = pyproject["project"]["version"]
return name, version
raise FileNotFoundError(
f"pyproject.toml not found in any parent directory of {script_dir}"
)
def run_server():
import uvicorn
print(f"[INFO] Starting server in {environment} environment...")
print(f"[INFO] SSL Certificate Path: {cert_path}")
print(f"[INFO] SSL Key Path: {key_path}")
port = config.get("PORT", os.getenv("PORT"))
if not port:
print("[ERROR] No port defined in config or environment variables. Aborting!")
sys.exit(1) # Exit if no port is defined
port = int(port)
print(f"[INFO] Running on port {port}")
uvicorn.run(
app,
host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
# Get project metadata from pyproject.toml
project_name, project_version = get_project_metadata()
app = FastAPI(
title=project_name,
description="Backend for next-gen sample management system",
version=project_version,
servers=[
{"url": "https://mx-aare-test.psi.ch:1492", "description": "Default server"}
],
)
# Determine environment and configuration file path
environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
if not config_file.exists():
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
# Load configuration
with open(config_file) as f:
config = json.load(f)
# Set SSL paths based on environment
if environment in ["test", "dev"]:
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
key_path = config.get("ssl_key_path", "ssl/key.pem")
elif environment == "prod":
cert_path = config.get("SSL_CERT_PATH")
key_path = config.get("SSL_KEY_PATH")
# Validate production SSL paths
if not cert_path or not key_path:
raise ValueError(
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
" for production."
)
if not Path(cert_path).exists() or not Path(key_path).exists():
raise FileNotFoundError(
f"Missing SSL files in production. Ensure the following files exist:\n"
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
)
else:
raise ValueError(f"Unknown environment: {environment}")
# Generate SSL Key and Certificate if not exist (only for development)
if environment == "dev":
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def on_startup():
print("[INFO] Running application startup tasks...")
db = SessionLocal()
try:
if environment == "prod":
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
# Ensure the production database is initialized
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
# Seed the database (slots + proposals)
from app.database import load_slots_data
load_slots_data(db)
else: # dev or test environments
print(f"{environment.capitalize()} environment: Regenerating database.")
# Base.metadata.drop_all(bind=engine)
# Base.metadata.create_all(bind=engine)
if environment == "dev":
from app.database import load_sample_data
load_sample_data(db)
elif environment == "test":
from app.database import load_slots_data
load_slots_data(db)
finally:
db.close()
# Include routers with correct configuration
app.include_router(protected_router, prefix="/protected")
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
app.include_router(spreadsheet.router, tags=["spreadsheet"])
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
app.include_router(sample.router, prefix="/samples", tags=["samples"])
if __name__ == "__main__":
import sys
from dotenv import load_dotenv
from multiprocessing import Process
from time import sleep
# Load environment variables from .env file
load_dotenv()
# Check if `generate-openapi` option is passed
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
from fastapi.openapi.utils import get_openapi
# Generate and save OpenAPI JSON file
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
with open("openapi.json", "w") as f:
json.dump(openapi_schema, f, indent=4)
print("openapi.json generated successfully.")
sys.exit(0) # Exit after generating the file
# Default behavior: Run the server based on the environment
environment = os.getenv("ENVIRONMENT", "dev")
port = int(os.getenv("PORT", 8000))
is_ci = os.getenv("CI", "false").lower() == "true"
if is_ci or environment == "test":
# Test or CI Mode: Run server process temporarily for test validation
ssl_dir = Path(cert_path).parent
ssl_dir.mkdir(parents=True, exist_ok=True)
# Generate self-signed certs if missing
if not Path(cert_path).exists() or not Path(key_path).exists():
print(f"[INFO] Generating self-signed SSL certificates at {ssl_dir}")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
# Start the server as a subprocess, wait, then terminate
server_process = Process(target=run_server)
server_process.start()
sleep(5) # Wait for 5 seconds to verify the server is running
server_process.terminate() # Terminate the server process (for CI)
server_process.join() # Ensure proper cleanup
print("CI: Server started and terminated successfully for test validation.")
else:
# Dev or Prod: Start the server as usual
run_server()