
Refactored database and server configuration to handle environments (dev, test, prod) explicitly, including tailored database setup and SSL management. Separated slot and sample data loading for better control during initialization. Improved environment variable usage and error handling for production certificates.
89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
# database.py
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
from . import models
|
|
import os
|
|
|
|
# Fetch the environment (default to "dev")
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
|
|
# Configure database per environment
|
|
if environment == "prod":
|
|
db_username = os.getenv("DB_USERNAME", "prod_user")
|
|
db_password = os.getenv("DB_PASSWORD", "prod_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_prod_db")
|
|
elif environment == "test":
|
|
db_username = os.getenv("DB_USERNAME", "test_user")
|
|
db_password = os.getenv("DB_PASSWORD", "test_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_test_db")
|
|
else: # Default is dev
|
|
db_username = os.getenv("DB_USERNAME", "dev_user")
|
|
db_password = os.getenv("DB_PASSWORD", "dev_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_dev_db")
|
|
|
|
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}"
|
|
|
|
# Create engine and session
|
|
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Load only slots (minimal data)
|
|
def load_slots_data(session: Session):
|
|
from .data import slots
|
|
|
|
if not session.query(models.Slot).first(): # Load only if no slots exist
|
|
session.add_all(slots)
|
|
session.commit()
|
|
|
|
|
|
# Load full sample data (used in dev/test)
|
|
def load_sample_data(session: Session):
|
|
from .data import (
|
|
contacts,
|
|
return_addresses,
|
|
dewars,
|
|
proposals,
|
|
shipments,
|
|
pucks,
|
|
samples,
|
|
dewar_types,
|
|
serial_numbers,
|
|
slots,
|
|
sample_events,
|
|
)
|
|
|
|
# If any data exists, don't reseed
|
|
if session.query(models.ContactPerson).first():
|
|
return
|
|
|
|
session.add_all(
|
|
contacts
|
|
+ return_addresses
|
|
+ dewars
|
|
+ proposals
|
|
+ shipments
|
|
+ pucks
|
|
+ samples
|
|
+ dewar_types
|
|
+ serial_numbers
|
|
+ slots
|
|
+ sample_events
|
|
)
|
|
session.commit()
|