Refactor environment-specific configurations and data loading.
Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
This commit is contained in:
@@ -59,7 +59,7 @@ slotQRCodes = [
|
|||||||
def timedelta_to_str(td: timedelta) -> str:
|
def timedelta_to_str(td: timedelta) -> str:
|
||||||
days, seconds = td.days, td.seconds
|
days, seconds = td.days, td.seconds
|
||||||
hours = days * 24 + seconds // 3600
|
hours = days * 24 + seconds // 3600
|
||||||
minutes = (seconds % 3600) // 60
|
minutes = (seconds % 172800) // 60
|
||||||
return f"PT{hours}H{minutes}M"
|
return f"PT{hours}H{minutes}M"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,18 +6,31 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
from . import models
|
from . import models
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Get username and password from environment variables
|
# Fetch the environment (default to "dev")
|
||||||
db_username = os.getenv("DB_USERNAME")
|
environment = os.getenv("ENVIRONMENT", "dev")
|
||||||
db_password = os.getenv("DB_PASSWORD")
|
|
||||||
|
|
||||||
# Construct the database URL
|
# Configure database per environment
|
||||||
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db"
|
if environment == "prod":
|
||||||
|
db_username = os.getenv("DB_USERNAME", "prod_user")
|
||||||
|
db_password = os.getenv("DB_PASSWORD", "prod_password")
|
||||||
|
db_host = os.getenv("DB_HOST", "localhost")
|
||||||
|
db_name = os.getenv("DB_NAME", "aare_prod_db")
|
||||||
|
elif environment == "test":
|
||||||
|
db_username = os.getenv("DB_USERNAME", "test_user")
|
||||||
|
db_password = os.getenv("DB_PASSWORD", "test_password")
|
||||||
|
db_host = os.getenv("DB_HOST", "localhost")
|
||||||
|
db_name = os.getenv("DB_NAME", "aare_test_db")
|
||||||
|
else: # Default is dev
|
||||||
|
db_username = os.getenv("DB_USERNAME", "dev_user")
|
||||||
|
db_password = os.getenv("DB_PASSWORD", "dev_password")
|
||||||
|
db_host = os.getenv("DB_HOST", "localhost")
|
||||||
|
db_name = os.getenv("DB_NAME", "aare_dev_db")
|
||||||
|
|
||||||
# Remove the `connect_args` parameter
|
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}"
|
||||||
|
|
||||||
|
# Create engine and session
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
@@ -30,12 +43,17 @@ def get_db():
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
# Load only slots (minimal data)
|
||||||
Base.metadata.create_all(bind=engine)
|
def load_slots_data(session: Session):
|
||||||
|
from .data import slots
|
||||||
|
|
||||||
|
if not session.query(models.Slot).first(): # Load only if no slots exist
|
||||||
|
session.add_all(slots)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Load full sample data (used in dev/test)
|
||||||
def load_sample_data(session: Session):
|
def load_sample_data(session: Session):
|
||||||
# Import models inside function to avoid circular dependency
|
|
||||||
from .data import (
|
from .data import (
|
||||||
contacts,
|
contacts,
|
||||||
return_addresses,
|
return_addresses,
|
||||||
@@ -50,7 +68,7 @@ def load_sample_data(session: Session):
|
|||||||
sample_events,
|
sample_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If any data already exists, skip seeding
|
# If any data exists, don't reseed
|
||||||
if session.query(models.ContactPerson).first():
|
if session.query(models.ContactPerson).first():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
120
backend/main.py
120
backend/main.py
@@ -1,5 +1,4 @@
|
|||||||
# app/main.py
|
# app/main.py
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import tomllib
|
import tomllib
|
||||||
@@ -21,7 +20,7 @@ from app.routers import (
|
|||||||
auth,
|
auth,
|
||||||
sample,
|
sample,
|
||||||
)
|
)
|
||||||
from app.database import Base, engine, SessionLocal, load_sample_data
|
from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data
|
||||||
|
|
||||||
|
|
||||||
# Utility function to fetch metadata from pyproject.toml
|
# Utility function to fetch metadata from pyproject.toml
|
||||||
@@ -80,13 +79,29 @@ app.add_middleware(
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def on_startup():
|
def on_startup():
|
||||||
# Drop and recreate database schema
|
|
||||||
Base.metadata.drop_all(bind=engine)
|
|
||||||
Base.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
if environment == "prod":
|
||||||
|
from sqlalchemy.engine import reflection
|
||||||
|
|
||||||
|
inspector = reflection.Inspector.from_engine(engine)
|
||||||
|
tables_exist = inspector.get_table_names()
|
||||||
|
|
||||||
|
if not tables_exist:
|
||||||
|
print("Production database is empty. Initializing...")
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
load_slots_data(db)
|
||||||
|
else:
|
||||||
|
print("Production database already initialized.")
|
||||||
|
|
||||||
|
else: # dev or test
|
||||||
|
print(f"{environment.capitalize()} environment: Regenerating database.")
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
if environment == "dev":
|
||||||
load_sample_data(db)
|
load_sample_data(db)
|
||||||
|
elif environment == "test":
|
||||||
|
load_slots_data(db)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -106,52 +121,57 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"])
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from pathlib import Path
|
||||||
|
from app import ssl_heidi
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Check if the user has passed "generate-openapi" as the first CLI argument
|
# Load environment variables from .env file
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
|
load_dotenv()
|
||||||
# Generate and save the OpenAPI schema
|
|
||||||
openapi_schema = app.openapi()
|
|
||||||
with open("openapi.json", "w") as openapi_file:
|
|
||||||
json.dump(openapi_schema, openapi_file, indent=2)
|
|
||||||
print("OpenAPI schema has been generated and saved to 'openapi.json'.")
|
|
||||||
else:
|
|
||||||
# Default behavior: Run the FastAPI server
|
|
||||||
import os
|
|
||||||
from multiprocessing import Process
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
# Get environment from an environment variable
|
# Fetch environment values
|
||||||
environment = os.getenv("ENVIRONMENT", "dev")
|
environment = os.getenv("ENVIRONMENT", "dev")
|
||||||
is_ci = os.getenv("CI", "false").lower() == "true" # Check if running in CI
|
|
||||||
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
|
port = int(os.getenv("PORT", 8000)) # Default to 8000 if PORT is not set
|
||||||
|
is_ci = (
|
||||||
|
os.getenv("CI", "false").lower() == "true"
|
||||||
|
) # Detect if running in CI environment
|
||||||
|
|
||||||
# Paths for SSL certificates
|
# Determine SSL certificate and key paths
|
||||||
if is_ci:
|
if environment == "prod":
|
||||||
cert_path = "ssl/cert.pem"
|
# Production environment must use proper SSL cert and key paths
|
||||||
key_path = "ssl/key.pem"
|
cert_path = os.getenv("VITE_SSL_CERT_PATH", "ssl/prod-cert.pem")
|
||||||
host = "127.0.0.1"
|
key_path = os.getenv("VITE_SSL_KEY_PATH", "ssl/prod-key.pem")
|
||||||
print("Running in CI mode with self-signed certificates...")
|
|
||||||
# Ensure SSL certificate and key are generated
|
|
||||||
if is_ci or environment == "dev":
|
|
||||||
Path("ssl").mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
if not Path(cert_path).exists() or not Path(key_path).exists():
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
||||||
print(
|
raise FileNotFoundError(
|
||||||
f"Generating self-signed SSL certificate"
|
f"Production certificates not found."
|
||||||
f"at {cert_path} and {key_path}"
|
f"Make sure the following files exist:\n"
|
||||||
|
f"Certificate: {cert_path}\nKey: {key_path}"
|
||||||
)
|
)
|
||||||
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
host = "0.0.0.0" # Allow external traffic
|
||||||
elif environment == "test":
|
print(
|
||||||
cert_path = "ssl/mx-aare-test.psi.ch.pem"
|
f"Running in production mode with provided SSL certificates:\n"
|
||||||
key_path = "ssl/mx-aare-test.psi.ch.key"
|
f" - Certificate: {cert_path}\n - Key: {key_path}"
|
||||||
host = "0.0.0.0"
|
)
|
||||||
print("Using test SSL certificates...")
|
|
||||||
else:
|
elif environment in ["test", "dev"]:
|
||||||
|
# Test/Development environments use self-signed certificates
|
||||||
cert_path = "ssl/cert.pem"
|
cert_path = "ssl/cert.pem"
|
||||||
key_path = "ssl/key.pem"
|
key_path = "ssl/key.pem"
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1" # Restrict to localhost
|
||||||
print("Using development SSL certificates...")
|
print(f"Running in {environment} mode with self-signed certificates...")
|
||||||
|
|
||||||
|
# Ensure self-signed certificates exist or generate them
|
||||||
|
Path("ssl").mkdir(parents=True, exist_ok=True)
|
||||||
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
||||||
|
print(f"Generating self-signed SSL certificate at {cert_path}...")
|
||||||
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown environment: {environment}. "
|
||||||
|
f"Must be one of 'prod', 'test', or 'dev'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Function to run the server
|
||||||
def run_server():
|
def run_server():
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
@@ -162,14 +182,18 @@ if __name__ == "__main__":
|
|||||||
ssl_certfile=cert_path,
|
ssl_certfile=cert_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Continuous Integration handling
|
||||||
if is_ci:
|
if is_ci:
|
||||||
# In CI, start server in a subprocess and exit after a short delay
|
from multiprocessing import Process
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
print("CI mode detected: Starting server in a subprocess...")
|
||||||
server_process = Process(target=run_server)
|
server_process = Process(target=run_server)
|
||||||
server_process.start()
|
server_process.start()
|
||||||
sleep(5) # Wait for 5 seconds to ensure the server starts without errors
|
sleep(5) # Wait 5 seconds to ensure the server starts without errors
|
||||||
server_process.terminate() # Terminate the server process
|
server_process.terminate() # Terminate the server (test purposes)
|
||||||
server_process.join() # Clean up the process
|
server_process.join() # Ensure proper cleanup
|
||||||
print("CI: Server started successfully and exited.")
|
print("CI: Server started and terminated successfully for test validation.")
|
||||||
else:
|
else:
|
||||||
# Normal behavior for running the FastAPI server
|
# Run the server normally
|
||||||
run_server()
|
run_server()
|
||||||
|
|||||||
Reference in New Issue
Block a user