From 29a9d16bc37174568b9428b0472defa75c235be3 Mon Sep 17 00:00:00 2001 From: GotthardG <51994228+GotthardG@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:53:27 +0100 Subject: [PATCH] various bugs fixing --- backend/app/crud.py | 2 +- backend/app/database.py | 6 +- backend/app/dependencies.py | 2 +- backend/app/init_db.py | 2 +- backend/app/models.py | 2 +- backend/app/sample_models.py | 144 +++++++++++++++++------------------ backend/{app => }/main.py | 42 ++++++++-- backend/save/main.py | 2 +- config_dev.json | 4 + config_test.json | 4 + 10 files changed, 123 insertions(+), 87 deletions(-) rename backend/{app => }/main.py (59%) create mode 100644 config_dev.json create mode 100644 config_test.json diff --git a/backend/app/crud.py b/backend/app/crud.py index aa703ac..bba0edd 100644 --- a/backend/app/crud.py +++ b/backend/app/crud.py @@ -1,6 +1,6 @@ import logging from sqlalchemy.orm import Session, joinedload -from app.models import Shipment +from .models import Shipment def get_shipments(db: Session): logging.info("Fetching all shipments from the database.") diff --git a/backend/app/database.py b/backend/app/database.py index 47d7541..4ed6568 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -from app import models +from . import models import os # Get username and password from environment variables @@ -31,12 +31,12 @@ def get_db(): def init_db(): # Import models inside function to avoid circular dependency - from app import models + from . import models Base.metadata.create_all(bind=engine) def load_sample_data(session: Session): # Import models inside function to avoid circular dependency - from app.data import contacts, return_addresses, dewars, proposals, shipments, pucks, samples, dewar_types, serial_numbers, slots, sample_events + from .data import contacts, return_addresses, dewars, proposals, shipments, pucks, samples, dewar_types, serial_numbers, slots, sample_events # If any data already exists, skip seeding if session.query(models.ContactPerson).first(): diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 5dfd7a8..f7d0a11 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -1,5 +1,5 @@ # app/dependencies.py -from app.database import SessionLocal # Import SessionLocal from database.py +from .database import SessionLocal # Import SessionLocal from database.py def get_db(): db = SessionLocal() diff --git a/backend/app/init_db.py b/backend/app/init_db.py index c2311fe..9d7a9ed 100644 --- a/backend/app/init_db.py +++ b/backend/app/init_db.py @@ -1,4 +1,4 @@ -from app.database import init_db +from .database import init_db def initialize_database(): diff --git a/backend/app/models.py b/backend/app/models.py index ff40228..0b29152 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, Integer, String, Date, ForeignKey, JSON, Interval, DateTime, Boolean from sqlalchemy.orm import relationship -from app.database import Base +from .database import Base from datetime import datetime, timedelta import uuid diff --git a/backend/app/sample_models.py b/backend/app/sample_models.py index b0ce6a4..a927bbe 100644 --- a/backend/app/sample_models.py +++ b/backend/app/sample_models.py @@ -191,84 +191,84 @@ class SpreadsheetModel(BaseModel): raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.") return v - @field_validator('spacegroupnumber', mode="before") - @classmethod - def spacegroupnumber_allowed(cls, v): - if v is not None: - try: - v = int(v) - if not (1 <= v <= 230): - raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") - except (ValueError, TypeError) as e: - raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") from e - return v + @field_validator('spacegroupnumber', mode="before") + @classmethod + def spacegroupnumber_allowed(cls, v): + if v is not None: + try: + v = int(v) + if not (1 <= v <= 230): + raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") + except (ValueError, TypeError) as e: + raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") from e + return v - @field_validator('cellparameters', mode="before") - @classmethod - def cellparameters_format(cls, v): - if v: - values = [float(i) for i in v.split(",")] - if len(values) != 6 or any(val <= 0 for val in values): - raise ValueError(f" '{v}' is not valid. Value must be a set of six positive floats or integers.") - return v + @field_validator('cellparameters', mode="before") + @classmethod + def cellparameters_format(cls, v): + if v: + values = [float(i) for i in v.split(",")] + if len(values) != 6 or any(val <= 0 for val in values): + raise ValueError(f" '{v}' is not valid. Value must be a set of six positive floats or integers.") + return v - @field_validator('rescutkey', 'rescutvalue', mode="before") - @classmethod - def rescutkey_value_pair(cls, values): - rescutkey = values.get('rescutkey') - rescutvalue = values.get('rescutvalue') - if rescutkey and rescutvalue: - if rescutkey not in {"is", "cchalf"}: - raise ValueError("Rescutkey must be either 'is' or 'cchalf'") - if not isinstance(rescutvalue, float) or rescutvalue <= 0: - raise ValueError("Rescutvalue must be a positive float if rescutkey is provided") - return values + @field_validator('rescutkey', 'rescutvalue', mode="before") + @classmethod + def rescutkey_value_pair(cls, values): + rescutkey = values.get('rescutkey') + rescutvalue = values.get('rescutvalue') + if rescutkey and rescutvalue: + if rescutkey not in {"is", "cchalf"}: + raise ValueError("Rescutkey must be either 'is' or 'cchalf'") + if not isinstance(rescutvalue, float) or rescutvalue <= 0: + raise ValueError("Rescutvalue must be a positive float if rescutkey is provided") + return values - @field_validator('trustedhigh', mode="before") - @classmethod - def trustedhigh_allowed(cls, v): - if v is not None: - try: - v = float(v) - if not (0 <= v <= 2.0): - raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") - except (ValueError, TypeError) as e: - raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") from e - return v + @field_validator('trustedhigh', mode="before") + @classmethod + def trustedhigh_allowed(cls, v): + if v is not None: + try: + v = float(v) + if not (0 <= v <= 2.0): + raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") + except (ValueError, TypeError) as e: + raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") from e + return v - @field_validator('chiphiangles', mode="before") - @classmethod - def chiphiangles_allowed(cls, v): - if v is not None: - try: - v = float(v) - if not (0 <= v <= 30): - raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") - except (ValueError, TypeError) as e: - raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") from e - return v + @field_validator('chiphiangles', mode="before") + @classmethod + def chiphiangles_allowed(cls, v): + if v is not None: + try: + v = float(v) + if not (0 <= v <= 30): + raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") + except (ValueError, TypeError) as e: + raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") from e + return v - @field_validator('dose', mode="before") - @classmethod - def dose_positive(cls, v): - if v is not None: - try: - v = float(v) - if v <= 0: - raise ValueError(f" '{v}' is not valid. Value must be a positive float.") - except (ValueError, TypeError) as e: - raise ValueError(f" '{v}' is not valid. Value must be a positive float.") from e - return v + @field_validator('dose', mode="before") + @classmethod + def dose_positive(cls, v): + if v is not None: + try: + v = float(v) + if v <= 0: + raise ValueError(f" '{v}' is not valid. Value must be a positive float.") + except (ValueError, TypeError) as e: + raise ValueError(f" '{v}' is not valid. Value must be a positive float.") from e + return v - class TELLModel(SpreadsheetModel): - input_order: int - samplemountcount: int = 0 - samplestatus: str = "not present" - puckaddress: str = "---" - username: str - puck_number: int - prefix: Optional[str] - folder: Optional[str] +class TELLModel(SpreadsheetModel): + input_order: int + samplemountcount: int = 0 + samplestatus: str = "not present" + puckaddress: str = "---" + username: str + puck_number: int + prefix: Optional[str] + folder: Optional[str] class SpreadsheetResponse(BaseModel): data: List[SpreadsheetModel] # Validated data rows as SpreadsheetModel instances diff --git a/backend/app/main.py b/backend/main.py similarity index 59% rename from backend/app/main.py rename to backend/main.py index cb9df90..fd66763 100644 --- a/backend/app/main.py +++ b/backend/main.py @@ -4,16 +4,31 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app import ssl_heidi from pathlib import Path +import os +import json + from app.routers import address, contact, proposal, dewar, shipment, puck, spreadsheet, logistics, auth, sample from app.database import Base, engine, SessionLocal, load_sample_data app = FastAPI() -# Generate SSL Key and Certificate if not exist -Path("ssl").mkdir(parents=True, exist_ok=True) -if not Path("ssl/cert.pem").exists() or not Path("ssl/key.pem").exists(): - ssl_heidi.generate_self_signed_cert("ssl/cert.pem", "ssl/key.pem") +# Determine environment and configuration file path +environment = os.getenv('ENVIRONMENT', 'dev') +config_file = Path(__file__).resolve().parent.parent / f'config_{environment}.json' + +# Load configuration +with open(config_file) as f: + config = json.load(f) + +cert_path = config['ssl_cert_path'] +key_path = config['ssl_key_path'] + +# Generate SSL Key and Certificate if not exist (only for development) +if environment == 'development': + Path("ssl").mkdir(parents=True, exist_ok=True) + if not Path(cert_path).exists() or not Path(key_path).exists(): + ssl_heidi.generate_self_signed_cert(cert_path, key_path) # Apply CORS middleware app.add_middleware( @@ -52,7 +67,20 @@ app.include_router(sample.router, prefix="/samples", tags=["samples"]) if __name__ == "__main__": - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain(certfile="ssl/cert.pem", keyfile="ssl/key.pem") + import uvicorn + import os - uvicorn.run(app, host="127.0.0.1", port=8000, log_level="debug", ssl_context=ssl_context) + # Get environment from an environment variable + environment = os.getenv('ENVIRONMENT', 'dev') + + # Paths for SSL certificates + cert_path = "ssl/cert.pem" + key_path = "ssl/key.pem" + + if environment == 'testing': + cert_path = "ssl/mx-aare-test.psi.ch.pem" + key_path = "ssl/mx-aare-test.psi.ch.key" + + # Run the application with appropriate SSL setup + uvicorn.run(app, host="127.0.0.1", port=8000, log_level="debug", + ssl_keyfile=key_path, ssl_certfile=cert_path) diff --git a/backend/save/main.py b/backend/save/main.py index 9453dbb..f72e8ce 100644 --- a/backend/save/main.py +++ b/backend/save/main.py @@ -110,7 +110,7 @@ contacts = [ # Example data for return addresses return_addresses = [ Address(id=1, street='123 Hobbiton St', city='Shire', zipcode='12345', country='Middle Earth'), - Address(id=2, street='456 Rohan Rd', city='Edoras', zipcode='67890', country='Middle Earth') + Address(id=2, street='456 Rohan Rd', city='Edoras', zipcode='67890', country='Middle Earth'), Address(id=3, street='789 Greenwood Dr', city='Mirkwood', zipcode='13579', country='Middle Earth'), Address(id=4, street='321 Gondor Ave', city='Minas Tirith', zipcode='24680', country='Middle Earth'), Address(id=5, street='654 Falgorn Pass', city='Rivendell', zipcode='11223', country='Middle Earth') diff --git a/config_dev.json b/config_dev.json new file mode 100644 index 0000000..9952169 --- /dev/null +++ b/config_dev.json @@ -0,0 +1,4 @@ +{ + "ssl_cert_path": "ssl/cert.pem", + "ssl_key_path": "ssl/key.pem" +} \ No newline at end of file diff --git a/config_test.json b/config_test.json new file mode 100644 index 0000000..bc9d5c2 --- /dev/null +++ b/config_test.json @@ -0,0 +1,4 @@ +{ + "ssl_cert_path": "ssl/mx-aare-test.psi.ch.pem", + "ssl_key_path": "ss/mx-aare-test.psi.ch.key" +} \ No newline at end of file