# database.py from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from . import models import os # Fetch the environment (default to "dev") environment = os.getenv("ENVIRONMENT", "dev") # Configure database per environment if environment == "prod": db_username = os.getenv("DB_USERNAME", "prod_user") db_password = os.getenv("DB_PASSWORD", "prod_password") db_host = os.getenv("DB_HOST", "localhost") db_name = os.getenv("DB_NAME", "aare_prod_db") elif environment == "test": db_username = os.getenv("DB_USERNAME", "test_user") db_password = os.getenv("DB_PASSWORD", "test_password") db_host = os.getenv("DB_HOST", "localhost") db_name = os.getenv("DB_NAME", "aare_test_db") else: # Default is dev db_username = os.getenv("DB_USERNAME", "dev_user") db_password = os.getenv("DB_PASSWORD", "dev_password") db_host = os.getenv("DB_HOST", "localhost") db_name = os.getenv("DB_NAME", "aare_dev_db") SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}" # Create engine and session engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # Dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() # Load only slots (minimal data) def load_slots_data(session: Session): from .data import slots from .data import proposals # Only load data if slots are missing if not session.query(models.Slot).first(): session.add_all(slots) # Add slots print("[INFO] Seeding slots...") # Check if proposals table is empty and seed proposals if not session.query(models.Proposal).first(): session.add_all(proposals) # Add proposals print("[INFO] Seeding proposals...") # Commit all changes to the database session.commit() print("[INFO] Seeding complete.") # Load full sample data (used in dev/test) def load_sample_data(session: Session): from .data import ( contacts, return_addresses, dewars, proposals, shipments, pucks, samples, dewar_types, serial_numbers, slots, sample_events, local_contacts, beamtimes, ) # If any data exists, don't reseed if session.query(models.Contact).first(): return session.add_all( contacts + return_addresses + dewars + proposals + shipments + pucks + samples + dewar_types + serial_numbers + slots + sample_events + local_contacts + beamtimes ) session.commit()