
Updated Dewar API methods to use protected endpoints for enhanced security and consistency. Added `pgroups` handling in various frontend components and modified the LogisticsView contact field for clarity. Simplified backend router imports for better readability.
212 lines
7.3 KiB
Python
212 lines
7.3 KiB
Python
import os
|
|
import json
|
|
import tomllib
|
|
from pathlib import Path
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from app import ssl_heidi
|
|
from app.routers import (
|
|
proposal,
|
|
puck,
|
|
spreadsheet,
|
|
logistics,
|
|
auth,
|
|
sample,
|
|
)
|
|
from app.database import Base, engine, SessionLocal
|
|
from app.routers.protected_router import protected_router
|
|
|
|
|
|
# Utility function to fetch metadata from pyproject.toml
|
|
def get_project_metadata():
|
|
script_dir = Path(__file__).resolve().parent
|
|
for parent in script_dir.parents:
|
|
pyproject_path = parent / "pyproject.toml"
|
|
if pyproject_path.exists():
|
|
with open(pyproject_path, "rb") as f:
|
|
pyproject = tomllib.load(f)
|
|
name = pyproject["project"]["name"]
|
|
version = pyproject["project"]["version"]
|
|
return name, version
|
|
raise FileNotFoundError(
|
|
f"pyproject.toml not found in any parent directory of {script_dir}"
|
|
)
|
|
|
|
|
|
def run_server():
|
|
import uvicorn
|
|
|
|
print(f"[INFO] Starting server in {environment} environment...")
|
|
print(f"[INFO] SSL Certificate Path: {cert_path}")
|
|
print(f"[INFO] SSL Key Path: {key_path}")
|
|
port = config.get("PORT", os.getenv("PORT"))
|
|
if not port:
|
|
print("[ERROR] No port defined in config or environment variables. Aborting!")
|
|
sys.exit(1) # Exit if no port is defined
|
|
port = int(port)
|
|
print(f"[INFO] Running on port {port}")
|
|
uvicorn.run(
|
|
app,
|
|
host="127.0.0.1" if environment in ["dev", "test"] else "0.0.0.0",
|
|
port=port,
|
|
log_level="debug",
|
|
ssl_keyfile=key_path,
|
|
ssl_certfile=cert_path,
|
|
)
|
|
|
|
|
|
# Get project metadata from pyproject.toml
|
|
project_name, project_version = get_project_metadata()
|
|
app = FastAPI(
|
|
title=project_name,
|
|
description="Backend for next-gen sample management system",
|
|
version=project_version,
|
|
servers=[
|
|
{"url": "https://mx-aare-test.psi.ch:1492", "description": "Default server"}
|
|
],
|
|
)
|
|
|
|
# Determine environment and configuration file path
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
|
|
|
|
if not config_file.exists():
|
|
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
|
|
|
|
# Load configuration
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
|
|
# Set SSL paths based on environment
|
|
if environment in ["test", "dev"]:
|
|
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
|
|
key_path = config.get("ssl_key_path", "ssl/key.pem")
|
|
elif environment == "prod":
|
|
cert_path = config.get("SSL_CERT_PATH")
|
|
key_path = config.get("SSL_KEY_PATH")
|
|
# Validate production SSL paths
|
|
if not cert_path or not key_path:
|
|
raise ValueError(
|
|
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
|
|
" for production."
|
|
)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
raise FileNotFoundError(
|
|
f"Missing SSL files in production. Ensure the following files exist:\n"
|
|
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown environment: {environment}")
|
|
|
|
# Generate SSL Key and Certificate if not exist (only for development)
|
|
if environment == "dev":
|
|
Path("ssl").mkdir(parents=True, exist_ok=True)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
# Apply CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
print("[INFO] Running application startup tasks...")
|
|
db = SessionLocal()
|
|
try:
|
|
if environment == "prod":
|
|
from sqlalchemy.engine import reflection
|
|
|
|
inspector = reflection.Inspector.from_engine(engine)
|
|
tables_exist = inspector.get_table_names()
|
|
|
|
# Ensure the production database is initialized
|
|
if not tables_exist:
|
|
print("Production database is empty. Initializing...")
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Seed the database (slots + proposals)
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
else: # dev or test environments
|
|
print(f"{environment.capitalize()} environment: Regenerating database.")
|
|
# Base.metadata.drop_all(bind=engine)
|
|
# Base.metadata.create_all(bind=engine)
|
|
if environment == "dev":
|
|
from app.database import load_sample_data
|
|
|
|
load_sample_data(db)
|
|
elif environment == "test":
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Include routers with correct configuration
|
|
app.include_router(protected_router, prefix="/protected")
|
|
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
|
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
|
|
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
|
|
app.include_router(spreadsheet.router, tags=["spreadsheet"])
|
|
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
|
|
app.include_router(sample.router, prefix="/samples", tags=["samples"])
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from multiprocessing import Process
|
|
from time import sleep
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Check if `generate-openapi` option is passed
|
|
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
|
|
from fastapi.openapi.utils import get_openapi
|
|
|
|
# Generate and save OpenAPI JSON file
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
)
|
|
with open("openapi.json", "w") as f:
|
|
json.dump(openapi_schema, f, indent=4)
|
|
print("openapi.json generated successfully.")
|
|
sys.exit(0) # Exit after generating the file
|
|
|
|
# Default behavior: Run the server based on the environment
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
port = int(os.getenv("PORT", 8000))
|
|
is_ci = os.getenv("CI", "false").lower() == "true"
|
|
|
|
if is_ci or environment == "test":
|
|
# Test or CI Mode: Run server process temporarily for test validation
|
|
ssl_dir = Path(cert_path).parent
|
|
ssl_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate self-signed certs if missing
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
print(f"[INFO] Generating self-signed SSL certificates at {ssl_dir}")
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
# Start the server as a subprocess, wait, then terminate
|
|
server_process = Process(target=run_server)
|
|
server_process.start()
|
|
sleep(5) # Wait for 5 seconds to verify the server is running
|
|
server_process.terminate() # Terminate the server process (for CI)
|
|
server_process.join() # Ensure proper cleanup
|
|
print("CI: Server started and terminated successfully for test validation.")
|
|
else:
|
|
# Dev or Prod: Start the server as usual
|
|
run_server()
|