diff --git a/backend/app/database.py b/backend/app/database.py index 2d026cc..bdd244a 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -48,10 +48,19 @@ def load_slots_data(session: Session): from .data import slots from .data import proposals - if not session.query(models.Slot).first(): # Load only if no slots exist - session.add_all(slots) - session.add_all(proposals) - session.commit() + # Only load data if slots are missing + if not session.query(models.Slot).first(): + session.add_all(slots) # Add slots + print("[INFO] Seeding slots...") + + # Check if proposals table is empty and seed proposals + if not session.query(models.Proposal).first(): + session.add_all(proposals) # Add proposals + print("[INFO] Seeding proposals...") + + # Commit all changes to the database + session.commit() + print("[INFO] Seeding complete.") # Load full sample data (used in dev/test) diff --git a/backend/main.py b/backend/main.py index 3946c63..b7b2126 100644 --- a/backend/main.py +++ b/backend/main.py @@ -17,7 +17,7 @@ from app.routers import ( auth, sample, ) -from app.database import Base, engine, SessionLocal, load_sample_data, load_slots_data +from app.database import Base, engine, SessionLocal # Utility function to fetch metadata from pyproject.toml @@ -94,6 +94,7 @@ app.add_middleware( @app.on_event("startup") def on_startup(): + print("[INFO] Running application startup tasks...") db = SessionLocal() try: if environment == "prod": @@ -102,20 +103,26 @@ def on_startup(): inspector = reflection.Inspector.from_engine(engine) tables_exist = inspector.get_table_names() + # Ensure the production database is initialized if not tables_exist: print("Production database is empty. Initializing...") Base.metadata.create_all(bind=engine) - load_slots_data(db) - else: - print("Production database already initialized.") - else: # dev or test + # Seed the database (slots + proposals) + from app.database import load_slots_data + + load_slots_data(db) + else: # dev or test environments print(f"{environment.capitalize()} environment: Regenerating database.") Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) if environment == "dev": + from app.database import load_sample_data + load_sample_data(db) elif environment == "test": + from app.database import load_slots_data + load_slots_data(db) finally: db.close()