
Reorganized and enhanced the OpenAPI fetch logic for better maintainability and error handling. Key updates include improved environment variable validation, more detailed error messages, streamlined configuration loading, and additional safety checks for file paths and directories. Added proper logging and ensured the process flow is easy to trace.
91 lines
2.4 KiB
Python
91 lines
2.4 KiB
Python
# database.py
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker
|
|
from . import models
|
|
import os
|
|
|
|
# Fetch the environment (default to "dev")
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
|
|
# Configure database per environment
|
|
if environment == "prod":
|
|
db_username = os.getenv("DB_USERNAME", "prod_user")
|
|
db_password = os.getenv("DB_PASSWORD", "prod_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_prod_db")
|
|
elif environment == "test":
|
|
db_username = os.getenv("DB_USERNAME", "test_user")
|
|
db_password = os.getenv("DB_PASSWORD", "test_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_test_db")
|
|
else: # Default is dev
|
|
db_username = os.getenv("DB_USERNAME", "dev_user")
|
|
db_password = os.getenv("DB_PASSWORD", "dev_password")
|
|
db_host = os.getenv("DB_HOST", "localhost")
|
|
db_name = os.getenv("DB_NAME", "aare_dev_db")
|
|
|
|
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@{db_host}/{db_name}"
|
|
|
|
# Create engine and session
|
|
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
Base = declarative_base()
|
|
|
|
|
|
# Dependency
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Load only slots (minimal data)
|
|
def load_slots_data(session: Session):
|
|
from .data import slots
|
|
from .data import proposals
|
|
|
|
if not session.query(models.Slot).first(): # Load only if no slots exist
|
|
session.add_all(slots)
|
|
session.add_all(proposals)
|
|
session.commit()
|
|
|
|
|
|
# Load full sample data (used in dev/test)
|
|
def load_sample_data(session: Session):
|
|
from .data import (
|
|
contacts,
|
|
return_addresses,
|
|
dewars,
|
|
proposals,
|
|
shipments,
|
|
pucks,
|
|
samples,
|
|
dewar_types,
|
|
serial_numbers,
|
|
slots,
|
|
sample_events,
|
|
)
|
|
|
|
# If any data exists, don't reseed
|
|
if session.query(models.ContactPerson).first():
|
|
return
|
|
|
|
session.add_all(
|
|
contacts
|
|
+ return_addresses
|
|
+ dewars
|
|
+ proposals
|
|
+ shipments
|
|
+ pucks
|
|
+ samples
|
|
+ dewar_types
|
|
+ serial_numbers
|
|
+ slots
|
|
+ sample_events
|
|
)
|
|
session.commit()
|