Refactor environment-specific configurations and data loading.

Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
This commit is contained in:
GotthardG
2024-12-17 13:11:26 +01:00
parent a3f85c6dda
commit 19c5d7f880
3 changed files with 124 additions and 82 deletions

View File

@@ -59,7 +59,7 @@ slotQRCodes = [
def timedelta_to_str(td: timedelta) -> str: def timedelta_to_str(td: timedelta) -> str:
days, seconds = td.days, td.seconds days, seconds = td.days, td.seconds
hours = days * 24 + seconds // 3600 hours = days * 24 + seconds // 3600
minutes = (seconds % 3600) // 60 minutes = (seconds % 172800) // 60
return f"PT{hours}H{minutes}M" return f"PT{hours}H{minutes}M"

View File

@@ -6,18 +6,31 @@ from sqlalchemy.orm import sessionmaker
from . import models from . import models
import os import os
# Get username and password from environment variables # Fetch the environment (default to "dev")
db_username = os.getenv("DB_USERNAME") environment = os.getenv("ENVIRONMENT", "dev")
db_password = os.getenv("DB_PASSWORD")
# Construct the database URL # Configure database per environment
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db" if environment == "prod":
db_username = os.getenv("DB_USERNAME", "prod_user")
db_password = os.getenv("DB_PASSWORD", "prod_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_prod_db")
elif environment == "test":
db_username = os.getenv("DB_USERNAME", "test_user")
db_password = os.getenv("DB_PASSWORD", "test_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_test_db")
else: # Default is dev
db_username = os.getenv("DB_USERNAME", "dev_user")
db_password = os.getenv("DB_PASSWORD", "dev_password")
db_host = os.getenv("DB_HOST", "localhost")
db_name = os.getenv("DB_NAME", "aare_dev_db")
# Remove the `connect_args` parameter SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}"
# Create engine and session
engine = create_engine(SQLALCHEMY_DATABASE_URL) engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
@@ -30,12 +43,17 @@ def get_db():
db.close() db.close()
def init_db(): # Load only slots (minimal data)
Base.metadata.create_all(bind=engine) def load_slots_data(session: Session):
from .data import slots
if not session.query(models.Slot).first(): # Load only if no slots exist
session.add_all(slots)
session.commit()
# Load full sample data (used in dev/test)
def load_sample_data(session: Session): def load_sample_data(session: Session):
# Import models inside function to avoid circular dependency
from .data import ( from .data import (
contacts, contacts,
return_addresses, return_addresses,
@@ -50,7 +68,7 @@ def load_sample_data(session: Session):
sample_events, sample_events,
) )
# If any data already exists, skip seeding # If any data exists, don't reseed
if session.query(models.ContactPerson).first(): if session.query(models.ContactPerson).first():
return return

View File

@@ -1,5 +1,4 @@
# app/main.py # app/main.py
import sys
import os import os
import json import json
import tomllib import tomllib
@@ -21,7 +20,7 @@ from app.routers import (
auth, auth,
sample, sample,
) )
from app.database import Base, engine, SessionLocal, load_sample_data from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data
# Utility function to fetch metadata from pyproject.toml # Utility function to fetch metadata from pyproject.toml
@@ -80,13 +79,29 @@ app.add_middleware(
@app.on_event("startup") @app.on_event("startup")
def on_startup(): def on_startup():
# Drop and recreate database schema
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal() db = SessionLocal()
try: try:
if environment == "prod":
from sqlalchemy.engine import reflection
inspector = reflection.Inspector.from_engine(engine)
tables_exist = inspector.get_table_names()
if not tables_exist:
print("Production database is empty. Initializing...")
Base.metadata.create_all(bind=engine)
load_slots_data(db)
else:
print("Production database already initialized.")
else: # dev or test
print(f"{environment.capitalize()} environment: Regenerating database.")
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
if environment == "dev":
load_sample_data(db) load_sample_data(db)
elif environment == "test":
load_slots_data(db)
finally: finally:
db.close() db.close()
@@ -106,52 +121,57 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"])
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
from pathlib import Path
from app import ssl_heidi
from dotenv import load_dotenv
# Check if the user has passed "generate-openapi" as the first CLI argument # Load environment variables from .env file
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi": load_dotenv()
# Generate and save the OpenAPI schema
openapi_schema = app.openapi()
with open("openapi.json", "w") as openapi_file:
json.dump(openapi_schema, openapi_file, indent=2)
print("OpenAPI schema has been generated and saved to 'openapi.json'.")
else:
# Default behavior: Run the FastAPI server
import os
from multiprocessing import Process
from time import sleep
# Get environment from an environment variable # Fetch environment values
environment = os.getenv("ENVIRONMENT", "dev") environment = os.getenv("ENVIRONMENT", "dev")
is_ci = os.getenv("CI", "false").lower() == "true" # Check if running in CI
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
is_ci = (
os.getenv("CI", "false").lower() == "true"
) # Detect if running in CI environment
# Paths for SSL certificates # Determine SSL certificate and key paths
if is_ci: if environment == "prod":
cert_path = "ssl/cert.pem" # Production environment must use proper SSL cert and key paths
key_path = "ssl/key.pem" cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem")
host = "127.0.0.1" key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem")
print("Running in CI mode with self-signed certificates...")
# Ensure SSL certificate and key are generated
if is_ci or environment == "dev":
Path("ssl").mkdir(exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists(): if not Path(cert_path).exists() or not Path(key_path).exists():
print( raise FileNotFoundError(
f"Generating self-signed SSL certificate" f"Production certificates not found."
f"at {cert_path} and {key_path}" f"Make sure the following files exist:\n"
f"Certificate: {cert_path}\nKey: {key_path}"
) )
ssl_heidi.generate_self_signed_cert(cert_path, key_path) host = "0.0.0.0" # Allow external traffic
elif environment == "test": print(
cert_path = "ssl/mx-aare-test.psi.ch.pem" f"Running in production mode with provided SSL certificates:\n"
key_path = "ssl/mx-aare-test.psi.ch.key" f" - Certificate: {cert_path}\n - Key: {key_path}"
host = "0.0.0.0" )
print("Using test SSL certificates...")
else: elif environment in ["test", "dev"]:
# Test/Development environments use self-signed certificates
cert_path = "ssl/cert.pem" cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem" key_path = "ssl/key.pem"
host = "127.0.0.1" host = "127.0.0.1" # Restrict to localhost
print("Using development SSL certificates...") print(f"Running in {environment} mode with self-signed certificates...")
# Ensure self-signed certificates exist or generate them
Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists():
print(f"Generating self-signed SSL certificate at {cert_path}...")
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
else:
raise ValueError(
f"Unknown environment: {environment}. "
f"Must be one of 'prod', 'test', or 'dev'."
)
# Function to run the server
def run_server(): def run_server():
uvicorn.run( uvicorn.run(
app, app,
@@ -162,14 +182,18 @@ if __name__ == "__main__":
ssl_certfile=cert_path, ssl_certfile=cert_path,
) )
# Continuous Integration handling
if is_ci: if is_ci:
# In CI, start server in a subprocess and exit after a short delay from multiprocessing import Process
from time import sleep
print("CI mode detected: Starting server in a subprocess...")
server_process = Process(target=run_server) server_process = Process(target=run_server)
server_process.start() server_process.start()
sleep(5) # Wait for 5 seconds to ensure the server starts without errors sleep(5) # Wait 5 seconds to ensure the server starts without errors
server_process.terminate() # Terminate the server process server_process.terminate() # Terminate the server (test purposes)
server_process.join() # Clean up the process server_process.join() # Ensure proper cleanup
print("CI: Server started successfully and exited.") print("CI: Server started and terminated successfully for test validation.")
else: else:
# Normal behavior for running the FastAPI server # Run the server normally
run_server() run_server()