diff --git a/backend/app/data/slots_data.py b/backend/app/data/slots_data.py index 4aa61e3..4341fdb 100644 --- a/backend/app/data/slots_data.py +++ b/backend/app/data/slots_data.py @@ -59,7 +59,7 @@ slotQRCodes = [ def timedelta_to_str(td: timedelta) -> str: days, seconds = td.days, td.seconds hours = days * 24 + seconds // 3600 - minutes = (seconds % 3600) // 60 + minutes = (seconds % 172800) // 60 return f"PT{hours}H{minutes}M" diff --git a/backend/app/database.py b/backend/app/database.py index d5b6090..aba08a5 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -6,18 +6,31 @@ from sqlalchemy.orm import sessionmaker from . import models import os -# Get username and password from environment variables -db_username = os.getenv("DB_USERNAME") -db_password = os.getenv("DB_PASSWORD") +# Fetch the environment (default to "dev") +environment = os.getenv("ENVIRONMENT", "dev") -# Construct the database URL -SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db" +# Configure database per environment +if environment == "prod": + db_username = os.getenv("DB_USERNAME", "prod_user") + db_password = os.getenv("DB_PASSWORD", "prod_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_prod_db") +elif environment == "test": + db_username = os.getenv("DB_USERNAME", "test_user") + db_password = os.getenv("DB_PASSWORD", "test_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_test_db") +else: # Default is dev + db_username = os.getenv("DB_USERNAME", "dev_user") + db_password = os.getenv("DB_PASSWORD", "dev_password") + db_host = os.getenv("DB_HOST", "localhost") + db_name = os.getenv("DB_NAME", "aare_dev_db") -# Remove the `connect_args` parameter +SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}" + +# Create engine and session engine = create_engine(SQLALCHEMY_DATABASE_URL) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base = declarative_base() @@ -30,12 +43,17 @@ def get_db(): db.close() -def init_db(): - Base.metadata.create_all(bind=engine) +# Load only slots (minimal data) +def load_slots_data(session: Session): + from .data import slots + + if not session.query(models.Slot).first(): # Load only if no slots exist + session.add_all(slots) + session.commit() +# Load full sample data (used in dev/test) def load_sample_data(session: Session): - # Import models inside function to avoid circular dependency from .data import ( contacts, return_addresses, @@ -50,7 +68,7 @@ def load_sample_data(session: Session): sample_events, ) - # If any data already exists, skip seeding + # If any data exists, don't reseed if session.query(models.ContactPerson).first(): return diff --git a/backend/main.py b/backend/main.py index 9d4c825..d4ed263 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,4 @@ # app/main.py -import sys import os import json import tomllib @@ -21,7 +20,7 @@ from app.routers import ( auth, sample, ) -from app.database import Base, engine, SessionLocal, load_sample_data +from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data # Utility function to fetch metadata from pyproject.toml @@ -80,13 +79,29 @@ app.add_middleware( @app.on_event("startup") def on_startup(): - # Drop and recreate database schema - Base.metadata.drop_all(bind=engine) - Base.metadata.create_all(bind=engine) - db = SessionLocal() try: - load_sample_data(db) + if environment == "prod": + from sqlalchemy.engine import reflection + + inspector = reflection.Inspector.from_engine(engine) + tables_exist = inspector.get_table_names() + + if not tables_exist: + print("Production database is empty. Initializing...") + Base.metadata.create_all(bind=engine) + load_slots_data(db) + else: + print("Production database already initialized.") + + else: # dev or test + print(f"{environment.capitalize()} environment: Regenerating database.") + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + if environment == "dev": + load_sample_data(db) + elif environment == "test": + load_slots_data(db) finally: db.close() @@ -106,70 +121,79 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"]) if __name__ == "__main__": import uvicorn + from pathlib import Path + from app import ssl_heidi + from dotenv import load_dotenv + + # Load environment variables from .env file + load_dotenv() + + # Fetch environment values + environment = os.getenv("ENVIRONMENT", "dev") + port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set + is_ci = ( + os.getenv("CI", "false").lower() == "true" + ) # Detect if running in CI environment + + # Determine SSL certificate and key paths + if environment == "prod": + # Production environment must use proper SSL cert and key paths + cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem") + key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem") + if not Path(cert_path).exists() or not Path(key_path).exists(): + raise FileNotFoundError( + f"Production certificates not found." + f"Make sure the following files exist:\n" + f"Certificate: {cert_path}\nKey: {key_path}" + ) + host = "0.0.0.0" # Allow external traffic + print( + f"Running in production mode with provided SSL certificates:\n" + f" - Certificate: {cert_path}\n - Key: {key_path}" + ) + + elif environment in ["test", "dev"]: + # Test/Development environments use self-signed certificates + cert_path = "ssl/cert.pem" + key_path = "ssl/key.pem" + host = "127.0.0.1" # Restrict to localhost + print(f"Running in {environment} mode with self-signed certificates...") + + # Ensure self-signed certificates exist or generate them + Path("ssl").mkdir(parents=True, exist_ok=True) + if not Path(cert_path).exists() or not Path(key_path).exists(): + print(f"Generating self-signed SSL certificate at {cert_path}...") + ssl_heidi.generate_self_signed_cert(cert_path, key_path) - # Check if the user has passed "generate-openapi" as the first CLI argument - if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi": - # Generate and save the OpenAPI schema - openapi_schema = app.openapi() - with open("openapi.json", "w") as openapi_file: - json.dump(openapi_schema, openapi_file, indent=2) - print("OpenAPI schema has been generated and saved to 'openapi.json'.") else: - # Default behavior: Run the FastAPI server - import os + raise ValueError( + f"Unknown environment: {environment}. " + f"Must be one of 'prod', 'test', or 'dev'." + ) + + # Function to run the server + def run_server(): + uvicorn.run( + app, + host=host, + port=port, + log_level="debug", + ssl_keyfile=key_path, + ssl_certfile=cert_path, + ) + + # Continuous Integration handling + if is_ci: from multiprocessing import Process from time import sleep - # Get environment from an environment variable - environment = os.getenv("ENVIRONMENT", "dev") - is_ci = os.getenv("CI", "false").lower() == "true" # Check if running in CI - port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set - - # Paths for SSL certificates - if is_ci: - cert_path = "ssl/cert.pem" - key_path = "ssl/key.pem" - host = "127.0.0.1" - print("Running in CI mode with self-signed certificates...") - # Ensure SSL certificate and key are generated - if is_ci or environment == "dev": - Path("ssl").mkdir(exist_ok=True) - - if not Path(cert_path).exists() or not Path(key_path).exists(): - print( - f"Generating self-signed SSL certificate" - f"at {cert_path} and {key_path}" - ) - ssl_heidi.generate_self_signed_cert(cert_path, key_path) - elif environment == "test": - cert_path = "ssl/mx-aare-test.psi.ch.pem" - key_path = "ssl/mx-aare-test.psi.ch.key" - host = "0.0.0.0" - print("Using test SSL certificates...") - else: - cert_path = "ssl/cert.pem" - key_path = "ssl/key.pem" - host = "127.0.0.1" - print("Using development SSL certificates...") - - def run_server(): - uvicorn.run( - app, - host=host, - port=port, - log_level="debug", - ssl_keyfile=key_path, - ssl_certfile=cert_path, - ) - - if is_ci: - # In CI, start server in a subprocess and exit after a short delay - server_process = Process(target=run_server) - server_process.start() - sleep(5) # Wait for 5 seconds to ensure the server starts without errors - server_process.terminate() # Terminate the server process - server_process.join() # Clean up the process - print("CI: Server started successfully and exited.") - else: - # Normal behavior for running the FastAPI server - run_server() + print("CI mode detected: Starting server in a subprocess...") + server_process = Process(target=run_server) + server_process.start() + sleep(5) # Wait 5 seconds to ensure the server starts without errors + server_process.terminate() # Terminate the server (test purposes) + server_process.join() # Ensure proper cleanup + print("CI: Server started and terminated successfully for test validation.") + else: + # Run the server normally + run_server()