
This commit adds relationships to link Pucks and Samples to Beamtime in the models, enabling better data association. Includes changes to assign beamtime IDs during data generation and updates in API response models for improved data loading. Removed redundant code in testfunctions.ipynb to clean up the notebook.
300 lines
10 KiB
Python
300 lines
10 KiB
Python
import os
|
|
import json
|
|
import tomllib
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from app import ssl_heidi
|
|
from app.routers import (
|
|
proposal,
|
|
puck,
|
|
spreadsheet,
|
|
logistics,
|
|
auth,
|
|
sample,
|
|
processing,
|
|
)
|
|
from app.database import Base, engine, SessionLocal
|
|
from app.routers.protected_router import protected_router
|
|
|
|
os.makedirs("images", exist_ok=True)
|
|
|
|
|
|
# Utility function to fetch metadata from pyproject.toml
|
|
def get_project_metadata():
|
|
script_dir = Path(__file__).resolve().parent
|
|
pyproject_path = script_dir / "pyproject.toml" # Check current directory first
|
|
|
|
if pyproject_path.exists():
|
|
with open(pyproject_path, "rb") as f:
|
|
pyproject = tomllib.load(f)
|
|
name = pyproject["project"]["name"]
|
|
version = pyproject["project"]["version"]
|
|
return name, version
|
|
|
|
# Search in parent directories
|
|
for parent in script_dir.parents:
|
|
pyproject_path = parent / "pyproject.toml"
|
|
if pyproject_path.exists():
|
|
with open(pyproject_path, "rb") as f:
|
|
pyproject = tomllib.load(f)
|
|
name = pyproject["project"]["name"]
|
|
version = pyproject["project"]["version"]
|
|
return name, version
|
|
|
|
raise FileNotFoundError(
|
|
f"pyproject.toml not found in any parent directory of {script_dir}"
|
|
)
|
|
|
|
|
|
def run_server(config, cert_path, key_path):
|
|
import uvicorn
|
|
|
|
environment = os.getenv(
|
|
"ENVIRONMENT", "dev"
|
|
) # needs to be set explicitly here if not globally available
|
|
|
|
print(f"[INFO] Starting server in {environment} environment...")
|
|
print(f"[INFO] SSL Certificate Path: {cert_path}")
|
|
print(f"[INFO] SSL Key Path: {key_path}")
|
|
port = config.get("PORT")
|
|
if not port:
|
|
print("[ERROR] No port defined in config or environment variables. Aborting!")
|
|
sys.exit(1) # Exit if no port is defined
|
|
port = int(port)
|
|
print(f"[INFO] Running on port {port}")
|
|
|
|
uvicorn.run(
|
|
app,
|
|
host="0.0.0.0",
|
|
port=port,
|
|
log_level="debug",
|
|
ssl_keyfile=key_path,
|
|
ssl_certfile=cert_path,
|
|
)
|
|
|
|
|
|
# Get project metadata from pyproject.toml
|
|
project_name, project_version = get_project_metadata()
|
|
|
|
# Determine environment and configuration file path
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"
|
|
|
|
if not config_file.exists():
|
|
raise FileNotFoundError(f"Config file '{config_file}' does not exist.")
|
|
|
|
# Load configuration
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
|
|
# Set SSL paths based on environment
|
|
if environment in ["test", "dev"]:
|
|
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
|
|
key_path = config.get("ssl_key_path", "ssl/key.pem")
|
|
ssl_dir = Path(cert_path).parent
|
|
|
|
# Ensure the directory exists before file operations
|
|
ssl_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
elif environment == "prod":
|
|
cert_path = config.get("SSL_CERT_PATH")
|
|
key_path = config.get("SSL_KEY_PATH")
|
|
if not cert_path or not key_path:
|
|
raise ValueError(
|
|
"SSL_CERT_PATH and SSL_KEY_PATH must be set in config_prod.json"
|
|
" for production."
|
|
)
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
raise FileNotFoundError(
|
|
f"Missing SSL files in production. Ensure the following files exist:\n"
|
|
f"SSL Certificate: {cert_path}\nSSL Key: {key_path}"
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown environment: {environment}")
|
|
|
|
# Generate SSL Key and Certificate if they do not exist
|
|
if environment == "dev":
|
|
if not Path(cert_path).exists() or not Path(key_path).exists():
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
|
|
def cleanup_job_loop():
|
|
import time
|
|
from app.dependencies import get_db
|
|
from app.routers.processing import cleanup_cancelled_jobs
|
|
|
|
while True:
|
|
db = next(get_db())
|
|
try:
|
|
cleanup_cancelled_jobs(db)
|
|
finally:
|
|
db.close()
|
|
time.sleep(3600) # every hour
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
print("[INFO] Running application startup tasks...")
|
|
db = SessionLocal()
|
|
try:
|
|
if environment == "prod":
|
|
# Base.metadata.drop_all(bind=engine)
|
|
# Base.metadata.create_all(bind=engine)
|
|
from sqlalchemy.engine import reflection
|
|
|
|
inspector = reflection.Inspector.from_engine(engine)
|
|
tables_exist = inspector.get_table_names()
|
|
# from app.models import ExperimentParameters # adjust the import as needed
|
|
#
|
|
# inspector = reflection.Inspector.from_engine(engine)
|
|
# tables_exist = inspector.get_table_names()
|
|
#
|
|
# if ExperimentParameters.__tablename__ not in tables_exist:
|
|
# print("Creating missing table: ExperimentParameters")
|
|
# ExperimentParameters.__table__.create(bind=engine)
|
|
#
|
|
# Ensure the production database is initialized
|
|
if not tables_exist:
|
|
print("Production database is empty. Initializing...")
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Seed the database (slots + proposals)
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
else: # dev or test environments
|
|
print(f"{environment.capitalize()} environment: Regenerating database.")
|
|
Base.metadata.drop_all(bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
# from sqlalchemy.engine import reflection
|
|
# from app.models import ExperimentParameters # adjust the import as needed
|
|
# inspector = reflection.Inspector.from_engine(engine)
|
|
# tables_exist = inspector.get_table_names()
|
|
#
|
|
# if ExperimentParameters.__tablename__ not in tables_exist:
|
|
# print("Creating missing table: ExperimentParameters")
|
|
# ExperimentParameters.__table__.create(bind=engine)
|
|
#
|
|
if environment == "dev":
|
|
from app.database import load_sample_data
|
|
|
|
load_sample_data(db)
|
|
elif environment == "test":
|
|
from app.database import load_slots_data
|
|
|
|
load_slots_data(db)
|
|
|
|
from threading import Thread
|
|
|
|
# Start cleanup in background thread
|
|
thread = Thread(target=cleanup_job_loop, daemon=True)
|
|
thread.start()
|
|
|
|
yield
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
app = FastAPI(
|
|
lifespan=lifespan,
|
|
title=project_name,
|
|
description="Backend for next-gen sample management system",
|
|
version=project_version,
|
|
servers=[
|
|
{"url": "https://mx-aare-test.psi.ch:8000", "description": "Default server"}
|
|
],
|
|
)
|
|
# Apply CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Include routers with correct configuration
|
|
app.include_router(protected_router, prefix="/protected")
|
|
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
|
app.include_router(proposal.router, prefix="/proposals", tags=["proposals"])
|
|
app.include_router(puck.router, prefix="/pucks", tags=["pucks"])
|
|
app.include_router(spreadsheet.router, tags=["spreadsheet"])
|
|
app.include_router(logistics.router, prefix="/logistics", tags=["logistics"])
|
|
app.include_router(sample.router, prefix="/samples", tags=["samples"])
|
|
app.include_router(processing.router, prefix="/processing", tags=["processing"])
|
|
|
|
app.mount("/images", StaticFiles(directory="images"), name="images")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from multiprocessing import Process
|
|
from time import sleep
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
environment = os.getenv("ENVIRONMENT", "dev")
|
|
config_file = Path(__file__).resolve().parent / f"config_{environment}.json"
|
|
|
|
# Check if `generate-openapi` option is passed
|
|
if len(sys.argv) > 1 and sys.argv[1] == "generate-openapi":
|
|
from fastapi.openapi.utils import get_openapi
|
|
|
|
# Generate and save OpenAPI JSON file
|
|
openapi_schema = get_openapi(
|
|
title=app.title,
|
|
version=app.version,
|
|
description=app.description,
|
|
routes=app.routes,
|
|
)
|
|
with open("openapi.json", "w") as f:
|
|
json.dump(openapi_schema, f, indent=4)
|
|
print("openapi.json generated successfully.")
|
|
sys.exit(0) # Exit after generating the file
|
|
|
|
# Explicitly load the configuration file
|
|
with open(config_file, "r") as f:
|
|
config = json.load(f)
|
|
|
|
# Explicitly obtain SSL paths from config
|
|
if environment in ["test", "dev"]:
|
|
cert_path = config.get("ssl_cert_path", "ssl/cert.pem")
|
|
key_path = config.get("ssl_key_path", "ssl/key.pem")
|
|
elif environment == "prod":
|
|
cert_path = config.get("SSL_CERT_PATH")
|
|
key_path = config.get("SSL_KEY_PATH")
|
|
if not cert_path or not key_path:
|
|
raise ValueError(
|
|
"SSL_CERT_PATH and SSL_KEY_PATH must be explicitly"
|
|
"set in config_prod.json for production."
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown environment: {environment}")
|
|
|
|
is_ci = os.getenv("CI", "false").lower() == "true"
|
|
|
|
# Handle certificates for dev/test if not available
|
|
ssl_dir = Path(cert_path).parent
|
|
ssl_dir.mkdir(parents=True, exist_ok=True)
|
|
if environment in ["dev", "test"] and (
|
|
not Path(cert_path).exists() or not Path(key_path).exists()
|
|
):
|
|
print(f"[INFO] Generating SSL certificates at {ssl_dir}")
|
|
ssl_heidi.generate_self_signed_cert(cert_path, key_path)
|
|
|
|
if is_ci or environment == "test":
|
|
server_process = Process(target=run_server, args=(config, cert_path, key_path))
|
|
server_process.start()
|
|
sleep(5)
|
|
server_process.terminate()
|
|
server_process.join()
|
|
print("CI/Test environment: Server started and terminated successfully.")
|
|
else:
|
|
run_server(config, cert_path, key_path)
|