Refactor environment-specific configurations and data loading.

Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
This commit is contained in:
GotthardG 2024-12-17 13:11:26 +01:00
parent a3f85c6dda
commit 19c5d7f880
3 changed files with 124 additions and 82 deletions

View File

@ -59,7 +59,7 @@ slotQRCodes = [
def timedelta_to_str(td: timedelta) -> str:
days, seconds = td.days, td.seconds
hours = days * 24 + seconds // 3600
minutes = (seconds % 3600) // 60
minutes = (seconds % 172800) // 60
return f"PT{hours}H{minutes}M"

View File

@ -6,18 +6,31 @@ from sqlalchemy.orm import sessionmaker
from . import models
import os
# Get username and password from environment variables
db_username = os.getenv("DB_USERNAME")
db_password = os.getenv("DB_PASSWORD")
# Fetch the environment (default to "dev")
environment = os.getenv("ENVIRONMENT", "dev")
# Construct the database URL
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db"
# Configure database per environment
if environment == "prod":
db_username = os.getenv("DB_USERNAME", "prod_user")
db_password = os.getenv("DB_PASSWORD", "prod_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_prod_db")
elif environment == "test":
db_username = os.getenv("DB_USERNAME", "test_user")
db_password = os.getenv("DB_PASSWORD", "test_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_test_db")
else: # Default is dev
db_username = os.getenv("DB_USERNAME", "dev_user")
db_password = os.getenv("DB_PASSWORD", "dev_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_dev_db")
# Remove the `connect_args` parameter
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}"
# Create engine and session
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@ -30,12 +43,17 @@ def get_db():
db.close()
def init_db():
Base.metadata.create_all(bind=engine)
# Load only slots (minimal data)
def load_slots_data(session: Session):
from .data import slots
if not session.query(models.Slot).first(): # Load only if no slots exist
session.add_all(slots)
session.commit()
# Load full sample data (used in dev/test)
def load_sample_data(session: Session):
# Import models inside function to avoid circular dependency
from .data import (
contacts,
return_addresses,
@ -50,7 +68,7 @@ def load_sample_data(session: Session):
sample_events,
)
# If any data already exists, skip seeding
# If any data exists, don't reseed
if session.query(models.ContactPerson).first():
return

View File

@ -1,5 +1,4 @@
# app/main.py
import sys
import os
import json
import tomllib
@ -21,7 +20,7 @@ from app.routers import (
auth,
sample,
)
from app.database import Base, engine, SessionLocal, load_sample_data
from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data
# Utility function to fetch metadata from pyproject.toml
@ -80,13 +79,29 @@ app.add_middleware(
@app.on_event("startup")
def on_startup():
# Drop and recreate database schema
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal()
try:
load_sample_data(db)
if environment == "prod":
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
load_slots_data(db)
else:
print("Production database already initialized.")
else: # dev or test
print(f"{environment.capitalize()} environment: Regenerating database.")
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
if environment == "dev":
load_sample_data(db)
elif environment == "test":
load_slots_data(db)
finally:
db.close()
@ -106,70 +121,79 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"])
if __name__ == "__main__":
import uvicorn
from pathlib import Path
from app import ssl_heidi
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Fetch environment values
environment = os.getenv("ENVIRONMENT", "dev")
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
is_ci = (
os.getenv("CI", "false").lower() == "true"
) # Detect if running in CI environment
# Determine SSL certificate and key paths
if environment == "prod":
# Production environment must use proper SSL cert and key paths
cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem")
key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem")
if not Path(cert_path).exists() or not Path(key_path).exists():
raise FileNotFoundError(
f"Production certificates not found."
f"Make sure the following files exist:\n"
f"Certificate: {cert_path}\nKey: {key_path}"
)
host = "0.0.0.0" # Allow external traffic
print(
f"Running in production mode with provided SSL certificates:\n"
f" - Certificate: {cert_path}\n - Key: {key_path}"
)
elif environment in ["test", "dev"]:
# Test/Development environments use self-signed certificates
cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem"
host = "127.0.0.1" # Restrict to localhost
print(f"Running in {environment} mode with self-signed certificates...")
# Ensure self-signed certificates exist or generate them
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
print(f"Generating self-signed SSL certificate at {cert_path}...")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
# Check if the user has passed "generate-openapi" as the first CLI argument
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
# Generate and save the OpenAPI schema
openapi_schema = app.openapi()
with open("openapi.json", "w") as openapi_file:
json.dump(openapi_schema, openapi_file, indent=2)
print("OpenAPI schema has been generated and saved to 'openapi.json'.")
else:
# Default behavior: Run the FastAPI server
import os
raise ValueError(
f"Unknown environment: {environment}. "
f"Must be one of 'prod', 'test', or 'dev'."
)
# Function to run the server
def run_server():
uvicorn.run(
app,
host=host,
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
# Continuous Integration handling
if is_ci:
from multiprocessing import Process
from time import sleep
# Get environment from an environment variable
environment = os.getenv("ENVIRONMENT", "dev")
is_ci = os.getenv("CI", "false").lower() == "true" # Check if running in CI
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
# Paths for SSL certificates
if is_ci:
cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem"
host = "127.0.0.1"
print("Running in CI mode with self-signed certificates...")
# Ensure SSL certificate and key are generated
if is_ci or environment == "dev":
Path("ssl").mkdir(exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
print(
f"Generating self-signed SSL certificate"
f"at {cert_path} and {key_path}"
)
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
elif environment == "test":
cert_path = "ssl/mx-aare-test.psi.ch.pem"
key_path = "ssl/mx-aare-test.psi.ch.key"
host = "0.0.0.0"
print("Using test SSL certificates...")
else:
cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem"
host = "127.0.0.1"
print("Using development SSL certificates...")
def run_server():
uvicorn.run(
app,
host=host,
port=port,
log_level="debug",
ssl_keyfile=key_path,
ssl_certfile=cert_path,
)
if is_ci:
# In CI, start server in a subprocess and exit after a short delay
server_process = Process(target=run_server)
server_process.start()
sleep(5) # Wait for 5 seconds to ensure the server starts without errors
server_process.terminate() # Terminate the server process
server_process.join() # Clean up the process
print("CI: Server started successfully and exited.")
else:
# Normal behavior for running the FastAPI server
run_server()
print("CI mode detected: Starting server in a subprocess...")
server_process = Process(target=run_server)
server_process.start()
sleep(5) # Wait 5 seconds to ensure the server starts without errors
server_process.terminate() # Terminate the server (test purposes)
server_process.join() # Ensure proper cleanup
print("CI: Server started and terminated successfully for test validation.")
else:
# Run the server normally
run_server()