From 19c5d7f8809cdc185f48507f25518d7c0a3c2f9c Mon Sep 17 00:00:00 2001 From: GotthardG <51994228+GotthardG@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:11:26 +0100 Subject: [PATCH] Refactor environment-specific configurations and data loading. Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates. --- backend/app/data/slots_data.py | 2 +- backend/app/database.py | 42 ++++++--- backend/main.py | 162 +++++++++++++++++++-------------- 3 files changed, 124 insertions(+), 82 deletions(-) diff --git a/backend/app/data/slots_data.py b/backend/app/data/slots_data.py index 4aa61e3..4341fdb 100644 --- a/backend/app/data/slots_data.py +++ b/backend/app/data/slots_data.py @@ -59,7 +59,7 @@ slotQRCodes = [ def timedelta_to_str(td: timedelta) -> str: days, seconds = td.days, td.seconds hours = days * 24 + seconds // 3600 - minutes = (seconds % 3600) // 60 + minutes = (seconds % 172800) // 60 return f"PT{hours}H{minutes}M" diff --git a/backend/app/database.py b/backend/app/database.py index d5b6090..aba08a5 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -6,18 +6,31 @@ from sqlalchemy.orm import sessionmaker from . import models import os -# Get username and password from environment variables -db_username = os.getenv("DB_USERNAME") -db_password = os.getenv("DB_PASSWORD") +# Fetch the environment (default to "dev") +environment = os.getenv("ENVIRONMENT", "dev") -# Construct the database URL -SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db" +# Configure database per environment +if environment == "prod": + db_username = os.getenv("DB_USERNAME", "prod_user") + db_password = os.getenv("DB_PASSWORD", "prod_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_prod_db") +elif environment == "test": + db_username = os.getenv("DB_USERNAME", "test_user") + db_password = os.getenv("DB_PASSWORD", "test_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_test_db") +else: # Default is dev + db_username = os.getenv("DB_USERNAME", "dev_user") + db_password = os.getenv("DB_PASSWORD", "dev_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_dev_db") -# Remove the `connect_args` parameter +SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}" + +# Create engine and session engine = create_engine(SQLALCHEMY_DATABASE_URL) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base = declarative_base() @@ -30,12 +43,17 @@ def get_db(): db.close() -def init_db(): - Base.metadata.create_all(bind=engine) +# Load only slots (minimal data) +def load_slots_data(session: Session): + from .data import slots + + if not session.query(models.Slot).first(): # Load only if no slots exist + session.add_all(slots) + session.commit() +# Load full sample data (used in dev/test) def load_sample_data(session: Session): - # Import models inside function to avoid circular dependency from .data import ( contacts, return_addresses, @@ -50,7 +68,7 @@ def load_sample_data(session: Session): sample_events, ) - # If any data already exists, skip seeding + # If any data exists, don't reseed if session.query(models.ContactPerson).first(): return diff --git a/backend/main.py b/backend/main.py index 9d4c825..d4ed263 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,4 @@ # app/main.py -import sys import os import json import tomllib @@ -21,7 +20,7 @@ from app.routers import ( auth, sample, ) -from app.database import Base, engine, SessionLocal, load_sample_data +from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data # Utility function to fetch metadata from pyproject.toml @@ -80,13 +79,29 @@ app.add_middleware( @app.on_event("startup") def on_startup(): - # Drop and recreate database schema - Base.metadata.drop_all(bind=engine) - Base.metadata.create_all(bind=engine) - db = SessionLocal() try: - load_sample_data(db) + if environment == "prod": + from sqlalchemy.engine import reflection + + inspector = reflection.Inspector.from_engine(engine) + tables_exist = inspector.get_table_names() + + if not tables_exist: + print("Production database is empty. Initializing...") + Base.metadata.create_all(bind=engine) + load_slots_data(db) + else: + print("Production database already initialized.") + + else: # dev or test + print(f"{environment.capitalize()} environment: Regenerating database.") + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + if environment == "dev": + load_sample_data(db) + elif environment == "test": + load_slots_data(db) finally: db.close() @@ -106,70 +121,79 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"]) if __name__ == "__main__": import uvicorn + from pathlib import Path + from app import ssl_heidi + from dotenv import load_dotenv + + # Load environment variables from .env file + load_dotenv() + + # Fetch environment values + environment = os.getenv("ENVIRONMENT", "dev") + port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set + is_ci = ( + os.getenv("CI", "false").lower() == "true" + ) # Detect if running in CI environment + + # Determine SSL certificate and key paths + if environment == "prod": + # Production environment must use proper SSL cert and key paths + cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem") + key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem") + if not Path(cert_path).exists() or not Path(key_path).exists(): + raise FileNotFoundError( + f"Production certificates not found." + f"Make sure the following files exist:\n" + f"Certificate: {cert_path}\nKey: {key_path}" + ) + host = "0.0.0.0" # Allow external traffic + print( + f"Running in production mode with provided SSL certificates:\n" + f" - Certificate: {cert_path}\n - Key: {key_path}" + ) + + elif environment in ["test", "dev"]: + # Test/Development environments use self-signed certificates + cert_path = "ssl/cert.pem" + key_path = "ssl/key.pem" + host = "127.0.0.1" # Restrict to localhost + print(f"Running in {environment} mode with self-signed certificates...") + + # Ensure self-signed certificates exist or generate them + Path("ssl").mkdir(parents=True, exist_ok=True) + if not Path(cert_path).exists() or not Path(key_path).exists(): + print(f"Generating self-signed SSL certificate at {cert_path}...") + ssl_heidi.generate_self_signed_cert(cert_path, key_path) - # Check if the user has passed "generate-openapi" as the first CLI argument - if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi": - # Generate and save the OpenAPI schema - openapi_schema = app.openapi() - with open("openapi.json", "w") as openapi_file: - json.dump(openapi_schema, openapi_file, indent=2) - print("OpenAPI schema has been generated and saved to 'openapi.json'.") else: - # Default behavior: Run the FastAPI server - import os + raise ValueError( + f"Unknown environment: {environment}. " + f"Must be one of 'prod', 'test', or 'dev'." + ) + + # Function to run the server + def run_server(): + uvicorn.run( + app, + host=host, + port=port, + log_level="debug", + ssl_keyfile=key_path, + ssl_certfile=cert_path, + ) + + # Continuous Integration handling + if is_ci: from multiprocessing import Process from time import sleep - # Get environment from an environment variable - environment = os.getenv("ENVIRONMENT", "dev") - is_ci = os.getenv("CI", "false").lower() == "true" # Check if running in CI - port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set - - # Paths for SSL certificates - if is_ci: - cert_path = "ssl/cert.pem" - key_path = "ssl/key.pem" - host = "127.0.0.1" - print("Running in CI mode with self-signed certificates...") - # Ensure SSL certificate and key are generated - if is_ci or environment == "dev": - Path("ssl").mkdir(exist_ok=True) - - if not Path(cert_path).exists() or not Path(key_path).exists(): - print( - f"Generating self-signed SSL certificate" - f"at {cert_path} and {key_path}" - ) - ssl_heidi.generate_self_signed_cert(cert_path, key_path) - elif environment == "test": - cert_path = "ssl/mx-aare-test.psi.ch.pem" - key_path = "ssl/mx-aare-test.psi.ch.key" - host = "0.0.0.0" - print("Using test SSL certificates...") - else: - cert_path = "ssl/cert.pem" - key_path = "ssl/key.pem" - host = "127.0.0.1" - print("Using development SSL certificates...") - - def run_server(): - uvicorn.run( - app, - host=host, - port=port, - log_level="debug", - ssl_keyfile=key_path, - ssl_certfile=cert_path, - ) - - if is_ci: - # In CI, start server in a subprocess and exit after a short delay - server_process = Process(target=run_server) - server_process.start() - sleep(5) # Wait for 5 seconds to ensure the server starts without errors - server_process.terminate() # Terminate the server process - server_process.join() # Clean up the process - print("CI: Server started successfully and exited.") - else: - # Normal behavior for running the FastAPI server - run_server() + print("CI mode detected: Starting server in a subprocess...") + server_process = Process(target=run_server) + server_process.start() + sleep(5) # Wait 5 seconds to ensure the server starts without errors + server_process.terminate() # Terminate the server (test purposes) + server_process.join() # Ensure proper cleanup + print("CI: Server started and terminated successfully for test validation.") + else: + # Run the server normally + run_server()