Fix formatting with black

This commit is contained in:
GotthardG 2024-12-16 10:41:56 +01:00
parent 57763970f9
commit a0be71bdfe
26 changed files with 1657 additions and 645 deletions

View File

@ -2,29 +2,42 @@ import logging
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from .models import Shipment from .models import Shipment
def get_shipments(db: Session): def get_shipments(db: Session):
logging.info("Fetching all shipments from the database.") logging.info("Fetching all shipments from the database.")
shipments = db.query(Shipment).options( shipments = (
db.query(Shipment)
.options(
joinedload(Shipment.contact_person), joinedload(Shipment.contact_person),
joinedload(Shipment.return_address), joinedload(Shipment.return_address),
joinedload(Shipment.proposal), joinedload(Shipment.proposal),
joinedload(Shipment.dewars) joinedload(Shipment.dewars),
).all() )
.all()
)
logging.info(f"Total of {len(shipments)} shipments fetched.") logging.info(f"Total of {len(shipments)} shipments fetched.")
for shipment in shipments: for shipment in shipments:
if shipment.proposal_id is None: if shipment.proposal_id is None:
logging.warning(f"Shipment {shipment.id} is missing proposal ID.") logging.warning(f"Shipment {shipment.id} is missing proposal ID.")
logging.debug(f"Shipment ID: {shipment.id}, Shipment Name: {shipment.shipment_name}") logging.debug(
f"Shipment ID: {shipment.id}, Shipment Name: {shipment.shipment_name}"
)
return shipments return shipments
def get_shipment_by_id(db: Session, id: int): def get_shipment_by_id(db: Session, id: int):
logging.info(f"Fetching shipment with ID: {id}") logging.info(f"Fetching shipment with ID: {id}")
shipment = db.query(Shipment).options( shipment = (
db.query(Shipment)
.options(
joinedload(Shipment.contact_person), joinedload(Shipment.contact_person),
joinedload(Shipment.return_address), joinedload(Shipment.return_address),
joinedload(Shipment.proposal), joinedload(Shipment.proposal),
joinedload(Shipment.dewars) joinedload(Shipment.dewars),
).filter(Shipment.id == id).first() )
.filter(Shipment.id == id)
.first()
)
if shipment: if shipment:
if shipment.proposal_id is None: if shipment.proposal_id is None:
logging.warning(f"Shipment {shipment.id} is missing proposal ID.") logging.warning(f"Shipment {shipment.id} is missing proposal ID.")

View File

@ -1,2 +1,13 @@
from .data import contacts, return_addresses, dewars, proposals, shipments, pucks, samples, dewar_types, serial_numbers, sample_events from .data import (
contacts,
return_addresses,
dewars,
proposals,
shipments,
pucks,
samples,
dewar_types,
serial_numbers,
sample_events,
)
from .slots_data import slots from .slots_data import slots

View File

@ -1,4 +1,16 @@
from app.models import ContactPerson, Address, Dewar, Proposal, Shipment, Puck, Sample, DewarType, DewarSerialNumber, Slot, SampleEvent from app.models import (
ContactPerson,
Address,
Dewar,
Proposal,
Shipment,
Puck,
Sample,
DewarType,
DewarSerialNumber,
Slot,
SampleEvent,
)
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
import time import time
@ -23,74 +35,201 @@ serial_numbers = [
# Define contact persons # Define contact persons
contacts = [ contacts = [
ContactPerson(id=1, firstname="Frodo", lastname="Baggins", phone_number="123-456-7890", email="frodo.baggins@lotr.com"), ContactPerson(
ContactPerson(id=2, firstname="Samwise", lastname="Gamgee", phone_number="987-654-3210", email="samwise.gamgee@lotr.com"), id=1,
ContactPerson(id=3, firstname="Aragorn", lastname="Elessar", phone_number="123-333-4444", email="aragorn.elessar@lotr.com"), firstname="Frodo",
ContactPerson(id=4, firstname="Legolas", lastname="Greenleaf", phone_number="555-666-7777", email="legolas.greenleaf@lotr.com"), lastname="Baggins",
ContactPerson(id=5, firstname="Gimli", lastname="Son of Gloin", phone_number="888-999-0000", email="gimli.sonofgloin@lotr.com"), phone_number="123-456-7890",
ContactPerson(id=6, firstname="Gandalf", lastname="The Grey", phone_number="222-333-4444", email="gandalf.thegrey@lotr.com"), email="frodo.baggins@lotr.com",
ContactPerson(id=7, firstname="Boromir", lastname="Son of Denethor", phone_number="111-222-3333", email="boromir.sonofdenethor@lotr.com"), ),
ContactPerson(id=8, firstname="Galadriel", lastname="Lady of Lothlórien", phone_number="444-555-6666", email="galadriel.lothlorien@lotr.com"), ContactPerson(
ContactPerson(id=9, firstname="Elrond", lastname="Half-elven", phone_number="777-888-9999", email="elrond.halfelven@lotr.com"), id=2,
ContactPerson(id=10, firstname="Eowyn", lastname="Shieldmaiden of Rohan", phone_number="000-111-2222", email="eowyn.rohan@lotr.com"), firstname="Samwise",
lastname="Gamgee",
phone_number="987-654-3210",
email="samwise.gamgee@lotr.com",
),
ContactPerson(
id=3,
firstname="Aragorn",
lastname="Elessar",
phone_number="123-333-4444",
email="aragorn.elessar@lotr.com",
),
ContactPerson(
id=4,
firstname="Legolas",
lastname="Greenleaf",
phone_number="555-666-7777",
email="legolas.greenleaf@lotr.com",
),
ContactPerson(
id=5,
firstname="Gimli",
lastname="Son of Gloin",
phone_number="888-999-0000",
email="gimli.sonofgloin@lotr.com",
),
ContactPerson(
id=6,
firstname="Gandalf",
lastname="The Grey",
phone_number="222-333-4444",
email="gandalf.thegrey@lotr.com",
),
ContactPerson(
id=7,
firstname="Boromir",
lastname="Son of Denethor",
phone_number="111-222-3333",
email="boromir.sonofdenethor@lotr.com",
),
ContactPerson(
id=8,
firstname="Galadriel",
lastname="Lady of Lothlórien",
phone_number="444-555-6666",
email="galadriel.lothlorien@lotr.com",
),
ContactPerson(
id=9,
firstname="Elrond",
lastname="Half-elven",
phone_number="777-888-9999",
email="elrond.halfelven@lotr.com",
),
ContactPerson(
id=10,
firstname="Eowyn",
lastname="Shieldmaiden of Rohan",
phone_number="000-111-2222",
email="eowyn.rohan@lotr.com",
),
] ]
# Define return addresses # Define return addresses
return_addresses = [ return_addresses = [
Address(id=1, street='123 Hobbiton St', city='Shire', zipcode='12345', country='Middle Earth'), Address(
Address(id=2, street='456 Rohan Rd', city='Edoras', zipcode='67890', country='Middle Earth'), id=1,
Address(id=3, street='789 Greenwood Dr', city='Mirkwood', zipcode='13579', country='Middle Earth'), street="123 Hobbiton St",
Address(id=4, street='321 Gondor Ave', city='Minas Tirith', zipcode='24680', country='Middle Earth'), city="Shire",
Address(id=5, street='654 Falgorn Pass', city='Rivendell', zipcode='11223', country='Middle Earth'), zipcode="12345",
country="Middle Earth",
),
Address(
id=2,
street="456 Rohan Rd",
city="Edoras",
zipcode="67890",
country="Middle Earth",
),
Address(
id=3,
street="789 Greenwood Dr",
city="Mirkwood",
zipcode="13579",
country="Middle Earth",
),
Address(
id=4,
street="321 Gondor Ave",
city="Minas Tirith",
zipcode="24680",
country="Middle Earth",
),
Address(
id=5,
street="654 Falgorn Pass",
city="Rivendell",
zipcode="11223",
country="Middle Earth",
),
] ]
# Utilize a function to generate unique IDs # Utilize a function to generate unique IDs
def generate_unique_id(length=16): def generate_unique_id(length=16):
base_string = f"{time.time()}{random.randint(0, 10 ** 6)}" base_string = f"{time.time()}{random.randint(0, 10 ** 6)}"
hash_object = hashlib.sha256(base_string.encode()) hash_object = hashlib.sha256(base_string.encode())
hash_digest = hash_object.hexdigest() hash_digest = hash_object.hexdigest()
short_unique_id = ''.join(random.choices(hash_digest, k=length)) short_unique_id = "".join(random.choices(hash_digest, k=length))
return short_unique_id return short_unique_id
# Define dewars with unique IDs # Define dewars with unique IDs
dewars = [ dewars = [
Dewar( Dewar(
id=1, dewar_name='Dewar One', dewar_type_id=1, id=1,
dewar_serial_number_id=2, tracking_number='TRACK123', dewar_name="Dewar One",
return_address_id=1, contact_person_id=1, status='Ready for Shipping', dewar_type_id=1,
ready_date=datetime.strptime('2023-09-30', '%Y-%m-%d'), shipping_date=None, arrival_date=None, dewar_serial_number_id=2,
returning_date=None, unique_id=generate_unique_id() tracking_number="TRACK123",
return_address_id=1,
contact_person_id=1,
status="Ready for Shipping",
ready_date=datetime.strptime("2023-09-30", "%Y-%m-%d"),
shipping_date=None,
arrival_date=None,
returning_date=None,
unique_id=generate_unique_id(),
), ),
Dewar( Dewar(
id=2, dewar_name='Dewar Two', dewar_type_id=3, id=2,
dewar_serial_number_id=1, tracking_number='TRACK124', dewar_name="Dewar Two",
return_address_id=2, contact_person_id=2, status='In Preparation', dewar_type_id=3,
ready_date=None, shipping_date=None, arrival_date=None, returning_date=None, unique_id=generate_unique_id() dewar_serial_number_id=1,
tracking_number="TRACK124",
return_address_id=2,
contact_person_id=2,
status="In Preparation",
ready_date=None,
shipping_date=None,
arrival_date=None,
returning_date=None,
unique_id=generate_unique_id(),
), ),
Dewar( Dewar(
id=3, dewar_name='Dewar Three', dewar_type_id=2, id=3,
dewar_serial_number_id=3, tracking_number='TRACK125', dewar_name="Dewar Three",
return_address_id=1, contact_person_id=3, status='Not Shipped', dewar_type_id=2,
ready_date=datetime.strptime('2024-01-01', '%Y-%m-%d'), shipping_date=None, arrival_date=None, dewar_serial_number_id=3,
returning_date=None, unique_id=None tracking_number="TRACK125",
return_address_id=1,
contact_person_id=3,
status="Not Shipped",
ready_date=datetime.strptime("2024-01-01", "%Y-%m-%d"),
shipping_date=None,
arrival_date=None,
returning_date=None,
unique_id=None,
), ),
Dewar( Dewar(
id=4, dewar_name='Dewar Four', dewar_type_id=2, id=4,
dewar_serial_number_id=4, tracking_number='', dewar_name="Dewar Four",
return_address_id=1, contact_person_id=3, status='Delayed', dewar_type_id=2,
ready_date=datetime.strptime('2024-01-01', '%Y-%m-%d'), dewar_serial_number_id=4,
shipping_date=datetime.strptime('2024-01-02', '%Y-%m-%d'), tracking_number="",
arrival_date=None, returning_date=None, unique_id=None return_address_id=1,
contact_person_id=3,
status="Delayed",
ready_date=datetime.strptime("2024-01-01", "%Y-%m-%d"),
shipping_date=datetime.strptime("2024-01-02", "%Y-%m-%d"),
arrival_date=None,
returning_date=None,
unique_id=None,
), ),
Dewar( Dewar(
id=5, dewar_name='Dewar Five', dewar_type_id=1, id=5,
dewar_serial_number_id=1, tracking_number='', dewar_name="Dewar Five",
return_address_id=1, contact_person_id=3, status='Returned', dewar_type_id=1,
arrival_date=datetime.strptime('2024-01-03', '%Y-%m-%d'), dewar_serial_number_id=1,
returning_date=datetime.strptime('2024-01-07', '%Y-%m-%d'), tracking_number="",
unique_id=None return_address_id=1,
contact_person_id=3,
status="Returned",
arrival_date=datetime.strptime("2024-01-03", "%Y-%m-%d"),
returning_date=datetime.strptime("2024-01-07", "%Y-%m-%d"),
unique_id=None,
), ),
] ]
@ -115,54 +254,252 @@ specific_dewars3 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids3
# Define shipments # Define shipments
shipments = [ shipments = [
Shipment( Shipment(
id=1, shipment_date=datetime.strptime('2024-10-10', '%Y-%m-%d'), id=1,
shipment_name='Shipment from Mordor', shipment_status='Delivered', contact_person_id=2, shipment_date=datetime.strptime("2024-10-10", "%Y-%m-%d"),
proposal_id=3, return_address_id=1, comments='Handle with care', dewars=specific_dewars1 shipment_name="Shipment from Mordor",
shipment_status="Delivered",
contact_person_id=2,
proposal_id=3,
return_address_id=1,
comments="Handle with care",
dewars=specific_dewars1,
), ),
Shipment( Shipment(
id=2, shipment_date=datetime.strptime('2024-10-24', '%Y-%m-%d'), id=2,
shipment_name='Shipment from Mordor', shipment_status='In Transit', contact_person_id=4, shipment_date=datetime.strptime("2024-10-24", "%Y-%m-%d"),
proposal_id=4, return_address_id=2, comments='Contains the one ring', dewars=specific_dewars2 shipment_name="Shipment from Mordor",
shipment_status="In Transit",
contact_person_id=4,
proposal_id=4,
return_address_id=2,
comments="Contains the one ring",
dewars=specific_dewars2,
), ),
Shipment( Shipment(
id=3, shipment_date=datetime.strptime('2024-10-28', '%Y-%m-%d'), id=3,
shipment_name='Shipment from Mordor', shipment_status='In Transit', contact_person_id=5, shipment_date=datetime.strptime("2024-10-28", "%Y-%m-%d"),
proposal_id=5, return_address_id=1, comments='Contains the one ring', dewars=specific_dewars3 shipment_name="Shipment from Mordor",
shipment_status="In Transit",
contact_person_id=5,
proposal_id=5,
return_address_id=1,
comments="Contains the one ring",
dewars=specific_dewars3,
), ),
] ]
# Define pucks # Define pucks
pucks = [ pucks = [
Puck(id=1, puck_name="PUCK001", puck_type="Unipuck", puck_location_in_dewar=1, dewar_id=1), Puck(
Puck(id=2, puck_name="PUCK002", puck_type="Unipuck", puck_location_in_dewar=2, dewar_id=1), id=1,
Puck(id=3, puck_name="PUCK003", puck_type="Unipuck", puck_location_in_dewar=3, dewar_id=1), puck_name="PUCK001",
Puck(id=4, puck_name="PUCK004", puck_type="Unipuck", puck_location_in_dewar=4, dewar_id=1), puck_type="Unipuck",
Puck(id=5, puck_name="PUCK005", puck_type="Unipuck", puck_location_in_dewar=5, dewar_id=1), puck_location_in_dewar=1,
Puck(id=6, puck_name="PUCK006", puck_type="Unipuck", puck_location_in_dewar=6, dewar_id=1), dewar_id=1,
Puck(id=7, puck_name="PUCK007", puck_type="Unipuck", puck_location_in_dewar=7, dewar_id=1), ),
Puck(id=8, puck_name="PK001", puck_type="Unipuck", puck_location_in_dewar=1, dewar_id=2), Puck(
Puck(id=9, puck_name="PK002", puck_type="Unipuck", puck_location_in_dewar=2, dewar_id=2), id=2,
Puck(id=10, puck_name="PK003", puck_type="Unipuck", puck_location_in_dewar=3, dewar_id=2), puck_name="PUCK002",
Puck(id=11, puck_name="PK004", puck_type="Unipuck", puck_location_in_dewar=4, dewar_id=2), puck_type="Unipuck",
Puck(id=12, puck_name="PK005", puck_type="Unipuck", puck_location_in_dewar=5, dewar_id=2), puck_location_in_dewar=2,
Puck(id=13, puck_name="PK006", puck_type="Unipuck", puck_location_in_dewar=6, dewar_id=2), dewar_id=1,
Puck(id=14, puck_name="P001", puck_type="Unipuck", puck_location_in_dewar=1, dewar_id=3), ),
Puck(id=15, puck_name="P002", puck_type="Unipuck", puck_location_in_dewar=2, dewar_id=3), Puck(
Puck(id=16, puck_name="P003", puck_type="Unipuck", puck_location_in_dewar=3, dewar_id=3), id=3,
Puck(id=17, puck_name="P004", puck_type="Unipuck", puck_location_in_dewar=4, dewar_id=3), puck_name="PUCK003",
Puck(id=18, puck_name="P005", puck_type="Unipuck", puck_location_in_dewar=5, dewar_id=3), puck_type="Unipuck",
Puck(id=19, puck_name="P006", puck_type="Unipuck", puck_location_in_dewar=6, dewar_id=3), puck_location_in_dewar=3,
Puck(id=20, puck_name="P007", puck_type="Unipuck", puck_location_in_dewar=7, dewar_id=3), dewar_id=1,
Puck(id=21, puck_name="PC002", puck_type="Unipuck", puck_location_in_dewar=2, dewar_id=4), ),
Puck(id=22, puck_name="PC003", puck_type="Unipuck", puck_location_in_dewar=3, dewar_id=4), Puck(
Puck(id=23, puck_name="PC004", puck_type="Unipuck", puck_location_in_dewar=4, dewar_id=4), id=4,
Puck(id=24, puck_name="PC005", puck_type="Unipuck", puck_location_in_dewar=5, dewar_id=4), puck_name="PUCK004",
Puck(id=25, puck_name="PC006", puck_type="Unipuck", puck_location_in_dewar=6, dewar_id=4), puck_type="Unipuck",
Puck(id=26, puck_name="PC007", puck_type="Unipuck", puck_location_in_dewar=7, dewar_id=4), puck_location_in_dewar=4,
Puck(id=27, puck_name="PKK004", puck_type="Unipuck", puck_location_in_dewar=4, dewar_id=5), dewar_id=1,
Puck(id=28, puck_name="PKK005", puck_type="Unipuck", puck_location_in_dewar=5, dewar_id=5), ),
Puck(id=29, puck_name="PKK006", puck_type="Unipuck", puck_location_in_dewar=6, dewar_id=5), Puck(
Puck(id=30, puck_name="PKK007", puck_type="Unipuck", puck_location_in_dewar=7, dewar_id=5) id=5,
puck_name="PUCK005",
puck_type="Unipuck",
puck_location_in_dewar=5,
dewar_id=1,
),
Puck(
id=6,
puck_name="PUCK006",
puck_type="Unipuck",
puck_location_in_dewar=6,
dewar_id=1,
),
Puck(
id=7,
puck_name="PUCK007",
puck_type="Unipuck",
puck_location_in_dewar=7,
dewar_id=1,
),
Puck(
id=8,
puck_name="PK001",
puck_type="Unipuck",
puck_location_in_dewar=1,
dewar_id=2,
),
Puck(
id=9,
puck_name="PK002",
puck_type="Unipuck",
puck_location_in_dewar=2,
dewar_id=2,
),
Puck(
id=10,
puck_name="PK003",
puck_type="Unipuck",
puck_location_in_dewar=3,
dewar_id=2,
),
Puck(
id=11,
puck_name="PK004",
puck_type="Unipuck",
puck_location_in_dewar=4,
dewar_id=2,
),
Puck(
id=12,
puck_name="PK005",
puck_type="Unipuck",
puck_location_in_dewar=5,
dewar_id=2,
),
Puck(
id=13,
puck_name="PK006",
puck_type="Unipuck",
puck_location_in_dewar=6,
dewar_id=2,
),
Puck(
id=14,
puck_name="P001",
puck_type="Unipuck",
puck_location_in_dewar=1,
dewar_id=3,
),
Puck(
id=15,
puck_name="P002",
puck_type="Unipuck",
puck_location_in_dewar=2,
dewar_id=3,
),
Puck(
id=16,
puck_name="P003",
puck_type="Unipuck",
puck_location_in_dewar=3,
dewar_id=3,
),
Puck(
id=17,
puck_name="P004",
puck_type="Unipuck",
puck_location_in_dewar=4,
dewar_id=3,
),
Puck(
id=18,
puck_name="P005",
puck_type="Unipuck",
puck_location_in_dewar=5,
dewar_id=3,
),
Puck(
id=19,
puck_name="P006",
puck_type="Unipuck",
puck_location_in_dewar=6,
dewar_id=3,
),
Puck(
id=20,
puck_name="P007",
puck_type="Unipuck",
puck_location_in_dewar=7,
dewar_id=3,
),
Puck(
id=21,
puck_name="PC002",
puck_type="Unipuck",
puck_location_in_dewar=2,
dewar_id=4,
),
Puck(
id=22,
puck_name="PC003",
puck_type="Unipuck",
puck_location_in_dewar=3,
dewar_id=4,
),
Puck(
id=23,
puck_name="PC004",
puck_type="Unipuck",
puck_location_in_dewar=4,
dewar_id=4,
),
Puck(
id=24,
puck_name="PC005",
puck_type="Unipuck",
puck_location_in_dewar=5,
dewar_id=4,
),
Puck(
id=25,
puck_name="PC006",
puck_type="Unipuck",
puck_location_in_dewar=6,
dewar_id=4,
),
Puck(
id=26,
puck_name="PC007",
puck_type="Unipuck",
puck_location_in_dewar=7,
dewar_id=4,
),
Puck(
id=27,
puck_name="PKK004",
puck_type="Unipuck",
puck_location_in_dewar=4,
dewar_id=5,
),
Puck(
id=28,
puck_name="PKK005",
puck_type="Unipuck",
puck_location_in_dewar=5,
dewar_id=5,
),
Puck(
id=29,
puck_name="PKK006",
puck_type="Unipuck",
puck_location_in_dewar=6,
dewar_id=5,
),
Puck(
id=30,
puck_name="PKK007",
puck_type="Unipuck",
puck_location_in_dewar=7,
dewar_id=5,
),
] ]
# Define samples # Define samples
@ -179,7 +516,7 @@ for puck in pucks:
id=sample_id_counter, id=sample_id_counter,
sample_name=f"Sample{sample_id_counter:03}", sample_name=f"Sample{sample_id_counter:03}",
position=pos, position=pos,
puck_id=puck.id puck_id=puck.id,
) )
samples.append(sample) samples.append(sample)
sample_id_counter += 1 sample_id_counter += 1
@ -193,7 +530,9 @@ def generate_sample_events(samples, chance_no_event=0.2, chance_lost=0.1):
events = [] events = []
# Set the start time to yesterday at 9:33 AM # Set the start time to yesterday at 9:33 AM
start_time = datetime.now().replace(hour=9, minute=33, second=0, microsecond=0) - timedelta(days=1) start_time = datetime.now().replace(
hour=9, minute=33, second=0, microsecond=0
) - timedelta(days=1)
for sample in samples: for sample in samples:
current_time = start_time current_time = start_time
@ -208,32 +547,37 @@ def generate_sample_events(samples, chance_no_event=0.2, chance_lost=0.1):
event_type = "Failed" if random.random() < 0.05 else "Mounted" event_type = "Failed" if random.random() < 0.05 else "Mounted"
# Append the initial event # Append the initial event
events.append(SampleEvent( events.append(
sample_id=sample.id, SampleEvent(
event_type=event_type, sample_id=sample.id, event_type=event_type, timestamp=current_time
timestamp=current_time )
)) )
current_time += timedelta(seconds=50) # Increment the time for subsequent events current_time += timedelta(
seconds=50
) # Increment the time for subsequent events
# Proceed if mounted and it's not the last sample # Proceed if mounted and it's not the last sample
if event_type == "Mounted" and sample is not samples[-1]: if event_type == "Mounted" and sample is not samples[-1]:
# Determine follow-up event # Determine follow-up event
if random.random() < chance_lost: if random.random() < chance_lost:
events.append(SampleEvent( events.append(
sample_id=sample.id, SampleEvent(
event_type="Lost", sample_id=sample.id, event_type="Lost", timestamp=current_time
timestamp=current_time )
)) )
else: else:
events.append(SampleEvent( events.append(
SampleEvent(
sample_id=sample.id, sample_id=sample.id,
event_type="Unmounted", event_type="Unmounted",
timestamp=current_time timestamp=current_time,
)) )
)
# Increment start_time for the next sample # Increment start_time for the next sample
start_time += timedelta(minutes=10) start_time += timedelta(minutes=10)
return events return events
sample_events = generate_sample_events(samples) sample_events = generate_sample_events(samples)

View File

@ -2,31 +2,73 @@ from datetime import datetime, timedelta
from app.models import Slot from app.models import Slot
slotQRCodes = [ slotQRCodes = [
"A1-X06SA", "A2-X06SA", "A3-X06SA", "A4-X06SA", "A5-X06SA", "A1-X06SA",
"B1-X06SA", "B2-X06SA", "B3-X06SA", "B4-X06SA", "B5-X06SA", "A2-X06SA",
"C1-X06SA", "C2-X06SA", "C3-X06SA", "C4-X06SA", "C5-X06SA", "A3-X06SA",
"D1-X06SA", "D2-X06SA", "D3-X06SA", "D4-X06SA", "D5-X06SA", "A4-X06SA",
"A1-X10SA", "A2-X10SA", "A3-X10SA", "A4-X10SA", "A5-X10SA", "A5-X06SA",
"B1-X10SA", "B2-X10SA", "B3-X10SA", "B4-X10SA", "B5-X10SA", "B1-X06SA",
"C1-X10SA", "C2-X10SA", "C3-X10SA", "C4-X10SA", "C5-X10SA", "B2-X06SA",
"D1-X10SA", "D2-X10SA", "D3-X10SA", "D4-X10SA", "D5-X10SA", "B3-X06SA",
"NB1", "NB2", "NB3", "NB4", "NB5", "NB6", "B4-X06SA",
"X10SA-Beamline", "X06SA-Beamline", "X06DA-Beamline", "B5-X06SA",
"Outgoing X10SA", "Outgoing X06SA" "C1-X06SA",
"C2-X06SA",
"C3-X06SA",
"C4-X06SA",
"C5-X06SA",
"D1-X06SA",
"D2-X06SA",
"D3-X06SA",
"D4-X06SA",
"D5-X06SA",
"A1-X10SA",
"A2-X10SA",
"A3-X10SA",
"A4-X10SA",
"A5-X10SA",
"B1-X10SA",
"B2-X10SA",
"B3-X10SA",
"B4-X10SA",
"B5-X10SA",
"C1-X10SA",
"C2-X10SA",
"C3-X10SA",
"C4-X10SA",
"C5-X10SA",
"D1-X10SA",
"D2-X10SA",
"D3-X10SA",
"D4-X10SA",
"D5-X10SA",
"NB1",
"NB2",
"NB3",
"NB4",
"NB5",
"NB6",
"X10SA-Beamline",
"X06SA-Beamline",
"X06DA-Beamline",
"Outgoing X10SA",
"Outgoing X06SA",
] ]
def timedelta_to_str(td: timedelta) -> str: def timedelta_to_str(td: timedelta) -> str:
days, seconds = td.days, td.seconds days, seconds = td.days, td.seconds
hours = days * 24 + seconds // 3600 hours = days * 24 + seconds // 3600
minutes = (seconds % 3600) // 60 minutes = (seconds % 3600) // 60
return f'PT{hours}H{minutes}M' return f"PT{hours}H{minutes}M"
slots = [ slots = [
Slot( Slot(
id=str(i + 1), # Convert id to string to match your schema id=str(i + 1), # Convert id to string to match your schema
qr_code=qrcode, qr_code=qrcode,
label=qrcode.split('-')[0], label=qrcode.split("-")[0],
qr_base=qrcode.split('-')[1] if '-' in qrcode else '', qr_base=qrcode.split("-")[1] if "-" in qrcode else "",
occupied=False, occupied=False,
needs_refill=False, needs_refill=False,
) )

View File

@ -21,6 +21,7 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
# Dependency # Dependency
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
@ -29,18 +30,45 @@ def get_db():
finally: finally:
db.close() db.close()
def init_db(): def init_db():
# Import models inside function to avoid circular dependency # Import models inside function to avoid circular dependency
from . import models from . import models
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
def load_sample_data(session: Session): def load_sample_data(session: Session):
# Import models inside function to avoid circular dependency # Import models inside function to avoid circular dependency
from .data import contacts, return_addresses, dewars, proposals, shipments, pucks, samples, dewar_types, serial_numbers, slots, sample_events from .data import (
contacts,
return_addresses,
dewars,
proposals,
shipments,
pucks,
samples,
dewar_types,
serial_numbers,
slots,
sample_events,
)
# If any data already exists, skip seeding # If any data already exists, skip seeding
if session.query(models.ContactPerson).first(): if session.query(models.ContactPerson).first():
return return
session.add_all(contacts + return_addresses + dewars + proposals + shipments + pucks + samples + dewar_types + serial_numbers + slots + sample_events) session.add_all(
contacts
+ return_addresses
+ dewars
+ proposals
+ shipments
+ pucks
+ samples
+ dewar_types
+ serial_numbers
+ slots
+ sample_events
)
session.commit() session.commit()

View File

@ -1,6 +1,7 @@
# app/dependencies.py # app/dependencies.py
from .database import SessionLocal # Import SessionLocal from database.py from .database import SessionLocal # Import SessionLocal from database.py
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
try: try:

View File

@ -1,4 +1,13 @@
from sqlalchemy import Column, Integer, String, Date, ForeignKey, JSON, DateTime, Boolean from sqlalchemy import (
Column,
Integer,
String,
Date,
ForeignKey,
JSON,
DateTime,
Boolean,
)
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .database import Base from .database import Base
from datetime import datetime from datetime import datetime
@ -14,7 +23,7 @@ class Shipment(Base):
comments = Column(String(200), nullable=True) comments = Column(String(200), nullable=True)
contact_person_id = Column(Integer, ForeignKey("contact_persons.id")) contact_person_id = Column(Integer, ForeignKey("contact_persons.id"))
return_address_id = Column(Integer, ForeignKey("addresses.id")) return_address_id = Column(Integer, ForeignKey("addresses.id"))
proposal_id = Column(Integer, ForeignKey('proposals.id'), nullable=True) proposal_id = Column(Integer, ForeignKey("proposals.id"), nullable=True)
contact_person = relationship("ContactPerson", back_populates="shipments") contact_person = relationship("ContactPerson", back_populates="shipments")
return_address = relationship("Address", back_populates="shipments") return_address = relationship("Address", back_populates="shipments")
@ -44,17 +53,19 @@ class Address(Base):
shipments = relationship("Shipment", back_populates="return_address") shipments = relationship("Shipment", back_populates="return_address")
class DewarType(Base): class DewarType(Base):
__tablename__ = "dewar_types" __tablename__ = "dewar_types"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
dewar_type = Column(String(255), unique=True, index=True) dewar_type = Column(String(255), unique=True, index=True)
serial_numbers = relationship("DewarSerialNumber", back_populates="dewar_type") serial_numbers = relationship("DewarSerialNumber", back_populates="dewar_type")
class DewarSerialNumber(Base): class DewarSerialNumber(Base):
__tablename__ = "dewar_serial_numbers" __tablename__ = "dewar_serial_numbers"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
serial_number = Column(String(255), index=True) serial_number = Column(String(255), index=True)
dewar_type_id = Column(Integer, ForeignKey('dewar_types.id')) dewar_type_id = Column(Integer, ForeignKey("dewar_types.id"))
dewar_type = relationship("DewarType", back_populates="serial_numbers") dewar_type = relationship("DewarType", back_populates="serial_numbers")
@ -64,7 +75,9 @@ class Dewar(Base):
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
dewar_name = Column(String(255)) dewar_name = Column(String(255))
dewar_type_id = Column(Integer, ForeignKey("dewar_types.id"), nullable=True) dewar_type_id = Column(Integer, ForeignKey("dewar_types.id"), nullable=True)
dewar_serial_number_id = Column(Integer, ForeignKey("dewar_serial_numbers.id"), nullable=True) dewar_serial_number_id = Column(
Integer, ForeignKey("dewar_serial_numbers.id"), nullable=True
)
tracking_number = Column(String(255)) tracking_number = Column(String(255))
status = Column(String(255)) status = Column(String(255))
ready_date = Column(Date, nullable=True) ready_date = Column(Date, nullable=True)
@ -97,6 +110,7 @@ class Dewar(Base):
return 0 return 0
return sum(len(puck.samples) for puck in self.pucks) return sum(len(puck.samples) for puck in self.pucks)
class Proposal(Base): class Proposal(Base):
__tablename__ = "proposals" __tablename__ = "proposals"
@ -106,7 +120,7 @@ class Proposal(Base):
class Puck(Base): class Puck(Base):
__tablename__ = 'pucks' __tablename__ = "pucks"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
puck_name = Column(String(255), index=True) puck_name = Column(String(255), index=True)
@ -114,14 +128,14 @@ class Puck(Base):
puck_location_in_dewar = Column(Integer) puck_location_in_dewar = Column(Integer)
# Foreign keys and relationships # Foreign keys and relationships
dewar_id = Column(Integer, ForeignKey('dewars.id')) dewar_id = Column(Integer, ForeignKey("dewars.id"))
dewar = relationship("Dewar", back_populates="pucks") dewar = relationship("Dewar", back_populates="pucks")
samples = relationship("Sample", back_populates="puck") samples = relationship("Sample", back_populates="puck")
events = relationship("PuckEvent", back_populates="puck") events = relationship("PuckEvent", back_populates="puck")
class Sample(Base): class Sample(Base):
__tablename__ = 'samples' __tablename__ = "samples"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
sample_name = Column(String(255), index=True) sample_name = Column(String(255), index=True)
@ -129,7 +143,7 @@ class Sample(Base):
data_collection_parameters = Column(JSON, nullable=True) data_collection_parameters = Column(JSON, nullable=True)
# Foreign keys and relationships # Foreign keys and relationships
puck_id = Column(Integer, ForeignKey('pucks.id')) puck_id = Column(Integer, ForeignKey("pucks.id"))
puck = relationship("Puck", back_populates="samples") puck = relationship("Puck", back_populates="samples")
events = relationship("SampleEvent", back_populates="sample") events = relationship("SampleEvent", back_populates="sample")
@ -143,36 +157,39 @@ class Slot(Base):
qr_base = Column(String(255), nullable=True) qr_base = Column(String(255), nullable=True)
occupied = Column(Boolean, default=False) occupied = Column(Boolean, default=False)
needs_refill = Column(Boolean, default=False) needs_refill = Column(Boolean, default=False)
dewar_unique_id = Column(String(255), ForeignKey('dewars.unique_id'), nullable=True) dewar_unique_id = Column(String(255), ForeignKey("dewars.unique_id"), nullable=True)
dewar = relationship("Dewar", back_populates="slot") dewar = relationship("Dewar", back_populates="slot")
events = relationship("LogisticsEvent", back_populates="slot") events = relationship("LogisticsEvent", back_populates="slot")
class LogisticsEvent(Base): class LogisticsEvent(Base):
__tablename__ = "logistics_events" __tablename__ = "logistics_events"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
dewar_id = Column(Integer, ForeignKey('dewars.id')) dewar_id = Column(Integer, ForeignKey("dewars.id"))
slot_id = Column(Integer, ForeignKey('slots.id')) slot_id = Column(Integer, ForeignKey("slots.id"))
event_type = Column(String(255), index=True) event_type = Column(String(255), index=True)
timestamp = Column(DateTime, default=datetime.utcnow) timestamp = Column(DateTime, default=datetime.utcnow)
dewar = relationship("Dewar", back_populates="events") dewar = relationship("Dewar", back_populates="events")
slot = relationship("Slot", back_populates="events") slot = relationship("Slot", back_populates="events")
class SampleEvent(Base): class SampleEvent(Base):
__tablename__ = "sample_events" __tablename__ = "sample_events"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
sample_id = Column(Integer, ForeignKey('samples.id')) sample_id = Column(Integer, ForeignKey("samples.id"))
event_type = Column(String(255), index=True) event_type = Column(String(255), index=True)
timestamp = Column(DateTime, default=datetime.utcnow) timestamp = Column(DateTime, default=datetime.utcnow)
sample = relationship("Sample", back_populates="events") sample = relationship("Sample", back_populates="events")
class PuckEvent(Base): class PuckEvent(Base):
__tablename__ = "puck_events" __tablename__ = "puck_events"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
puck_id = Column(Integer, ForeignKey('pucks.id')) puck_id = Column(Integer, ForeignKey("pucks.id"))
tell_position = Column(String(255), nullable=True) tell_position = Column(String(255), nullable=True)
event_type = Column(String(255), index=True) event_type = Column(String(255), index=True)
timestamp = Column(DateTime, default=datetime.utcnow) timestamp = Column(DateTime, default=datetime.utcnow)

View File

@ -5,4 +5,11 @@ from .dewar import router as dewar_router
from .shipment import router as shipment_router from .shipment import router as shipment_router
from .auth import router as auth_router from .auth import router as auth_router
__all__ = ["address_router", "contact_router", "proposal_router", "dewar_router", "shipment_router", "auth_router"] __all__ = [
"address_router",
"contact_router",
"proposal_router",
"dewar_router",
"shipment_router",
"auth_router",
]

View File

@ -7,23 +7,25 @@ from app.dependencies import get_db
router = APIRouter() router = APIRouter()
@router.get("/", response_model=List[AddressSchema]) @router.get("/", response_model=List[AddressSchema])
async def get_return_addresses(db: Session = Depends(get_db)): async def get_return_addresses(db: Session = Depends(get_db)):
return db.query(AddressModel).all() return db.query(AddressModel).all()
@router.post("/", response_model=AddressSchema, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=AddressSchema, status_code=status.HTTP_201_CREATED)
async def create_return_address(address: AddressCreate, db: Session = Depends(get_db)): async def create_return_address(address: AddressCreate, db: Session = Depends(get_db)):
if db.query(AddressModel).filter(AddressModel.city == address.city).first(): if db.query(AddressModel).filter(AddressModel.city == address.city).first():
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Address in this city already exists." detail="Address in this city already exists.",
) )
db_address = AddressModel( db_address = AddressModel(
street=address.street, street=address.street,
city=address.city, city=address.city,
zipcode=address.zipcode, zipcode=address.zipcode,
country=address.country country=address.country,
) )
db.add(db_address) db.add(db_address)
@ -31,13 +33,15 @@ async def create_return_address(address: AddressCreate, db: Session = Depends(ge
db.refresh(db_address) db.refresh(db_address)
return db_address return db_address
@router.put("/{address_id}", response_model=AddressSchema) @router.put("/{address_id}", response_model=AddressSchema)
async def update_return_address(address_id: int, address: AddressUpdate, db: Session = Depends(get_db)): async def update_return_address(
address_id: int, address: AddressUpdate, db: Session = Depends(get_db)
):
db_address = db.query(AddressModel).filter(AddressModel.id == address_id).first() db_address = db.query(AddressModel).filter(AddressModel.id == address_id).first()
if not db_address: if not db_address:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Address not found."
detail="Address not found."
) )
for key, value in address.dict(exclude_unset=True).items(): for key, value in address.dict(exclude_unset=True).items():
setattr(db_address, key, value) setattr(db_address, key, value)
@ -45,13 +49,13 @@ async def update_return_address(address_id: int, address: AddressUpdate, db: Ses
db.refresh(db_address) db.refresh(db_address)
return db_address return db_address
@router.delete("/{address_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{address_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_return_address(address_id: int, db: Session = Depends(get_db)): async def delete_return_address(address_id: int, db: Session = Depends(get_db)):
db_address = db.query(AddressModel).filter(AddressModel.id == address_id).first() db_address = db.query(AddressModel).filter(AddressModel.id == address_id).first()
if not db_address: if not db_address:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Address not found."
detail="Address not found."
) )
db.delete(db_address) db.delete(db_address)
db.commit() db.commit()

View File

@ -26,7 +26,10 @@ SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl="/login", tokenUrl="/token/login") oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="/login", tokenUrl="/token/login"
)
def create_access_token(data: dict) -> str: def create_access_token(data: dict) -> str:
to_encode = data.copy() to_encode = data.copy()
@ -34,6 +37,7 @@ def create_access_token(data: dict) -> str:
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256") return jwt.encode(to_encode, SECRET_KEY, algorithm="HS256")
async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData: async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -60,6 +64,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
return token_data return token_data
@router.post("/token/login", response_model=loginToken) @router.post("/token/login", response_model=loginToken)
async def login(form_data: OAuth2PasswordRequestForm = Depends()): async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = mock_users_db.get(form_data.username) user = mock_users_db.get(form_data.username)

View File

@ -7,38 +7,48 @@ from app.dependencies import get_db
router = APIRouter() router = APIRouter()
# Existing routes # Existing routes
@router.get("/", response_model=List[ContactPerson]) @router.get("/", response_model=List[ContactPerson])
async def get_contacts(db: Session = Depends(get_db)): async def get_contacts(db: Session = Depends(get_db)):
return db.query(ContactPersonModel).all() return db.query(ContactPersonModel).all()
@router.post("/", response_model=ContactPerson, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ContactPerson, status_code=status.HTTP_201_CREATED)
async def create_contact(contact: ContactPersonCreate, db: Session = Depends(get_db)): async def create_contact(contact: ContactPersonCreate, db: Session = Depends(get_db)):
if db.query(ContactPersonModel).filter(ContactPersonModel.email == contact.email).first(): if (
db.query(ContactPersonModel)
.filter(ContactPersonModel.email == contact.email)
.first()
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="This contact already exists." detail="This contact already exists.",
) )
db_contact = ContactPersonModel( db_contact = ContactPersonModel(
firstname=contact.firstname, firstname=contact.firstname,
lastname=contact.lastname, lastname=contact.lastname,
phone_number=contact.phone_number, phone_number=contact.phone_number,
email=contact.email email=contact.email,
) )
db.add(db_contact) db.add(db_contact)
db.commit() db.commit()
db.refresh(db_contact) db.refresh(db_contact)
return db_contact return db_contact
# New routes # New routes
@router.put("/{contact_id}", response_model=ContactPerson) @router.put("/{contact_id}", response_model=ContactPerson)
async def update_contact(contact_id: int, contact: ContactPersonUpdate, db: Session = Depends(get_db)): async def update_contact(
db_contact = db.query(ContactPersonModel).filter(ContactPersonModel.id == contact_id).first() contact_id: int, contact: ContactPersonUpdate, db: Session = Depends(get_db)
):
db_contact = (
db.query(ContactPersonModel).filter(ContactPersonModel.id == contact_id).first()
)
if not db_contact: if not db_contact:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Contact not found."
detail="Contact not found."
) )
for key, value in contact.dict(exclude_unset=True).items(): for key, value in contact.dict(exclude_unset=True).items():
setattr(db_contact, key, value) setattr(db_contact, key, value)
@ -46,13 +56,15 @@ async def update_contact(contact_id: int, contact: ContactPersonUpdate, db: Sess
db.refresh(db_contact) db.refresh(db_contact)
return db_contact return db_contact
@router.delete("/{contact_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{contact_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_contact(contact_id: int, db: Session = Depends(get_db)): async def delete_contact(contact_id: int, db: Session = Depends(get_db)):
db_contact = db.query(ContactPersonModel).filter(ContactPersonModel.id == contact_id).first() db_contact = (
db.query(ContactPersonModel).filter(ContactPersonModel.id == contact_id).first()
)
if not db_contact: if not db_contact:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Contact not found."
detail="Contact not found."
) )
db.delete(db_contact) db.delete(db_contact)
db.commit() db.commit()

View File

@ -13,7 +13,7 @@ from app.schemas import (
DewarTypeCreate, DewarTypeCreate,
DewarSerialNumber as DewarSerialNumberSchema, DewarSerialNumber as DewarSerialNumberSchema,
DewarSerialNumberCreate, DewarSerialNumberCreate,
Shipment as ShipmentSchema # Clearer name for schema Shipment as ShipmentSchema, # Clearer name for schema
) )
from app.models import ( from app.models import (
Dewar as DewarModel, Dewar as DewarModel,
@ -21,7 +21,7 @@ from app.models import (
Sample as SampleModel, Sample as SampleModel,
DewarType as DewarTypeModel, DewarType as DewarTypeModel,
DewarSerialNumber as DewarSerialNumberModel, DewarSerialNumber as DewarSerialNumberModel,
Shipment as ShipmentModel # Clearer name for model Shipment as ShipmentModel, # Clearer name for model
) )
from app.dependencies import get_db from app.dependencies import get_db
import uuid import uuid
@ -32,23 +32,32 @@ from PIL import ImageFont, ImageDraw, Image
from reportlab.lib.pagesizes import A5, landscape from reportlab.lib.pagesizes import A5, landscape
from reportlab.lib.units import cm from reportlab.lib.units import cm
from reportlab.pdfgen import canvas from reportlab.pdfgen import canvas
from app.crud import get_shipments, get_shipment_by_id # Import CRUD functions for shipment from app.crud import (
get_shipments,
get_shipment_by_id,
) # Import CRUD functions for shipment
router = APIRouter() router = APIRouter()
def generate_unique_id(db: Session, length: int = 16) -> str: def generate_unique_id(db: Session, length: int = 16) -> str:
while True: while True:
base_string = f"{time.time()}{random.randint(0, 10 ** 6)}" base_string = f"{time.time()}{random.randint(0, 10 ** 6)}"
hash_object = hashlib.sha256(base_string.encode()) hash_object = hashlib.sha256(base_string.encode())
hash_digest = hash_object.hexdigest() hash_digest = hash_object.hexdigest()
unique_id = ''.join(random.choices(hash_digest, k=length)) unique_id = "".join(random.choices(hash_digest, k=length))
existing_dewar = db.query(DewarModel).filter(DewarModel.unique_id == unique_id).first() existing_dewar = (
db.query(DewarModel).filter(DewarModel.unique_id == unique_id).first()
)
if not existing_dewar: if not existing_dewar:
break break
return unique_id return unique_id
@router.post("/", response_model=DewarSchema, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=DewarSchema, status_code=status.HTTP_201_CREATED)
async def create_dewar(dewar: DewarCreate, db: Session = Depends(get_db)) -> DewarSchema: async def create_dewar(
dewar: DewarCreate, db: Session = Depends(get_db)
) -> DewarSchema:
try: try:
db_dewar = DewarModel( db_dewar = DewarModel(
dewar_name=dewar.dewar_name, dewar_name=dewar.dewar_name,
@ -96,6 +105,7 @@ async def create_dewar(dewar: DewarCreate, db: Session = Depends(get_db)) -> Dew
logging.error(f"Validation error occurred: {e}") logging.error(f"Validation error occurred: {e}")
raise HTTPException(status_code=400, detail="Validation error") raise HTTPException(status_code=400, detail="Validation error")
@router.post("/{dewar_id}/generate-qrcode") @router.post("/{dewar_id}/generate-qrcode")
async def generate_dewar_qrcode(dewar_id: int, db: Session = Depends(get_db)): async def generate_dewar_qrcode(dewar_id: int, db: Session = Depends(get_db)):
dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first() dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first()
@ -109,7 +119,7 @@ async def generate_dewar_qrcode(dewar_id: int, db: Session = Depends(get_db)):
qr = qrcode.QRCode(version=1, box_size=10, border=5) qr = qrcode.QRCode(version=1, box_size=10, border=5)
qr.add_data(dewar.unique_id) qr.add_data(dewar.unique_id)
qr.make(fit=True) qr.make(fit=True)
img = qr.make_image(fill='black', back_color='white') img = qr.make_image(fill="black", back_color="white")
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf) img.save(buf)
@ -120,6 +130,7 @@ async def generate_dewar_qrcode(dewar_id: int, db: Session = Depends(get_db)):
return {"message": "QR Code generated", "qrcode": dewar.unique_id} return {"message": "QR Code generated", "qrcode": dewar.unique_id}
def generate_label(dewar): def generate_label(dewar):
buffer = BytesIO() buffer = BytesIO()
# Set page orientation to landscape # Set page orientation to landscape
@ -138,25 +149,36 @@ def generate_label(dewar):
# Desired logo width in the PDF (you can adjust this size) # Desired logo width in the PDF (you can adjust this size)
desired_logo_width = 4 * cm desired_logo_width = 4 * cm
desired_logo_height = desired_logo_width / logo_aspect_ratio # maintain aspect ratio desired_logo_height = (
desired_logo_width / logo_aspect_ratio
) # maintain aspect ratio
# Draw header text # Draw header text
c.setFont("Helvetica-Bold", 16) c.setFont("Helvetica-Bold", 16)
c.drawString(2 * cm, page_height - 2 * cm, "Paul Scherrer Institut") c.drawString(2 * cm, page_height - 2 * cm, "Paul Scherrer Institut")
# Draw the Heidi logo with preserved aspect ratio # Draw the Heidi logo with preserved aspect ratio
c.drawImage(png_logo_path, page_width - desired_logo_width - 2 * cm, c.drawImage(
png_logo_path,
page_width - desired_logo_width - 2 * cm,
page_height - desired_logo_height - 2 * cm, page_height - desired_logo_height - 2 * cm,
width=desired_logo_width, height=desired_logo_height, mask='auto') width=desired_logo_width,
height=desired_logo_height,
mask="auto",
)
# Draw details section # Draw details section
c.setFont("Helvetica", 12) c.setFont("Helvetica", 12)
y_position = page_height - 4 * cm # Adjusted to ensure text doesn't overlap with the logo y_position = (
page_height - 4 * cm
) # Adjusted to ensure text doesn't overlap with the logo
line_height = 0.8 * cm line_height = 0.8 * cm
if dewar.shipment: if dewar.shipment:
c.drawString(2 * cm, y_position, f"Shipment Name: {dewar.shipment.shipment_name}") c.drawString(
2 * cm, y_position, f"Shipment Name: {dewar.shipment.shipment_name}"
)
y_position -= line_height y_position -= line_height
c.drawString(2 * cm, y_position, f"Dewar Name: {dewar.dewar_name}") c.drawString(2 * cm, y_position, f"Dewar Name: {dewar.dewar_name}")
@ -167,7 +189,11 @@ def generate_label(dewar):
if dewar.contact_person: if dewar.contact_person:
contact_person = dewar.contact_person contact_person = dewar.contact_person
c.drawString(2 * cm, y_position, f"Contact: {contact_person.firstname} {contact_person.lastname}") c.drawString(
2 * cm,
y_position,
f"Contact: {contact_person.firstname} {contact_person.lastname}",
)
y_position -= line_height y_position -= line_height
c.drawString(2 * cm, y_position, f"Email: {contact_person.email}") c.drawString(2 * cm, y_position, f"Email: {contact_person.email}")
y_position -= line_height y_position -= line_height
@ -191,15 +217,17 @@ def generate_label(dewar):
qr = qrcode.QRCode(version=1, box_size=10, border=4) qr = qrcode.QRCode(version=1, box_size=10, border=4)
qr.add_data(dewar.unique_id) qr.add_data(dewar.unique_id)
qr.make(fit=True) qr.make(fit=True)
qr_img = qr.make_image(fill='black', back_color='white').convert("RGBA") qr_img = qr.make_image(fill="black", back_color="white").convert("RGBA")
# Save this QR code to a temporary file # Save this QR code to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
qr_img.save(temp_file, format='PNG') qr_img.save(temp_file, format="PNG")
temp_file_path = temp_file.name temp_file_path = temp_file.name
# Add QR code to PDF # Add QR code to PDF
c.drawImage(temp_file_path, page_width - 6 * cm, 5 * cm, width=4 * cm, height=4 * cm) c.drawImage(
temp_file_path, page_width - 6 * cm, 5 * cm, width=4 * cm, height=4 * cm
)
# Add footer text # Add footer text
c.setFont("Helvetica", 10) c.setFont("Helvetica", 10)
@ -207,7 +235,9 @@ def generate_label(dewar):
# Draw border # Draw border
c.setLineWidth(1) c.setLineWidth(1)
c.rect(1 * cm, 1 * cm, page_width - 2 * cm, page_height - 2 * cm) # Adjusted dimensions c.rect(
1 * cm, 1 * cm, page_width - 2 * cm, page_height - 2 * cm
) # Adjusted dimensions
# Finalize the canvas # Finalize the canvas
c.showPage() c.showPage()
@ -220,25 +250,38 @@ def generate_label(dewar):
return buffer return buffer
@router.get("/{dewar_id}/download-label", response_class=Response) @router.get("/{dewar_id}/download-label", response_class=Response)
async def download_dewar_label(dewar_id: int, db: Session = Depends(get_db)): async def download_dewar_label(dewar_id: int, db: Session = Depends(get_db)):
dewar = db.query(DewarModel).options( dewar = (
db.query(DewarModel)
.options(
joinedload(DewarModel.pucks).joinedload(PuckModel.samples), joinedload(DewarModel.pucks).joinedload(PuckModel.samples),
joinedload(DewarModel.contact_person), joinedload(DewarModel.contact_person),
joinedload(DewarModel.return_address), joinedload(DewarModel.return_address),
joinedload(DewarModel.shipment) joinedload(DewarModel.shipment),
).filter(DewarModel.id == dewar_id).first() )
.filter(DewarModel.id == dewar_id)
.first()
)
if not dewar: if not dewar:
raise HTTPException(status_code=404, detail="Dewar not found") raise HTTPException(status_code=404, detail="Dewar not found")
if not dewar.unique_id: if not dewar.unique_id:
raise HTTPException(status_code=404, detail="QR Code not generated for this dewar") raise HTTPException(
status_code=404, detail="QR Code not generated for this dewar"
)
buffer = generate_label(dewar) buffer = generate_label(dewar)
return Response(buffer.getvalue(), media_type="application/pdf", headers={ return Response(
buffer.getvalue(),
media_type="application/pdf",
headers={
"Content-Disposition": f"attachment; filename=dewar_label_{dewar.id}.pdf" "Content-Disposition": f"attachment; filename=dewar_label_{dewar.id}.pdf"
}) },
)
@router.get("/", response_model=List[DewarSchema]) @router.get("/", response_model=List[DewarSchema])
async def get_dewars(db: Session = Depends(get_db)): async def get_dewars(db: Session = Depends(get_db)):
@ -249,13 +292,23 @@ async def get_dewars(db: Session = Depends(get_db)):
logging.error(f"Database error occurred: {e}") logging.error(f"Database error occurred: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/dewar-types", response_model=List[DewarTypeSchema]) @router.get("/dewar-types", response_model=List[DewarTypeSchema])
def get_dewar_types(db: Session = Depends(get_db)): def get_dewar_types(db: Session = Depends(get_db)):
return db.query(DewarTypeModel).all() return db.query(DewarTypeModel).all()
@router.get("/dewar-types/{type_id}/serial-numbers", response_model=List[DewarSerialNumberSchema])
@router.get(
"/dewar-types/{type_id}/serial-numbers",
response_model=List[DewarSerialNumberSchema],
)
def get_serial_numbers(type_id: int, db: Session = Depends(get_db)): def get_serial_numbers(type_id: int, db: Session = Depends(get_db)):
return db.query(DewarSerialNumberModel).filter(DewarSerialNumberModel.dewar_type_id == type_id).all() return (
db.query(DewarSerialNumberModel)
.filter(DewarSerialNumberModel.dewar_type_id == type_id)
.all()
)
@router.post("/dewar-types", response_model=DewarTypeSchema) @router.post("/dewar-types", response_model=DewarTypeSchema)
def create_dewar_type(dewar_type: DewarTypeCreate, db: Session = Depends(get_db)): def create_dewar_type(dewar_type: DewarTypeCreate, db: Session = Depends(get_db)):
@ -265,14 +318,18 @@ def create_dewar_type(dewar_type: DewarTypeCreate, db: Session = Depends(get_db)
db.refresh(db_type) db.refresh(db_type)
return db_type return db_type
@router.post("/dewar-serial-numbers", response_model=DewarSerialNumberSchema) @router.post("/dewar-serial-numbers", response_model=DewarSerialNumberSchema)
def create_dewar_serial_number(serial_number: DewarSerialNumberCreate, db: Session = Depends(get_db)): def create_dewar_serial_number(
serial_number: DewarSerialNumberCreate, db: Session = Depends(get_db)
):
db_serial = DewarSerialNumberModel(**serial_number.dict()) db_serial = DewarSerialNumberModel(**serial_number.dict())
db.add(db_serial) db.add(db_serial)
db.commit() db.commit()
db.refresh(db_serial) db.refresh(db_serial)
return db_serial return db_serial
@router.get("/dewar-serial-numbers", response_model=List[DewarSerialNumberSchema]) @router.get("/dewar-serial-numbers", response_model=List[DewarSerialNumberSchema])
def get_all_serial_numbers(db: Session = Depends(get_db)): def get_all_serial_numbers(db: Session = Depends(get_db)):
try: try:
@ -282,22 +339,31 @@ def get_all_serial_numbers(db: Session = Depends(get_db)):
logging.error(f"Database error occurred: {e}") logging.error(f"Database error occurred: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{dewar_id}", response_model=DewarSchema) @router.get("/{dewar_id}", response_model=DewarSchema)
async def get_dewar(dewar_id: int, db: Session = Depends(get_db)): async def get_dewar(dewar_id: int, db: Session = Depends(get_db)):
dewar = db.query(DewarModel).options( dewar = (
db.query(DewarModel)
.options(
joinedload(DewarModel.pucks).joinedload(PuckModel.samples), joinedload(DewarModel.pucks).joinedload(PuckModel.samples),
joinedload(DewarModel.contact_person), joinedload(DewarModel.contact_person),
joinedload(DewarModel.return_address), joinedload(DewarModel.return_address),
joinedload(DewarModel.shipment) joinedload(DewarModel.shipment),
).filter(DewarModel.id == dewar_id).first() )
.filter(DewarModel.id == dewar_id)
.first()
)
if not dewar: if not dewar:
raise HTTPException(status_code=404, detail="Dewar not found") raise HTTPException(status_code=404, detail="Dewar not found")
return DewarSchema.from_orm(dewar) return DewarSchema.from_orm(dewar)
@router.put("/{dewar_id}", response_model=DewarSchema) @router.put("/{dewar_id}", response_model=DewarSchema)
async def update_dewar(dewar_id: int, dewar_update: DewarUpdate, db: Session = Depends(get_db)) -> DewarSchema: async def update_dewar(
dewar_id: int, dewar_update: DewarUpdate, db: Session = Depends(get_db)
) -> DewarSchema:
dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first() dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first()
if not dewar: if not dewar:
@ -311,6 +377,7 @@ async def update_dewar(dewar_id: int, dewar_update: DewarUpdate, db: Session = D
db.refresh(dewar) db.refresh(dewar)
return dewar return dewar
@router.delete("/{dewar_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{dewar_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_dewar(dewar_id: int, db: Session = Depends(get_db)): async def delete_dewar(dewar_id: int, db: Session = Depends(get_db)):
dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first() dewar = db.query(DewarModel).filter(DewarModel.id == dewar_id).first()
@ -322,6 +389,7 @@ async def delete_dewar(dewar_id: int, db: Session = Depends(get_db)):
db.commit() db.commit()
return return
# New routes for shipments # New routes for shipments
@router.get("/shipments", response_model=List[ShipmentSchema]) @router.get("/shipments", response_model=List[ShipmentSchema])
async def get_all_shipments(db: Session = Depends(get_db)): async def get_all_shipments(db: Session = Depends(get_db)):
@ -332,6 +400,7 @@ async def get_all_shipments(db: Session = Depends(get_db)):
logging.error(f"Database error occurred: {e}") logging.error(f"Database error occurred: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/shipments/{id}", response_model=ShipmentSchema) @router.get("/shipments/{id}", response_model=ShipmentSchema)
async def get_single_shipment(id: int, db: Session = Depends(get_db)): async def get_single_shipment(id: int, db: Session = Depends(get_db)):
try: try:

View File

@ -2,7 +2,11 @@ from fastapi import APIRouter, HTTPException, Depends
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from typing import List, Optional from typing import List, Optional
from ..models import Dewar as DewarModel, Slot as SlotModel, LogisticsEvent as LogisticsEventModel from ..models import (
Dewar as DewarModel,
Slot as SlotModel,
LogisticsEvent as LogisticsEventModel,
)
from ..schemas import LogisticsEventCreate, SlotSchema, Dewar as DewarSchema from ..schemas import LogisticsEventCreate, SlotSchema, Dewar as DewarSchema
from ..database import get_db from ..database import get_db
import logging import logging
@ -14,7 +18,9 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def calculate_time_until_refill(last_refill: Optional[datetime], refill_interval_hours: int = 1) -> int: def calculate_time_until_refill(
last_refill: Optional[datetime], refill_interval_hours: int = 1
) -> int:
refill_interval = timedelta(hours=refill_interval_hours) refill_interval = timedelta(hours=refill_interval_hours)
now = datetime.now() now = datetime.now()
@ -27,30 +33,54 @@ def calculate_time_until_refill(last_refill: Optional[datetime], refill_interval
@router.post("/dewars/return", response_model=DewarSchema) @router.post("/dewars/return", response_model=DewarSchema)
async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(get_db)): async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(get_db)):
logger.info(f"Returning dewar to storage: {data.dewar_qr_code} at location {data.location_qr_code}") logger.info(
f"Returning dewar to storage: {data.dewar_qr_code} at location {data.location_qr_code}"
)
try: try:
# Log the incoming payload # Log the incoming payload
logger.info("Received payload: %s", data.json()) logger.info("Received payload: %s", data.json())
dewar = db.query(DewarModel).filter(DewarModel.unique_id == data.dewar_qr_code).first() dewar = (
db.query(DewarModel)
.filter(DewarModel.unique_id == data.dewar_qr_code)
.first()
)
if not dewar: if not dewar:
logger.error(f"Dewar not found for unique ID: {data.dewar_qr_code}") logger.error(f"Dewar not found for unique ID: {data.dewar_qr_code}")
raise HTTPException(status_code=404, detail="Dewar not found") raise HTTPException(status_code=404, detail="Dewar not found")
original_slot = db.query(SlotModel).filter(SlotModel.dewar_unique_id == data.dewar_qr_code).first() original_slot = (
db.query(SlotModel)
.filter(SlotModel.dewar_unique_id == data.dewar_qr_code)
.first()
)
if original_slot and original_slot.qr_code != data.location_qr_code: if original_slot and original_slot.qr_code != data.location_qr_code:
logger.error(f"Dewar {data.dewar_qr_code} is associated with slot {original_slot.qr_code}") logger.error(
raise HTTPException(status_code=400, detail=f"Dewar {data.dewar_qr_code} is associated with a different slot {original_slot.qr_code}.") f"Dewar {data.dewar_qr_code} is associated with slot {original_slot.qr_code}"
)
raise HTTPException(
status_code=400,
detail=f"Dewar {data.dewar_qr_code} is associated with a different slot {original_slot.qr_code}.",
)
slot = db.query(SlotModel).filter(SlotModel.qr_code == data.location_qr_code).first() slot = (
db.query(SlotModel)
.filter(SlotModel.qr_code == data.location_qr_code)
.first()
)
if not slot: if not slot:
logger.error(f"Slot not found for QR code: {data.location_qr_code}") logger.error(f"Slot not found for QR code: {data.location_qr_code}")
raise HTTPException(status_code=404, detail="Slot not found") raise HTTPException(status_code=404, detail="Slot not found")
if slot.occupied and slot.dewar_unique_id != data.dewar_qr_code: if slot.occupied and slot.dewar_unique_id != data.dewar_qr_code:
logger.error(f"Slot {data.location_qr_code} is already occupied by another dewar") logger.error(
raise HTTPException(status_code=400, detail="Selected slot is already occupied by another dewar") f"Slot {data.location_qr_code} is already occupied by another dewar"
)
raise HTTPException(
status_code=400,
detail="Selected slot is already occupied by another dewar",
)
# Update slot with dewar information # Update slot with dewar information
slot.dewar_unique_id = dewar.unique_id slot.dewar_unique_id = dewar.unique_id
@ -61,7 +91,9 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge
log_event(db, dewar.id, slot.id, "returned") log_event(db, dewar.id, slot.id, "returned")
db.commit() db.commit()
logger.info(f"Dewar {data.dewar_qr_code} successfully returned to storage slot {slot.qr_code}.") logger.info(
f"Dewar {data.dewar_qr_code} successfully returned to storage slot {slot.qr_code}."
)
db.refresh(dewar) db.refresh(dewar)
return dewar return dewar
except ValidationError as e: except ValidationError as e:
@ -71,6 +103,7 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge
logger.error(f"Unexpected error: {str(e)}") logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/dewar/scan", response_model=dict) @router.post("/dewar/scan", response_model=dict)
async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get_db)): async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get_db)):
logger.info(f"Received event data: {event_data}") logger.info(f"Received event data: {event_data}")
@ -82,7 +115,9 @@ async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get
# Validate Dewar QR Code # Validate Dewar QR Code
if not dewar_qr_code or not dewar_qr_code.strip(): if not dewar_qr_code or not dewar_qr_code.strip():
logger.error("Dewar QR Code is null or empty") logger.error("Dewar QR Code is null or empty")
raise HTTPException(status_code=422, detail="Dewar QR Code cannot be null or empty") raise HTTPException(
status_code=422, detail="Dewar QR Code cannot be null or empty"
)
# Retrieve the Dewar # Retrieve the Dewar
dewar = db.query(DewarModel).filter(DewarModel.unique_id == dewar_qr_code).first() dewar = db.query(DewarModel).filter(DewarModel.unique_id == dewar_qr_code).first()
@ -92,31 +127,42 @@ async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get
# Check for Outgoing QR Codes and set transaction type # Check for Outgoing QR Codes and set transaction type
if location_qr_code in ["Outgoing X10-SA", "Outgoing X06-SA"]: if location_qr_code in ["Outgoing X10-SA", "Outgoing X06-SA"]:
transaction_type = 'outgoing' transaction_type = "outgoing"
# Retrieve the Slot associated with the Dewar (for outgoing) # Retrieve the Slot associated with the Dewar (for outgoing)
slot = None slot = None
if transaction_type == 'outgoing': if transaction_type == "outgoing":
slot = db.query(SlotModel).filter(SlotModel.dewar_unique_id == dewar.unique_id).first() slot = (
db.query(SlotModel)
.filter(SlotModel.dewar_unique_id == dewar.unique_id)
.first()
)
if not slot: if not slot:
logger.error(f"No slot associated with dewar for outgoing: {dewar_qr_code}") logger.error(f"No slot associated with dewar for outgoing: {dewar_qr_code}")
raise HTTPException(status_code=404, detail="No slot associated with dewar for outgoing") raise HTTPException(
status_code=404, detail="No slot associated with dewar for outgoing"
)
# Incoming Logic # Incoming Logic
if transaction_type == 'incoming': if transaction_type == "incoming":
slot = db.query(SlotModel).filter(SlotModel.qr_code == location_qr_code).first() slot = db.query(SlotModel).filter(SlotModel.qr_code == location_qr_code).first()
if not slot or slot.occupied: if not slot or slot.occupied:
logger.error(f"Slot not found or already occupied: {location_qr_code}") logger.error(f"Slot not found or already occupied: {location_qr_code}")
raise HTTPException(status_code=400, detail="Slot not found or already occupied") raise HTTPException(
status_code=400, detail="Slot not found or already occupied"
)
slot.dewar_unique_id = dewar.unique_id slot.dewar_unique_id = dewar.unique_id
slot.occupied = True slot.occupied = True
elif transaction_type == 'outgoing': elif transaction_type == "outgoing":
if not slot.occupied or slot.dewar_unique_id != dewar.unique_id: if not slot.occupied or slot.dewar_unique_id != dewar.unique_id:
logger.error(f"Slot not valid for outgoing: {location_qr_code}") logger.error(f"Slot not valid for outgoing: {location_qr_code}")
raise HTTPException(status_code=400, detail="Dewar not associated with the slot for outgoing") raise HTTPException(
status_code=400,
detail="Dewar not associated with the slot for outgoing",
)
slot.dewar_unique_id = None slot.dewar_unique_id = None
slot.occupied = False slot.occupied = False
elif transaction_type == 'beamline': elif transaction_type == "beamline":
slot = db.query(SlotModel).filter(SlotModel.qr_code == location_qr_code).first() slot = db.query(SlotModel).filter(SlotModel.qr_code == location_qr_code).first()
if not slot: if not slot:
logger.error(f"Beamline location not found: {location_qr_code}") logger.error(f"Beamline location not found: {location_qr_code}")
@ -128,10 +174,12 @@ async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get
log_event(db, dewar.id, slot.id if slot else None, transaction_type) log_event(db, dewar.id, slot.id if slot else None, transaction_type)
db.commit() db.commit()
logger.info( logger.info(
f"Transaction completed: {transaction_type} for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}") f"Transaction completed: {transaction_type} for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}"
)
return {"message": "Status updated successfully"} return {"message": "Status updated successfully"}
@router.get("/slots", response_model=List[SlotSchema]) @router.get("/slots", response_model=List[SlotSchema])
async def get_all_slots(db: Session = Depends(get_db)): async def get_all_slots(db: Session = Depends(get_db)):
slots = db.query(SlotModel).options(joinedload(SlotModel.dewar)).all() slots = db.query(SlotModel).options(joinedload(SlotModel.dewar)).all()
@ -147,14 +195,16 @@ async def get_all_slots(db: Session = Depends(get_db)):
if slot.dewar_unique_id: if slot.dewar_unique_id:
# Calculate time until refill # Calculate time until refill
last_refill_event = db.query(LogisticsEventModel) \ last_refill_event = (
.join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id) \ db.query(LogisticsEventModel)
.join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id)
.filter( .filter(
DewarModel.unique_id == slot.dewar.unique_id, DewarModel.unique_id == slot.dewar.unique_id,
LogisticsEventModel.event_type == "refill" LogisticsEventModel.event_type == "refill",
) \ )
.order_by(LogisticsEventModel.timestamp.desc()) \ .order_by(LogisticsEventModel.timestamp.desc())
.first() .first()
)
if last_refill_event: if last_refill_event:
last_refill = last_refill_event.timestamp last_refill = last_refill_event.timestamp
@ -163,21 +213,27 @@ async def get_all_slots(db: Session = Depends(get_db)):
time_until_refill = -1 time_until_refill = -1
# Fetch the latest beamline event # Fetch the latest beamline event
last_beamline_event = db.query(LogisticsEventModel) \ last_beamline_event = (
.join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id) \ db.query(LogisticsEventModel)
.join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id)
.filter( .filter(
DewarModel.unique_id == slot.dewar.unique_id, DewarModel.unique_id == slot.dewar.unique_id,
LogisticsEventModel.event_type == "beamline" LogisticsEventModel.event_type == "beamline",
) \ )
.order_by(LogisticsEventModel.timestamp.desc()) \ .order_by(LogisticsEventModel.timestamp.desc())
.first() .first()
)
if last_beamline_event: if last_beamline_event:
# Set retrievedTimestamp to the timestamp of the beamline event # Set retrievedTimestamp to the timestamp of the beamline event
retrievedTimestamp = last_beamline_event.timestamp.isoformat() retrievedTimestamp = last_beamline_event.timestamp.isoformat()
# Fetch the associated slot's label for beamlineLocation # Fetch the associated slot's label for beamlineLocation
associated_slot = db.query(SlotModel).filter(SlotModel.id == last_beamline_event.slot_id).first() associated_slot = (
db.query(SlotModel)
.filter(SlotModel.id == last_beamline_event.slot_id)
.first()
)
beamlineLocation = associated_slot.label if associated_slot else None beamlineLocation = associated_slot.label if associated_slot else None
# Mark as being at a beamline # Mark as being at a beamline
@ -204,7 +260,11 @@ async def get_all_slots(db: Session = Depends(get_db)):
at_beamline=at_beamline, at_beamline=at_beamline,
retrievedTimestamp=retrievedTimestamp, retrievedTimestamp=retrievedTimestamp,
beamlineLocation=beamlineLocation, beamlineLocation=beamlineLocation,
shipment_name=slot.dewar.shipment.shipment_name if slot.dewar and slot.dewar.shipment else None, shipment_name=(
slot.dewar.shipment.shipment_name
if slot.dewar and slot.dewar.shipment
else None
),
contact_person=contact_person, contact_person=contact_person,
local_contact="local contact placeholder", local_contact="local contact placeholder",
) )
@ -214,7 +274,6 @@ async def get_all_slots(db: Session = Depends(get_db)):
return slots_with_refill_time return slots_with_refill_time
@router.post("/dewar/refill", response_model=dict) @router.post("/dewar/refill", response_model=dict)
async def refill_dewar(qr_code: str, db: Session = Depends(get_db)): async def refill_dewar(qr_code: str, db: Session = Depends(get_db)):
logger.info(f"Refilling dewar with QR code: {qr_code}") logger.info(f"Refilling dewar with QR code: {qr_code}")
@ -236,9 +295,14 @@ async def refill_dewar(qr_code: str, db: Session = Depends(get_db)):
db.commit() db.commit()
time_until_refill_seconds = calculate_time_until_refill(now) time_until_refill_seconds = calculate_time_until_refill(now)
logger.info(f"Dewar refilled successfully with time_until_refill: {time_until_refill_seconds}") logger.info(
f"Dewar refilled successfully with time_until_refill: {time_until_refill_seconds}"
)
return {"message": "Dewar refilled successfully", "time_until_refill": time_until_refill_seconds} return {
"message": "Dewar refilled successfully",
"time_until_refill": time_until_refill_seconds,
}
@router.get("/dewars", response_model=List[DewarSchema]) @router.get("/dewars", response_model=List[DewarSchema])
@ -250,7 +314,9 @@ async def get_all_dewars(db: Session = Depends(get_db)):
@router.get("/dewar/{unique_id}", response_model=DewarSchema) @router.get("/dewar/{unique_id}", response_model=DewarSchema)
async def get_dewar_by_unique_id(unique_id: str, db: Session = Depends(get_db)): async def get_dewar_by_unique_id(unique_id: str, db: Session = Depends(get_db)):
logger.info(f"Received request for dewar with unique_id: {unique_id}") logger.info(f"Received request for dewar with unique_id: {unique_id}")
dewar = db.query(DewarModel).filter(DewarModel.unique_id == unique_id.strip()).first() dewar = (
db.query(DewarModel).filter(DewarModel.unique_id == unique_id.strip()).first()
)
if not dewar: if not dewar:
logger.warning(f"Dewar with unique_id '{unique_id}' not found.") logger.warning(f"Dewar with unique_id '{unique_id}' not found.")
raise HTTPException(status_code=404, detail="Dewar not found") raise HTTPException(status_code=404, detail="Dewar not found")
@ -263,8 +329,10 @@ def log_event(db: Session, dewar_id: int, slot_id: Optional[int], event_type: st
dewar_id=dewar_id, dewar_id=dewar_id,
slot_id=slot_id, slot_id=slot_id,
event_type=event_type, event_type=event_type,
timestamp=datetime.now() timestamp=datetime.now(),
) )
db.add(new_event) db.add(new_event)
db.commit() db.commit()
logger.info(f"Logged event: {event_type} for dewar: {dewar_id} in slot: {slot_id if slot_id else 'N/A'}") logger.info(
f"Logged event: {event_type} for dewar: {dewar_id} in slot: {slot_id if slot_id else 'N/A'}"
)

View File

@ -8,6 +8,7 @@ from app.dependencies import get_db
router = APIRouter() router = APIRouter()
@router.get("/", response_model=List[ProposalSchema]) @router.get("/", response_model=List[ProposalSchema])
async def get_proposals(db: Session = Depends(get_db)): async def get_proposals(db: Session = Depends(get_db)):
return db.query(ProposalModel).all() return db.query(ProposalModel).all()

View File

@ -2,8 +2,21 @@ from fastapi import APIRouter, HTTPException, status, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
import uuid import uuid
from app.schemas import Puck as PuckSchema, PuckCreate, PuckUpdate, SetTellPosition, PuckEvent from app.schemas import (
from app.models import Puck as PuckModel, Sample as SampleModel, PuckEvent as PuckEventModel, Slot as SlotModel, LogisticsEvent as LogisticsEventModel, Dewar as DewarModel Puck as PuckSchema,
PuckCreate,
PuckUpdate,
SetTellPosition,
PuckEvent,
)
from app.models import (
Puck as PuckModel,
Sample as SampleModel,
PuckEvent as PuckEventModel,
Slot as SlotModel,
LogisticsEvent as LogisticsEventModel,
Dewar as DewarModel,
)
from app.dependencies import get_db from app.dependencies import get_db
from datetime import datetime from datetime import datetime
import logging import logging
@ -13,6 +26,7 @@ router = APIRouter()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@router.get("/", response_model=List[PuckSchema]) @router.get("/", response_model=List[PuckSchema])
async def get_pucks(db: Session = Depends(get_db)): async def get_pucks(db: Session = Depends(get_db)):
return db.query(PuckModel).all() return db.query(PuckModel).all()
@ -35,8 +49,7 @@ async def get_pucks_with_tell_position(db: Session = Depends(get_db)):
if not pucks: if not pucks:
logger.info("No pucks with tell_position found.") # Log for debugging logger.info("No pucks with tell_position found.") # Log for debugging
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail="No pucks with a `tell_position` found."
detail="No pucks with a `tell_position` found."
) )
result = [] result = []
@ -67,6 +80,7 @@ async def get_pucks_with_tell_position(db: Session = Depends(get_db)):
return result return result
@router.get("/{puck_id}", response_model=PuckSchema) @router.get("/{puck_id}", response_model=PuckSchema)
async def get_puck(puck_id: str, db: Session = Depends(get_db)): async def get_puck(puck_id: str, db: Session = Depends(get_db)):
puck = db.query(PuckModel).filter(PuckModel.id == puck_id).first() puck = db.query(PuckModel).filter(PuckModel.id == puck_id).first()
@ -77,13 +91,13 @@ async def get_puck(puck_id: str, db: Session = Depends(get_db)):
@router.post("/", response_model=PuckSchema, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=PuckSchema, status_code=status.HTTP_201_CREATED)
async def create_puck(puck: PuckCreate, db: Session = Depends(get_db)) -> PuckSchema: async def create_puck(puck: PuckCreate, db: Session = Depends(get_db)) -> PuckSchema:
puck_id = f'PUCK-{uuid.uuid4().hex[:8].upper()}' puck_id = f"PUCK-{uuid.uuid4().hex[:8].upper()}"
db_puck = PuckModel( db_puck = PuckModel(
id=puck_id, id=puck_id,
puck_name=puck.puck_name, puck_name=puck.puck_name,
puck_type=puck.puck_type, puck_type=puck.puck_type,
puck_location_in_dewar=puck.puck_location_in_dewar, puck_location_in_dewar=puck.puck_location_in_dewar,
dewar_id=puck.dewar_id dewar_id=puck.dewar_id,
) )
db.add(db_puck) db.add(db_puck)
db.commit() db.commit()
@ -92,7 +106,9 @@ async def create_puck(puck: PuckCreate, db: Session = Depends(get_db)) -> PuckSc
@router.put("/{puck_id}", response_model=PuckSchema) @router.put("/{puck_id}", response_model=PuckSchema)
async def update_puck(puck_id: str, updated_puck: PuckUpdate, db: Session = Depends(get_db)): async def update_puck(
puck_id: str, updated_puck: PuckUpdate, db: Session = Depends(get_db)
):
puck = db.query(PuckModel).filter(PuckModel.id == puck_id).first() puck = db.query(PuckModel).filter(PuckModel.id == puck_id).first()
if not puck: if not puck:
raise HTTPException(status_code=404, detail="Puck not found") raise HTTPException(status_code=404, detail="Puck not found")
@ -115,17 +131,18 @@ async def delete_puck(puck_id: str, db: Session = Depends(get_db)):
db.commit() db.commit()
return return
@router.put("/{puck_id}/tell_position", status_code=status.HTTP_200_OK) @router.put("/{puck_id}/tell_position", status_code=status.HTTP_200_OK)
async def set_tell_position( async def set_tell_position(
puck_id: int, puck_id: int, request: SetTellPosition, db: Session = Depends(get_db)
request: SetTellPosition,
db: Session = Depends(get_db)
): ):
# Get the requested tell_position # Get the requested tell_position
tell_position = request.tell_position tell_position = request.tell_position
# Define valid positions # Define valid positions
valid_positions = [f"{letter}{num}" for letter in "ABCDEF" for num in range(1, 6)] + ["null", None] valid_positions = [
f"{letter}{num}" for letter in "ABCDEF" for num in range(1, 6)
] + ["null", None]
# Validate tell_position # Validate tell_position
if tell_position not in valid_positions: if tell_position not in valid_positions:
@ -161,7 +178,10 @@ async def get_last_tell_position(puck_id: str, db: Session = Depends(get_db)):
# Query the most recent tell_position_set event for the given puck_id # Query the most recent tell_position_set event for the given puck_id
last_event = ( last_event = (
db.query(PuckEventModel) db.query(PuckEventModel)
.filter(PuckEventModel.puck_id == puck_id, PuckEventModel.event_type == "tell_position_set") .filter(
PuckEventModel.puck_id == puck_id,
PuckEventModel.event_type == "tell_position_set",
)
.order_by(PuckEventModel.timestamp.desc()) .order_by(PuckEventModel.timestamp.desc())
.first() .first()
) )
@ -182,10 +202,7 @@ async def get_last_tell_position(puck_id: str, db: Session = Depends(get_db)):
@router.get("/slot/{slot_identifier}", response_model=List[dict]) @router.get("/slot/{slot_identifier}", response_model=List[dict])
async def get_pucks_by_slot( async def get_pucks_by_slot(slot_identifier: str, db: Session = Depends(get_db)):
slot_identifier: str,
db: Session = Depends(get_db)
):
""" """
Retrieve all pucks associated with all dewars linked to the given slot Retrieve all pucks associated with all dewars linked to the given slot
(by ID or keyword) via 'beamline' events. (by ID or keyword) via 'beamline' events.
@ -200,28 +217,29 @@ async def get_pucks_by_slot(
"PXIII": 49, "PXIII": 49,
"X06SA": 47, "X06SA": 47,
"X10SA": 48, "X10SA": 48,
"X06DA": 49 "X06DA": 49,
} }
# Check if the slot identifier is an alias or ID # Check if the slot identifier is an alias or ID
try: try:
slot_id = int(slot_identifier) # If the user provided a numeric ID slot_id = int(slot_identifier) # If the user provided a numeric ID
alias = next((k for k, v in slot_aliases.items() if v == slot_id), slot_identifier) alias = next(
(k for k, v in slot_aliases.items() if v == slot_id), slot_identifier
)
except ValueError: except ValueError:
slot_id = slot_aliases.get(slot_identifier.upper()) # Try mapping alias slot_id = slot_aliases.get(slot_identifier.upper()) # Try mapping alias
alias = slot_identifier.upper() # Keep alias as-is for error messages alias = slot_identifier.upper() # Keep alias as-is for error messages
if not slot_id: if not slot_id:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Invalid slot identifier. Must be an ID or one of the following: PXI, PXII, PXIII, X06SA, X10SA, X06DA." detail="Invalid slot identifier. Must be an ID or one of the following: PXI, PXII, PXIII, X06SA, X10SA, X06DA.",
) )
# Verify that the slot exists # Verify that the slot exists
slot = db.query(SlotModel).filter(SlotModel.id == slot_id).first() slot = db.query(SlotModel).filter(SlotModel.id == slot_id).first()
if not slot: if not slot:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"Slot not found for identifier '{alias}'."
detail=f"Slot not found for identifier '{alias}'."
) )
logger.info(f"Slot found: ID={slot.id}, Label={slot.label}") logger.info(f"Slot found: ID={slot.id}, Label={slot.label}")
@ -231,7 +249,7 @@ async def get_pucks_by_slot(
db.query(LogisticsEventModel) db.query(LogisticsEventModel)
.filter( .filter(
LogisticsEventModel.slot_id == slot_id, LogisticsEventModel.slot_id == slot_id,
LogisticsEventModel.event_type == "beamline" LogisticsEventModel.event_type == "beamline",
) )
.order_by(LogisticsEventModel.timestamp.desc()) .order_by(LogisticsEventModel.timestamp.desc())
.all() .all()
@ -240,8 +258,7 @@ async def get_pucks_by_slot(
if not beamline_events: if not beamline_events:
logger.warning(f"No dewars associated to this beamline '{alias}'.") logger.warning(f"No dewars associated to this beamline '{alias}'.")
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"No dewars found for the given beamline '{alias}'."
detail=f"No dewars found for the given beamline '{alias}'."
) )
logger.info(f"Found {len(beamline_events)} beamline events for slot_id={slot_id}.") logger.info(f"Found {len(beamline_events)} beamline events for slot_id={slot_id}.")
@ -253,8 +270,7 @@ async def get_pucks_by_slot(
if not dewars: if not dewars:
logger.warning(f"No dewars found for beamline '{alias}'.") logger.warning(f"No dewars found for beamline '{alias}'.")
raise HTTPException( raise HTTPException(
status_code=404, status_code=404, detail=f"No dewars found for beamline '{alias}'."
detail=f"No dewars found for beamline '{alias}'."
) )
logger.info(f"Found {len(dewars)} dewars for beamline '{alias}'.") logger.info(f"Found {len(dewars)} dewars for beamline '{alias}'.")
@ -273,7 +289,7 @@ async def get_pucks_by_slot(
logger.warning(f"No pucks found for dewars associated with beamline '{alias}'.") logger.warning(f"No pucks found for dewars associated with beamline '{alias}'.")
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"No pucks found for dewars associated with beamline '{alias}'." detail=f"No pucks found for dewars associated with beamline '{alias}'.",
) )
logger.info(f"Found {len(puck_list)} pucks for beamline '{alias}'.") logger.info(f"Found {len(puck_list)} pucks for beamline '{alias}'.")
@ -285,7 +301,7 @@ async def get_pucks_by_slot(
"puck_name": puck.puck_name, "puck_name": puck.puck_name,
"puck_type": puck.puck_type, "puck_type": puck.puck_type,
"dewar_id": puck.dewar_id, "dewar_id": puck.dewar_id,
"dewar_name": dewar_mapping.get(puck.dewar_id) # Link dewar_name "dewar_name": dewar_mapping.get(puck.dewar_id), # Link dewar_name
} }
for puck in puck_list for puck in puck_list
] ]

View File

@ -2,7 +2,11 @@ from fastapi import APIRouter, HTTPException, status, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
from app.schemas import Puck as PuckSchema, Sample as SampleSchema, SampleEventCreate from app.schemas import Puck as PuckSchema, Sample as SampleSchema, SampleEventCreate
from app.models import Puck as PuckModel, Sample as SampleModel, SampleEvent as SampleEventModel from app.models import (
Puck as PuckModel,
Sample as SampleModel,
SampleEvent as SampleEventModel,
)
from app.dependencies import get_db from app.dependencies import get_db
import logging import logging
@ -18,10 +22,15 @@ async def get_samples_with_events(puck_id: str, db: Session = Depends(get_db)):
samples = db.query(SampleModel).filter(SampleModel.puck_id == puck_id).all() samples = db.query(SampleModel).filter(SampleModel.puck_id == puck_id).all()
for sample in samples: for sample in samples:
sample.events = db.query(SampleEventModel).filter(SampleEventModel.sample_id == sample.id).all() sample.events = (
db.query(SampleEventModel)
.filter(SampleEventModel.sample_id == sample.id)
.all()
)
return samples return samples
@router.get("/pucks-samples", response_model=List[PuckSchema]) @router.get("/pucks-samples", response_model=List[PuckSchema])
async def get_all_pucks_with_samples_and_events(db: Session = Depends(get_db)): async def get_all_pucks_with_samples_and_events(db: Session = Depends(get_db)):
logging.info("Fetching all pucks with samples and events") logging.info("Fetching all pucks with samples and events")
@ -32,5 +41,7 @@ async def get_all_pucks_with_samples_and_events(db: Session = Depends(get_db)):
logging.info(f"Puck ID: {puck.id}, Name: {puck.puck_name}") logging.info(f"Puck ID: {puck.id}, Name: {puck.puck_name}")
if not pucks: if not pucks:
raise HTTPException(status_code=404, detail="No pucks found in the database") # More descriptive raise HTTPException(
status_code=404, detail="No pucks found in the database"
) # More descriptive
return pucks return pucks

View File

@ -6,10 +6,27 @@ from pydantic import BaseModel, ValidationError
from datetime import date from datetime import date
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.models import Shipment as ShipmentModel, ContactPerson as ContactPersonModel, Address as AddressModel, \ from app.models import (
Proposal as ProposalModel, Dewar as DewarModel, Puck as PuckModel, Sample as SampleModel Shipment as ShipmentModel,
from app.schemas import ShipmentCreate, UpdateShipmentComments, Shipment as ShipmentSchema, DewarUpdate, \ ContactPerson as ContactPersonModel,
ContactPerson as ContactPersonSchema, Sample as SampleSchema, DewarCreate, PuckCreate, SampleCreate, DewarSchema Address as AddressModel,
Proposal as ProposalModel,
Dewar as DewarModel,
Puck as PuckModel,
Sample as SampleModel,
)
from app.schemas import (
ShipmentCreate,
UpdateShipmentComments,
Shipment as ShipmentSchema,
DewarUpdate,
ContactPerson as ContactPersonSchema,
Sample as SampleSchema,
DewarCreate,
PuckCreate,
SampleCreate,
DewarSchema,
)
from app.database import get_db from app.database import get_db
from app.crud import get_shipments, get_shipment_by_id from app.crud import get_shipments, get_shipment_by_id
@ -23,7 +40,9 @@ def default_serializer(obj):
@router.get("", response_model=List[ShipmentSchema]) @router.get("", response_model=List[ShipmentSchema])
async def fetch_shipments(id: Optional[int] = Query(None), db: Session = Depends(get_db)): async def fetch_shipments(
id: Optional[int] = Query(None), db: Session = Depends(get_db)
):
if id: if id:
shipment = get_shipment_by_id(db, id) shipment = get_shipment_by_id(db, id)
if not shipment: if not shipment:
@ -35,9 +54,12 @@ async def fetch_shipments(id: Optional[int] = Query(None), db: Session = Depends
shipments = get_shipments(db) shipments = get_shipments(db)
logging.info(f"Total shipments fetched: {len(shipments)}") logging.info(f"Total shipments fetched: {len(shipments)}")
for shipment in shipments: for shipment in shipments:
logging.info(f"Shipment ID: {shipment.id}, Shipment Name: {shipment.shipment_name}") logging.info(
f"Shipment ID: {shipment.id}, Shipment Name: {shipment.shipment_name}"
)
return shipments return shipments
@router.get("/{shipment_id}/dewars", response_model=List[DewarSchema]) @router.get("/{shipment_id}/dewars", response_model=List[DewarSchema])
async def get_dewars_by_shipment_id(shipment_id: int, db: Session = Depends(get_db)): async def get_dewars_by_shipment_id(shipment_id: int, db: Session = Depends(get_db)):
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
@ -51,12 +73,21 @@ async def get_dewars_by_shipment_id(shipment_id: int, db: Session = Depends(get_
return dewars return dewars
@router.post("", response_model=ShipmentSchema, status_code=status.HTTP_201_CREATED) @router.post("", response_model=ShipmentSchema, status_code=status.HTTP_201_CREATED)
async def create_shipment(shipment: ShipmentCreate, db: Session = Depends(get_db)): async def create_shipment(shipment: ShipmentCreate, db: Session = Depends(get_db)):
contact_person = db.query(ContactPersonModel).filter(ContactPersonModel.id == shipment.contact_person_id).first() contact_person = (
return_address = db.query(AddressModel).filter(AddressModel.id == shipment.return_address_id).first() db.query(ContactPersonModel)
proposal = db.query(ProposalModel).filter(ProposalModel.id == shipment.proposal_id).first() .filter(ContactPersonModel.id == shipment.contact_person_id)
.first()
)
return_address = (
db.query(AddressModel)
.filter(AddressModel.id == shipment.return_address_id)
.first()
)
proposal = (
db.query(ProposalModel).filter(ProposalModel.id == shipment.proposal_id).first()
)
if not (contact_person or return_address or proposal): if not (contact_person or return_address or proposal):
raise HTTPException(status_code=404, detail="Associated entity not found") raise HTTPException(status_code=404, detail="Associated entity not found")
@ -97,17 +128,29 @@ async def delete_shipment(shipment_id: int, db: Session = Depends(get_db)):
@router.put("/{shipment_id}", response_model=ShipmentSchema) @router.put("/{shipment_id}", response_model=ShipmentSchema)
async def update_shipment(shipment_id: int, updated_shipment: ShipmentCreate, db: Session = Depends(get_db)): async def update_shipment(
print("Received payload:", json.dumps(updated_shipment.dict(), indent=2, default=default_serializer)) shipment_id: int, updated_shipment: ShipmentCreate, db: Session = Depends(get_db)
):
print(
"Received payload:",
json.dumps(updated_shipment.dict(), indent=2, default=default_serializer),
)
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
# Validate relationships by IDs # Validate relationships by IDs
contact_person = db.query(ContactPersonModel).filter( contact_person = (
ContactPersonModel.id == updated_shipment.contact_person_id).first() db.query(ContactPersonModel)
return_address = db.query(AddressModel).filter(AddressModel.id == updated_shipment.return_address_id).first() .filter(ContactPersonModel.id == updated_shipment.contact_person_id)
.first()
)
return_address = (
db.query(AddressModel)
.filter(AddressModel.id == updated_shipment.return_address_id)
.first()
)
if not contact_person: if not contact_person:
raise HTTPException(status_code=404, detail="Contact person not found") raise HTTPException(status_code=404, detail="Contact person not found")
if not return_address: if not return_address:
@ -123,25 +166,39 @@ async def update_shipment(shipment_id: int, updated_shipment: ShipmentCreate, db
# Process and update dewars' details # Process and update dewars' details
for dewar_data in updated_shipment.dewars: for dewar_data in updated_shipment.dewars:
dewar = db.query(DewarModel).filter(DewarModel.id == dewar_data.dewar_id).first() dewar = (
db.query(DewarModel).filter(DewarModel.id == dewar_data.dewar_id).first()
)
if not dewar: if not dewar:
raise HTTPException(status_code=404, detail=f"Dewar with ID {dewar_data.dewar_id} not found") raise HTTPException(
status_code=404, detail=f"Dewar with ID {dewar_data.dewar_id} not found"
)
update_fields = dewar_data.dict(exclude_unset=True) update_fields = dewar_data.dict(exclude_unset=True)
for key, value in update_fields.items(): for key, value in update_fields.items():
if key == 'contact_person_id': if key == "contact_person_id":
contact_person = db.query(ContactPersonModel).filter(ContactPersonModel.id == value).first() contact_person = (
db.query(ContactPersonModel)
.filter(ContactPersonModel.id == value)
.first()
)
if not contact_person: if not contact_person:
raise HTTPException(status_code=404, raise HTTPException(
detail=f"Contact person with ID {value} for Dewar {dewar_data.dewar_id} not found") status_code=404,
if key == 'return_address_id': detail=f"Contact person with ID {value} for Dewar {dewar_data.dewar_id} not found",
address = db.query(AddressModel).filter(AddressModel.id == value).first() )
if key == "return_address_id":
address = (
db.query(AddressModel).filter(AddressModel.id == value).first()
)
if not address: if not address:
raise HTTPException(status_code=404, raise HTTPException(
detail=f"Address with ID {value} for Dewar {dewar_data.dewar_id} not found") status_code=404,
detail=f"Address with ID {value} for Dewar {dewar_data.dewar_id} not found",
)
for key, value in update_fields.items(): for key, value in update_fields.items():
if key != 'dewar_id': if key != "dewar_id":
setattr(dewar, key, value) setattr(dewar, key, value)
db.commit() db.commit()
@ -150,7 +207,9 @@ async def update_shipment(shipment_id: int, updated_shipment: ShipmentCreate, db
@router.post("/{shipment_id}/add_dewar", response_model=ShipmentSchema) @router.post("/{shipment_id}/add_dewar", response_model=ShipmentSchema)
async def add_dewar_to_shipment(shipment_id: int, dewar_id: int, db: Session = Depends(get_db)): async def add_dewar_to_shipment(
shipment_id: int, dewar_id: int, db: Session = Depends(get_db)
):
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
@ -166,14 +225,18 @@ async def add_dewar_to_shipment(shipment_id: int, dewar_id: int, db: Session = D
@router.delete("/{shipment_id}/remove_dewar/{dewar_id}", response_model=ShipmentSchema) @router.delete("/{shipment_id}/remove_dewar/{dewar_id}", response_model=ShipmentSchema)
async def remove_dewar_from_shipment(shipment_id: int, dewar_id: int, db: Session = Depends(get_db)): async def remove_dewar_from_shipment(
shipment_id: int, dewar_id: int, db: Session = Depends(get_db)
):
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
dewar_exists = any(dw.id == dewar_id for dw in shipment.dewars) dewar_exists = any(dw.id == dewar_id for dw in shipment.dewars)
if not dewar_exists: if not dewar_exists:
raise HTTPException(status_code=404, detail=f"Dewar with ID {dewar_id} not found in shipment") raise HTTPException(
status_code=404, detail=f"Dewar with ID {dewar_id} not found in shipment"
)
shipment.dewars = [dw for dw in shipment.dewars if dw.id != dewar_id] shipment.dewars = [dw for dw in shipment.dewars if dw.id != dewar_id]
db.commit() db.commit()
@ -201,8 +264,13 @@ async def get_samples_in_shipment(shipment_id: int, db: Session = Depends(get_db
return samples return samples
@router.get("/shipments/{shipment_id}/dewars/{dewar_id}/samples", response_model=List[SampleSchema]) @router.get(
async def get_samples_in_dewar(shipment_id: int, dewar_id: int, db: Session = Depends(get_db)): "/shipments/{shipment_id}/dewars/{dewar_id}/samples",
response_model=List[SampleSchema],
)
async def get_samples_in_dewar(
shipment_id: int, dewar_id: int, db: Session = Depends(get_db)
):
shipment = get_shipment_by_id(db, shipment_id) shipment = get_shipment_by_id(db, shipment_id)
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
@ -220,8 +288,11 @@ async def get_samples_in_dewar(shipment_id: int, dewar_id: int, db: Session = De
@router.put("/{shipment_id}/comments", response_model=ShipmentSchema) @router.put("/{shipment_id}/comments", response_model=ShipmentSchema)
async def update_shipment_comments(shipment_id: int, comments_data: UpdateShipmentComments, async def update_shipment_comments(
db: Session = Depends(get_db)): shipment_id: int,
comments_data: UpdateShipmentComments,
db: Session = Depends(get_db),
):
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
@ -232,15 +303,25 @@ async def update_shipment_comments(shipment_id: int, comments_data: UpdateShipme
return shipment return shipment
@router.post("/{shipment_id}/add_dewar_puck_sample", response_model=ShipmentSchema, status_code=status.HTTP_201_CREATED) @router.post(
def add_dewar_puck_sample_to_shipment(shipment_id: int, payload: DewarCreate, db: Session = Depends(get_db)): "/{shipment_id}/add_dewar_puck_sample",
response_model=ShipmentSchema,
status_code=status.HTTP_201_CREATED,
)
def add_dewar_puck_sample_to_shipment(
shipment_id: int, payload: DewarCreate, db: Session = Depends(get_db)
):
shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first() shipment = db.query(ShipmentModel).filter(ShipmentModel.id == shipment_id).first()
if not shipment: if not shipment:
raise HTTPException(status_code=404, detail="Shipment not found") raise HTTPException(status_code=404, detail="Shipment not found")
try: try:
for dewar_data in payload.dewars: for dewar_data in payload.dewars:
dewar = db.query(DewarModel).filter(DewarModel.dewar_name == dewar_data.dewar_name).first() dewar = (
db.query(DewarModel)
.filter(DewarModel.dewar_name == dewar_data.dewar_name)
.first()
)
if dewar: if dewar:
# Update existing dewar # Update existing dewar
dewar.tracking_number = dewar_data.tracking_number dewar.tracking_number = dewar_data.tracking_number

View File

@ -1,7 +1,10 @@
from app.sample_models import SpreadsheetModel, SpreadsheetResponse from app.sample_models import SpreadsheetModel, SpreadsheetResponse
from fastapi import APIRouter, UploadFile, File, HTTPException from fastapi import APIRouter, UploadFile, File, HTTPException
import logging import logging
from app.services.spreadsheet_service import SampleSpreadsheetImporter, SpreadsheetImportError from app.services.spreadsheet_service import (
SampleSpreadsheetImporter,
SpreadsheetImportError,
)
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
import os import os
from pydantic import ValidationError # Import ValidationError here from pydantic import ValidationError # Import ValidationError here
@ -10,20 +13,27 @@ from app.row_storage import row_storage # Import the RowStorage instance
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
importer = SampleSpreadsheetImporter() # assuming this is a singleton or manageable instance importer = (
SampleSpreadsheetImporter()
) # assuming this is a singleton or manageable instance
@router.get("/download-template", response_class=FileResponse) @router.get("/download-template", response_class=FileResponse)
async def download_template(): async def download_template():
"""Serve a template file for spreadsheet upload.""" """Serve a template file for spreadsheet upload."""
current_dir = os.path.dirname(__file__) current_dir = os.path.dirname(__file__)
template_path = os.path.join(current_dir, "../../downloads/V7_TELLSamplesSpreadsheetTemplate.xlsx") template_path = os.path.join(
current_dir, "../../downloads/V7_TELLSamplesSpreadsheetTemplate.xlsx"
)
if not os.path.exists(template_path): if not os.path.exists(template_path):
raise HTTPException(status_code=404, detail="Template file not found.") raise HTTPException(status_code=404, detail="Template file not found.")
return FileResponse(template_path, filename="template.xlsx", return FileResponse(
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") template_path,
filename="template.xlsx",
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
@router.post("/upload", response_model=SpreadsheetResponse) @router.post("/upload", response_model=SpreadsheetResponse)
@ -33,17 +43,24 @@ async def upload_file(file: UploadFile = File(...)):
logger.info(f"Received file: {file.filename}") logger.info(f"Received file: {file.filename}")
# Validate file format # Validate file format
if not file.filename.endswith('.xlsx'): if not file.filename.endswith(".xlsx"):
logger.error("Invalid file format") logger.error("Invalid file format")
raise HTTPException(status_code=400, detail="Invalid file format. Please upload an .xlsx file.") raise HTTPException(
status_code=400,
detail="Invalid file format. Please upload an .xlsx file.",
)
# Initialize the importer and process the spreadsheet # Initialize the importer and process the spreadsheet
validated_model, errors, raw_data, headers = importer.import_spreadsheet_with_errors(file) validated_model, errors, raw_data, headers = (
importer.import_spreadsheet_with_errors(file)
)
# Extract unique values for dewars, pucks, and samples # Extract unique values for dewars, pucks, and samples
dewars = {sample.dewarname for sample in validated_model if sample.dewarname} dewars = {sample.dewarname for sample in validated_model if sample.dewarname}
pucks = {sample.puckname for sample in validated_model if sample.puckname} pucks = {sample.puckname for sample in validated_model if sample.puckname}
samples = {sample.crystalname for sample in validated_model if sample.crystalname} samples = {
sample.crystalname for sample in validated_model if sample.crystalname
}
# Construct the response model with the processed data # Construct the response model with the processed data
response_data = SpreadsheetResponse( response_data = SpreadsheetResponse(
@ -56,7 +73,7 @@ async def upload_file(file: UploadFile = File(...)):
pucks=list(pucks), pucks=list(pucks),
samples_count=len(samples), samples_count=len(samples),
samples=list(samples), samples=list(samples),
headers=headers # Include headers in the response headers=headers, # Include headers in the response
) )
# Store row data for future use # Store row data for future use
@ -64,16 +81,23 @@ async def upload_file(file: UploadFile = File(...)):
row_num = idx + 4 # Adjust row numbering if necessary row_num = idx + 4 # Adjust row numbering if necessary
row_storage.set_row(row_num, row.dict()) row_storage.set_row(row_num, row.dict())
logger.info(f"Returning response with {len(validated_model)} records and {len(errors)} errors.") logger.info(
f"Returning response with {len(validated_model)} records and {len(errors)} errors."
)
return response_data return response_data
except SpreadsheetImportError as e: except SpreadsheetImportError as e:
logger.error(f"Spreadsheet import error: {str(e)}") logger.error(f"Spreadsheet import error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Error processing spreadsheet: {str(e)}") raise HTTPException(
status_code=400, detail=f"Error processing spreadsheet: {str(e)}"
)
except Exception as e: except Exception as e:
logger.error(f"Unexpected error occurred: {str(e)}") logger.error(f"Unexpected error occurred: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to upload file. Please try again. Error: {str(e)}") raise HTTPException(
status_code=500,
detail=f"Failed to upload file. Please try again. Error: {str(e)}",
)
@router.post("/validate-cell") @router.post("/validate-cell")
@ -86,7 +110,9 @@ async def validate_cell(data: dict):
current_row_data = row_storage.get_row(row_num) current_row_data = row_storage.get_row(row_num)
# Update the cell value # Update the cell value
current_row_data[col_name] = importer._clean_value(value, importer.get_expected_type(col_name)) current_row_data[col_name] = importer._clean_value(
value, importer.get_expected_type(col_name)
)
# Temporarily store the updated row data # Temporarily store the updated row data
row_storage.set_row(row_num, current_row_data) row_storage.set_row(row_num, current_row_data)
@ -100,6 +126,8 @@ async def validate_cell(data: dict):
return {"is_valid": True, "message": ""} return {"is_valid": True, "message": ""}
except ValidationError as e: except ValidationError as e:
# Extract the first error message # Extract the first error message
message = e.errors()[0]['msg'] message = e.errors()[0]["msg"]
logger.error(f"Validation failed for row {row_num}, column {col_name}: {message}") logger.error(
f"Validation failed for row {row_num}, column {col_name}: {message}"
)
return {"is_valid": False, "message": message} return {"is_valid": False, "message": message}

View File

@ -6,16 +6,17 @@ from typing_extensions import Annotated
class SpreadsheetModel(BaseModel): class SpreadsheetModel(BaseModel):
dewarname: str = Field(..., alias='dewarname') dewarname: str = Field(..., alias="dewarname")
puckname: str = Field(..., alias='puckname') puckname: str = Field(..., alias="puckname")
pucktype: Optional[str] = Field(None, alias="pucktype") pucktype: Optional[str] = Field(None, alias="pucktype")
crystalname: Annotated[ crystalname: Annotated[
str, str,
Field(..., Field(
...,
max_length=64, max_length=64,
title="Crystal Name", title="Crystal Name",
description="max_length imposed by MTZ file header format https://www.ccp4.ac.uk/html/mtzformat.html", description="max_length imposed by MTZ file header format https://www.ccp4.ac.uk/html/mtzformat.html",
alias='crystalname' alias="crystalname",
), ),
] ]
positioninpuck: int # Only accept positive integers between 1 and 16 positioninpuck: int # Only accept positive integers between 1 and 16
@ -26,17 +27,31 @@ class SpreadsheetModel(BaseModel):
oscillation: Optional[float] = None # Only accept positive float oscillation: Optional[float] = None # Only accept positive float
exposure: Optional[float] = None # Only accept positive floats between 0 and 1 exposure: Optional[float] = None # Only accept positive floats between 0 and 1
totalrange: Optional[int] = None # Only accept positive integers between 0 and 360 totalrange: Optional[int] = None # Only accept positive integers between 0 and 360
transmission: Optional[int] = None # Only accept positive integers between 0 and 100 transmission: Optional[int] = (
None # Only accept positive integers between 0 and 100
)
targetresolution: Optional[float] = None # Only accept positive float targetresolution: Optional[float] = None # Only accept positive float
aperture: Optional[str] = None # Optional string field aperture: Optional[str] = None # Optional string field
datacollectiontype: Optional[str] = None # Only accept "standard", other types might be added later datacollectiontype: Optional[str] = (
processingpipeline: Optional[str] = "" # Only accept "gopy", "autoproc", "xia2dials" None # Only accept "standard", other types might be added later
spacegroupnumber: Optional[int] = None # Only accept positive integers between 1 and 230 )
cellparameters: Optional[str] = None # Must be a set of six positive floats or integers processingpipeline: Optional[str] = (
"" # Only accept "gopy", "autoproc", "xia2dials"
)
spacegroupnumber: Optional[int] = (
None # Only accept positive integers between 1 and 230
)
cellparameters: Optional[str] = (
None # Must be a set of six positive floats or integers
)
rescutkey: Optional[str] = None # Only accept "is" or "cchalf" rescutkey: Optional[str] = None # Only accept "is" or "cchalf"
rescutvalue: Optional[float] = None # Must be a positive float if rescutkey is provided rescutvalue: Optional[float] = (
None # Must be a positive float if rescutkey is provided
)
userresolution: Optional[float] = None userresolution: Optional[float] = None
pdbid: Optional[str] = "" # Accepts either the format of the protein data bank code or {provided} pdbid: Optional[str] = (
"" # Accepts either the format of the protein data bank code or {provided}
)
autoprocfull: Optional[bool] = None autoprocfull: Optional[bool] = None
procfull: Optional[bool] = None procfull: Optional[bool] = None
adpenabled: Optional[bool] = None adpenabled: Optional[bool] = None
@ -48,7 +63,7 @@ class SpreadsheetModel(BaseModel):
dose: Optional[float] = None # Optional float field dose: Optional[float] = None # Optional float field
# Add pucktype validation # Add pucktype validation
@field_validator('pucktype', mode="before") @field_validator("pucktype", mode="before")
@classmethod @classmethod
def validate_pucktype(cls, v): def validate_pucktype(cls, v):
if v != "unipuck": if v != "unipuck":
@ -56,7 +71,7 @@ class SpreadsheetModel(BaseModel):
return v return v
# Validators # Validators
@field_validator('dewarname', 'puckname', mode="before") @field_validator("dewarname", "puckname", mode="before")
@classmethod @classmethod
def dewarname_puckname_characters(cls, v): def dewarname_puckname_characters(cls, v):
if v: if v:
@ -67,17 +82,19 @@ class SpreadsheetModel(BaseModel):
return v return v
raise ValueError("Value must be provided for dewarname and puckname.") raise ValueError("Value must be provided for dewarname and puckname.")
@field_validator('crystalname', mode="before") @field_validator("crystalname", mode="before")
@classmethod @classmethod
def parameter_characters(cls, v): def parameter_characters(cls, v):
v = str(v).replace(" ", "_") v = str(v).replace(" ", "_")
if re.search("\n", v): if re.search("\n", v):
assert v.isalnum(), "is not valid. newline character detected." assert v.isalnum(), "is not valid. newline character detected."
characters = re.sub("[._+-]", "", v) characters = re.sub("[._+-]", "", v)
assert characters.isalnum(), f" '{v}' is not valid. Only alphanumeric and . _ + - characters allowed." assert (
characters.isalnum()
), f" '{v}' is not valid. Only alphanumeric and . _ + - characters allowed."
return v return v
@field_validator('directory', mode="before") @field_validator("directory", mode="before")
@classmethod @classmethod
def directory_characters(cls, v): def directory_characters(cls, v):
if v: if v:
@ -85,37 +102,57 @@ class SpreadsheetModel(BaseModel):
if re.search("\n", v): if re.search("\n", v):
raise ValueError(f" '{v}' is not valid. newline character detected.") raise ValueError(f" '{v}' is not valid. newline character detected.")
valid_macros = ["{date}", "{prefix}", "{sgpuck}", "{puck}", "{beamline}", "{sgprefix}", valid_macros = [
"{sgpriority}", "{sgposition}", "{protein}", "{method}"] "{date}",
"{prefix}",
"{sgpuck}",
"{puck}",
"{beamline}",
"{sgprefix}",
"{sgpriority}",
"{sgposition}",
"{protein}",
"{method}",
]
pattern = re.compile("|".join(re.escape(macro) for macro in valid_macros)) pattern = re.compile("|".join(re.escape(macro) for macro in valid_macros))
v = pattern.sub('macro', v) v = pattern.sub("macro", v)
allowed_chars = "[a-z0-9_.+-]" allowed_chars = "[a-z0-9_.+-]"
directory_re = re.compile(f"^(({allowed_chars}*|{allowed_chars}+)*/*)*$", re.IGNORECASE) directory_re = re.compile(
f"^(({allowed_chars}*|{allowed_chars}+)*/*)*$", re.IGNORECASE
)
if not directory_re.match(v): if not directory_re.match(v):
raise ValueError(f" '{v}' is not valid. Value must be a valid path or macro.") raise ValueError(
f" '{v}' is not valid. Value must be a valid path or macro."
)
return v return v
@field_validator('positioninpuck', mode="before") @field_validator("positioninpuck", mode="before")
@classmethod @classmethod
def positioninpuck_possible(cls, v): def positioninpuck_possible(cls, v):
if not isinstance(v, int) or v < 1 or v > 16: if not isinstance(v, int) or v < 1 or v > 16:
raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 16.") raise ValueError(
f" '{v}' is not valid. Value must be an integer between 1 and 16."
)
return v return v
@field_validator('priority', mode="before") @field_validator("priority", mode="before")
@classmethod @classmethod
def priority_positive(cls, v): def priority_positive(cls, v):
if v is not None: if v is not None:
try: try:
v = int(v) v = int(v)
if v <= 0: if v <= 0:
raise ValueError(f" '{v}' is not valid. Value must be a positive integer.") raise ValueError(
f" '{v}' is not valid. Value must be a positive integer."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a positive integer.") from e raise ValueError(
f" '{v}' is not valid. Value must be a positive integer."
) from e
return v return v
@field_validator('aperture', mode="before") @field_validator("aperture", mode="before")
@classmethod @classmethod
def aperture_selection(cls, v): def aperture_selection(cls, v):
if v is not None: if v is not None:
@ -124,58 +161,76 @@ class SpreadsheetModel(BaseModel):
if v not in {1, 2, 3}: if v not in {1, 2, 3}:
raise ValueError(f" '{v}' is not valid. Value must be 1, 2, or 3.") raise ValueError(f" '{v}' is not valid. Value must be 1, 2, or 3.")
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be 1, 2, or 3.") from e raise ValueError(
f" '{v}' is not valid. Value must be 1, 2, or 3."
) from e
return v return v
@field_validator('oscillation', 'targetresolution', mode="before") @field_validator("oscillation", "targetresolution", mode="before")
@classmethod @classmethod
def positive_float_validator(cls, v): def positive_float_validator(cls, v):
if v is not None: if v is not None:
try: try:
v = float(v) v = float(v)
if v <= 0: if v <= 0:
raise ValueError(f" '{v}' is not valid. Value must be a positive float.") raise ValueError(
f" '{v}' is not valid. Value must be a positive float."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a positive float.") from e raise ValueError(
f" '{v}' is not valid. Value must be a positive float."
) from e
return v return v
@field_validator('exposure', mode="before") @field_validator("exposure", mode="before")
@classmethod @classmethod
def exposure_in_range(cls, v): def exposure_in_range(cls, v):
if v is not None: if v is not None:
try: try:
v = float(v) v = float(v)
if not (0 <= v <= 1): if not (0 <= v <= 1):
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 1.") raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 1."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 1.") from e raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 1."
) from e
return v return v
@field_validator('totalrange', mode="before") @field_validator("totalrange", mode="before")
@classmethod @classmethod
def totalrange_in_range(cls, v): def totalrange_in_range(cls, v):
if v is not None: if v is not None:
try: try:
v = int(v) v = int(v)
if not (0 <= v <= 360): if not (0 <= v <= 360):
raise ValueError(f" '{v}' is not valid. Value must be an integer between 0 and 360.") raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 360."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be an integer between 0 and 360.") from e raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 360."
) from e
return v return v
@field_validator('transmission', mode="before") @field_validator("transmission", mode="before")
@classmethod @classmethod
def transmission_fraction(cls, v): def transmission_fraction(cls, v):
if v is not None: if v is not None:
try: try:
v = int(v) v = int(v)
if not (0 <= v <= 100): if not (0 <= v <= 100):
raise ValueError(f" '{v}' is not valid. Value must be an integer between 0 and 100.") raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 100."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be an integer between 0 and 100.") from e raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 100."
) from e
return v return v
@field_validator('datacollectiontype', mode="before") @field_validator("datacollectiontype", mode="before")
@classmethod @classmethod
def datacollectiontype_allowed(cls, v): def datacollectiontype_allowed(cls, v):
allowed = {"standard"} # Other types of data collection might be added later allowed = {"standard"} # Other types of data collection might be added later
@ -183,7 +238,7 @@ class SpreadsheetModel(BaseModel):
raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.") raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.")
return v return v
@field_validator('processingpipeline', mode="before") @field_validator("processingpipeline", mode="before")
@classmethod @classmethod
def processingpipeline_allowed(cls, v): def processingpipeline_allowed(cls, v):
allowed = {"gopy", "autoproc", "xia2dials"} allowed = {"gopy", "autoproc", "xia2dials"}
@ -191,73 +246,93 @@ class SpreadsheetModel(BaseModel):
raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.") raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.")
return v return v
@field_validator('spacegroupnumber', mode="before") @field_validator("spacegroupnumber", mode="before")
@classmethod @classmethod
def spacegroupnumber_allowed(cls, v): def spacegroupnumber_allowed(cls, v):
if v is not None: if v is not None:
try: try:
v = int(v) v = int(v)
if not (1 <= v <= 230): if not (1 <= v <= 230):
raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") raise ValueError(
f" '{v}' is not valid. Value must be an integer between 1 and 230."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be an integer between 1 and 230.") from e raise ValueError(
f" '{v}' is not valid. Value must be an integer between 1 and 230."
) from e
return v return v
@field_validator('cellparameters', mode="before") @field_validator("cellparameters", mode="before")
@classmethod @classmethod
def cellparameters_format(cls, v): def cellparameters_format(cls, v):
if v: if v:
values = [float(i) for i in v.split(",")] values = [float(i) for i in v.split(",")]
if len(values) != 6 or any(val <= 0 for val in values): if len(values) != 6 or any(val <= 0 for val in values):
raise ValueError(f" '{v}' is not valid. Value must be a set of six positive floats or integers.") raise ValueError(
f" '{v}' is not valid. Value must be a set of six positive floats or integers."
)
return v return v
@field_validator('rescutkey', 'rescutvalue', mode="before") @field_validator("rescutkey", "rescutvalue", mode="before")
@classmethod @classmethod
def rescutkey_value_pair(cls, values): def rescutkey_value_pair(cls, values):
rescutkey = values.get('rescutkey') rescutkey = values.get("rescutkey")
rescutvalue = values.get('rescutvalue') rescutvalue = values.get("rescutvalue")
if rescutkey and rescutvalue: if rescutkey and rescutvalue:
if rescutkey not in {"is", "cchalf"}: if rescutkey not in {"is", "cchalf"}:
raise ValueError("Rescutkey must be either 'is' or 'cchalf'") raise ValueError("Rescutkey must be either 'is' or 'cchalf'")
if not isinstance(rescutvalue, float) or rescutvalue <= 0: if not isinstance(rescutvalue, float) or rescutvalue <= 0:
raise ValueError("Rescutvalue must be a positive float if rescutkey is provided") raise ValueError(
"Rescutvalue must be a positive float if rescutkey is provided"
)
return values return values
@field_validator('trustedhigh', mode="before") @field_validator("trustedhigh", mode="before")
@classmethod @classmethod
def trustedhigh_allowed(cls, v): def trustedhigh_allowed(cls, v):
if v is not None: if v is not None:
try: try:
v = float(v) v = float(v)
if not (0 <= v <= 2.0): if not (0 <= v <= 2.0):
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 2.0."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 2.0.") from e raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 2.0."
) from e
return v return v
@field_validator('chiphiangles', mode="before") @field_validator("chiphiangles", mode="before")
@classmethod @classmethod
def chiphiangles_allowed(cls, v): def chiphiangles_allowed(cls, v):
if v is not None: if v is not None:
try: try:
v = float(v) v = float(v)
if not (0 <= v <= 30): if not (0 <= v <= 30):
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 30."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a float between 0 and 30.") from e raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 30."
) from e
return v return v
@field_validator('dose', mode="before") @field_validator("dose", mode="before")
@classmethod @classmethod
def dose_positive(cls, v): def dose_positive(cls, v):
if v is not None: if v is not None:
try: try:
v = float(v) v = float(v)
if v <= 0: if v <= 0:
raise ValueError(f" '{v}' is not valid. Value must be a positive float.") raise ValueError(
f" '{v}' is not valid. Value must be a positive float."
)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise ValueError(f" '{v}' is not valid. Value must be a positive float.") from e raise ValueError(
f" '{v}' is not valid. Value must be a positive float."
) from e
return v return v
class TELLModel(SpreadsheetModel): class TELLModel(SpreadsheetModel):
@ -270,6 +345,7 @@ class SpreadsheetModel(BaseModel):
prefix: Optional[str] prefix: Optional[str]
folder: Optional[str] folder: Optional[str]
class SpreadsheetResponse(BaseModel): class SpreadsheetResponse(BaseModel):
data: List[SpreadsheetModel] # Validated data rows as SpreadsheetModel instances data: List[SpreadsheetModel] # Validated data rows as SpreadsheetModel instances
errors: List[Dict[str, Any]] # Errors encountered during validation errors: List[Dict[str, Any]] # Errors encountered during validation
@ -283,4 +359,4 @@ class SpreadsheetResponse(BaseModel):
headers: Optional[List[str]] = None # Add headers if needed headers: Optional[List[str]] = None # Add headers if needed
__all__ = ['SpreadsheetModel', 'SpreadsheetResponse'] __all__ = ["SpreadsheetModel", "SpreadsheetResponse"]

View File

@ -8,10 +8,12 @@ class loginToken(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class loginData(BaseModel): class loginData(BaseModel):
username: str username: str
pgroups: List[int] pgroups: List[int]
class DewarTypeBase(BaseModel): class DewarTypeBase(BaseModel):
dewar_type: str dewar_type: str
@ -76,9 +78,11 @@ class DataCollectionParameters(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class SampleEventCreate(BaseModel): class SampleEventCreate(BaseModel):
event_type: str event_type: str
class Results(BaseModel): class Results(BaseModel):
# Define attributes for Results here # Define attributes for Results here
pass pass
@ -150,6 +154,7 @@ class SampleCreate(BaseModel):
class Config: class Config:
populate_by_name = True populate_by_name = True
class PuckEvent(BaseModel): class PuckEvent(BaseModel):
id: int id: int
puck_id: int puck_id: int
@ -160,6 +165,7 @@ class PuckEvent(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class PuckBase(BaseModel): class PuckBase(BaseModel):
puck_name: str puck_name: str
puck_type: str puck_type: str
@ -299,6 +305,7 @@ class LogisticsEventCreate(BaseModel):
location_qr_code: str location_qr_code: str
transaction_type: str transaction_type: str
class SlotSchema(BaseModel): class SlotSchema(BaseModel):
id: int id: int
qr_code: str qr_code: str
@ -319,9 +326,10 @@ class SlotSchema(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class SetTellPosition(BaseModel): class SetTellPosition(BaseModel):
tell_position: str = Field( tell_position: str = Field(
..., ...,
pattern="^[A-F][1-5]$|^null$|^None$", # Use 'pattern' instead of 'regex' pattern="^[A-F][1-5]$|^null$|^None$", # Use 'pattern' instead of 'regex'
description="Valid values are A1-A5, B1-B5, ..., F1-F5, or null." description="Valid values are A1-A5, B1-B5, ..., F1-F5, or null.",
) )

View File

@ -48,12 +48,13 @@ class ShipmentProcessor:
for sample_data in puck_data.samples: for sample_data in puck_data.samples:
data_collection_params = DataCollectionParameters( data_collection_params = DataCollectionParameters(
**sample_data.data_collection_parameters.dict(by_alias=True)) **sample_data.data_collection_parameters.dict(by_alias=True)
)
sample = Sample( sample = Sample(
puck_id=puck.id, puck_id=puck.id,
sample_name=sample_data.sample_name, sample_name=sample_data.sample_name,
position=sample_data.position, position=sample_data.position,
data_collection_parameters=data_collection_params data_collection_parameters=data_collection_params,
) )
self.db.add(sample) self.db.add(sample)
self.db.commit() self.db.commit()
@ -62,7 +63,7 @@ class ShipmentProcessor:
return ShipmentResponse( return ShipmentResponse(
shipment_id=new_shipment.id, shipment_id=new_shipment.id,
status="success", status="success",
message="Shipment processed successfully" message="Shipment processed successfully",
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing shipment: {str(e)}") logger.error(f"Error processing shipment: {str(e)}")

View File

@ -34,7 +34,7 @@ class SampleSpreadsheetImporter:
if isinstance(value, str): if isinstance(value, str):
try: try:
# Handle numeric strings # Handle numeric strings
if '.' in value: if "." in value:
return float(value) return float(value)
else: else:
return int(value) return int(value)
@ -50,16 +50,18 @@ class SampleSpreadsheetImporter:
def get_expected_type(self, col_name): def get_expected_type(self, col_name):
type_mapping = { type_mapping = {
'dewarname': str, "dewarname": str,
'puckname': str, "puckname": str,
'positioninpuck': int, "positioninpuck": int,
'priority': int, "priority": int,
'oscillation': float, "oscillation": float,
# Add all other mappings based on model requirements # Add all other mappings based on model requirements
} }
return type_mapping.get(col_name, str) # Default to `str` return type_mapping.get(col_name, str) # Default to `str`
def import_spreadsheet_with_errors(self, file) -> Tuple[List[SpreadsheetModel], List[dict], List[dict], List[str]]: def import_spreadsheet_with_errors(
self, file
) -> Tuple[List[SpreadsheetModel], List[dict], List[dict], List[str]]:
self.model = [] self.model = []
self.filename = file.filename self.filename = file.filename
logger.info(f"Importing spreadsheet from .xlsx file: {self.filename}") logger.info(f"Importing spreadsheet from .xlsx file: {self.filename}")
@ -88,7 +90,9 @@ class SampleSpreadsheetImporter:
# Now, return the values correctly # Now, return the values correctly
return model, errors, raw_data, headers return model, errors, raw_data, headers
def process_spreadsheet(self, sheet) -> Tuple[List[SpreadsheetModel], List[dict], List[dict], List[str]]: def process_spreadsheet(
self, sheet
) -> Tuple[List[SpreadsheetModel], List[dict], List[dict], List[str]]:
model = [] model = []
errors = [] errors = []
raw_data = [] raw_data = []
@ -106,12 +110,38 @@ class SampleSpreadsheetImporter:
# Add the headers (the first row in the spreadsheet or map them explicitly) # Add the headers (the first row in the spreadsheet or map them explicitly)
headers = [ headers = [
'dewarname', 'puckname', 'pucktype', 'crystalname', 'positioninpuck', 'priority', "dewarname",
'comments', 'directory', 'proteinname', 'oscillation', 'aperture', 'exposure', "puckname",
'totalrange', 'transmission', 'dose', 'targetresolution', 'datacollectiontype', "pucktype",
'processingpipeline', 'spacegroupnumber', 'cellparameters', 'rescutkey', 'rescutvalue', "crystalname",
'userresolution', 'pdbid', 'autoprocfull', 'procfull', 'adpenabled', 'noano', "positioninpuck",
'ffcscampaign', 'trustedhigh', 'autoprocextraparams', 'chiphiangles' "priority",
"comments",
"directory",
"proteinname",
"oscillation",
"aperture",
"exposure",
"totalrange",
"transmission",
"dose",
"targetresolution",
"datacollectiontype",
"processingpipeline",
"spacegroupnumber",
"cellparameters",
"rescutkey",
"rescutvalue",
"userresolution",
"pdbid",
"autoprocfull",
"procfull",
"adpenabled",
"noano",
"ffcscampaign",
"trustedhigh",
"autoprocextraparams",
"chiphiangles",
] ]
for index, row in enumerate(rows): for index, row in enumerate(rows):
@ -128,38 +158,38 @@ class SampleSpreadsheetImporter:
# Prepare the record with the cleaned values # Prepare the record with the cleaned values
record = { record = {
'dewarname': self._clean_value(row[0], str), "dewarname": self._clean_value(row[0], str),
'puckname': self._clean_value(row[1], str), "puckname": self._clean_value(row[1], str),
'pucktype': self._clean_value(row[2], str), "pucktype": self._clean_value(row[2], str),
'crystalname': self._clean_value(row[3], str), "crystalname": self._clean_value(row[3], str),
'positioninpuck': self._clean_value(row[4], int), "positioninpuck": self._clean_value(row[4], int),
'priority': self._clean_value(row[5], int), "priority": self._clean_value(row[5], int),
'comments': self._clean_value(row[6], str), "comments": self._clean_value(row[6], str),
'directory': self._clean_value(row[7], str), "directory": self._clean_value(row[7], str),
'proteinname': self._clean_value(row[8], str), "proteinname": self._clean_value(row[8], str),
'oscillation': self._clean_value(row[9], float), "oscillation": self._clean_value(row[9], float),
'aperture': self._clean_value(row[10], str), "aperture": self._clean_value(row[10], str),
'exposure': self._clean_value(row[11], float), "exposure": self._clean_value(row[11], float),
'totalrange': self._clean_value(row[12], float), "totalrange": self._clean_value(row[12], float),
'transmission': self._clean_value(row[13], int), "transmission": self._clean_value(row[13], int),
'dose': self._clean_value(row[14], float), "dose": self._clean_value(row[14], float),
'targetresolution': self._clean_value(row[15], float), "targetresolution": self._clean_value(row[15], float),
'datacollectiontype': self._clean_value(row[16], str), "datacollectiontype": self._clean_value(row[16], str),
'processingpipeline': self._clean_value(row[17], str), "processingpipeline": self._clean_value(row[17], str),
'spacegroupnumber': self._clean_value(row[18], int), "spacegroupnumber": self._clean_value(row[18], int),
'cellparameters': self._clean_value(row[19], str), "cellparameters": self._clean_value(row[19], str),
'rescutkey': self._clean_value(row[20], str), "rescutkey": self._clean_value(row[20], str),
'rescutvalue': self._clean_value(row[21], str), "rescutvalue": self._clean_value(row[21], str),
'userresolution': self._clean_value(row[22], str), "userresolution": self._clean_value(row[22], str),
'pdbid': self._clean_value(row[23], str), "pdbid": self._clean_value(row[23], str),
'autoprocfull': self._clean_value(row[24], str), "autoprocfull": self._clean_value(row[24], str),
'procfull': self._clean_value(row[25], str), "procfull": self._clean_value(row[25], str),
'adpenabled': self._clean_value(row[26], str), "adpenabled": self._clean_value(row[26], str),
'noano': self._clean_value(row[27], str), "noano": self._clean_value(row[27], str),
'ffcscampaign': self._clean_value(row[28], str), "ffcscampaign": self._clean_value(row[28], str),
'trustedhigh': self._clean_value(row[29], str), "trustedhigh": self._clean_value(row[29], str),
'autoprocextraparams': self._clean_value(row[30], str), "autoprocextraparams": self._clean_value(row[30], str),
'chiphiangles': self._clean_value(row[31], str) "chiphiangles": self._clean_value(row[31], str),
} }
try: try:
@ -169,52 +199,54 @@ class SampleSpreadsheetImporter:
except ValidationError as e: except ValidationError as e:
logger.error(f"Validation error in row {index + 4}: {e}") logger.error(f"Validation error in row {index + 4}: {e}")
for error in e.errors(): for error in e.errors():
field = error['loc'][0] field = error["loc"][0]
msg = error['msg'] msg = error["msg"]
# Map field name (which is the key in `record`) to its index in the row # Map field name (which is the key in `record`) to its index in the row
field_to_col = { field_to_col = {
'dewarname': 0, "dewarname": 0,
'puckname': 1, "puckname": 1,
'pucktype': 2, "pucktype": 2,
'crystalname': 3, "crystalname": 3,
'positioninpuck': 4, "positioninpuck": 4,
'priority': 5, "priority": 5,
'comments': 6, "comments": 6,
'directory': 7, "directory": 7,
'proteinname': 8, "proteinname": 8,
'oscillation': 9, "oscillation": 9,
'aperture': 10, "aperture": 10,
'exposure': 11, "exposure": 11,
'totalrange': 12, "totalrange": 12,
'transmission': 13, "transmission": 13,
'dose': 14, "dose": 14,
'targetresolution': 15, "targetresolution": 15,
'datacollectiontype': 16, "datacollectiontype": 16,
'processingpipeline': 17, "processingpipeline": 17,
'spacegroupnumber': 18, "spacegroupnumber": 18,
'cellparameters': 19, "cellparameters": 19,
'rescutkey': 20, "rescutkey": 20,
'rescutvalue': 21, "rescutvalue": 21,
'userresolution': 22, "userresolution": 22,
'pdbid': 23, "pdbid": 23,
'autoprocfull': 24, "autoprocfull": 24,
'procfull': 25, "procfull": 25,
'adpenabled': 26, "adpenabled": 26,
'noano': 27, "noano": 27,
'ffcscampaign': 28, "ffcscampaign": 28,
'trustedhigh': 29, "trustedhigh": 29,
'autoprocextraparams': 30, "autoprocextraparams": 30,
'chiphiangles': 31 "chiphiangles": 31,
} }
column_index = field_to_col[field] column_index = field_to_col[field]
error_info = { error_info = {
'row': index + 4, "row": index + 4,
'cell': column_index, "cell": column_index,
'value': row[column_index], # Value that caused the error "value": row[column_index], # Value that caused the error
'message': msg "message": msg,
} }
errors.append(error_info) errors.append(error_info)
self.model = model self.model = model
logger.info(f"Finished processing {len(model)} records with {len(errors)} errors") logger.info(
f"Finished processing {len(model)} records with {len(errors)} errors"
)
return self.model, errors, raw_data, headers # Include headers in the response return self.model, errors, raw_data, headers # Include headers in the response

View File

@ -15,37 +15,42 @@ def generate_self_signed_cert(cert_file: str, key_file: str):
) )
# Write private key to file # Write private key to file
with open(key_file, "wb") as f: with open(key_file, "wb") as f:
f.write(key.private_bytes( f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(), encryption_algorithm=serialization.NoEncryption(),
)) )
)
# Generate self-signed certificate # Generate self-signed certificate
subject = issuer = x509.Name([ subject = issuer = x509.Name(
x509.NameAttribute(NameOID.COUNTRY_NAME, u"CH"), [
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"Argau"), x509.NameAttribute(NameOID.COUNTRY_NAME, "CH"),
x509.NameAttribute(NameOID.LOCALITY_NAME, u"Villigen"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Argau"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Paul Scherrer Institut"), x509.NameAttribute(NameOID.LOCALITY_NAME, "Villigen"),
x509.NameAttribute(NameOID.COMMON_NAME, u"PSI.CH"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Paul Scherrer Institut"),
]) x509.NameAttribute(NameOID.COMMON_NAME, "PSI.CH"),
cert = x509.CertificateBuilder().subject_name( ]
subject )
).issuer_name( cert = (
issuer x509.CertificateBuilder()
).public_key( .subject_name(subject)
key.public_key() .issuer_name(issuer)
).serial_number( .public_key(key.public_key())
x509.random_serial_number() .serial_number(x509.random_serial_number())
).not_valid_before( .not_valid_before(datetime.datetime.utcnow())
datetime.datetime.utcnow() .not_valid_after(
).not_valid_after(
# Our certificate will be valid for 10 days # Our certificate will be valid for 10 days
datetime.datetime.utcnow() + datetime.timedelta(days=10) datetime.datetime.utcnow()
).add_extension( + datetime.timedelta(days=10)
x509.SubjectAlternativeName([x509.DNSName(u"localhost")]), )
.add_extension(
x509.SubjectAlternativeName([x509.DNSName("localhost")]),
critical=False, critical=False,
).sign(key, hashes.SHA256()) )
.sign(key, hashes.SHA256())
)
# Write certificate to file # Write certificate to file
with open(cert_file, "wb") as f: with open(cert_file, "wb") as f:

View File

@ -8,24 +8,35 @@ from fastapi.middleware.cors import CORSMiddleware
from app import ssl_heidi from app import ssl_heidi
from app.routers import address, contact, proposal, dewar, shipment, puck, spreadsheet, logistics, auth, sample from app.routers import (
address,
contact,
proposal,
dewar,
shipment,
puck,
spreadsheet,
logistics,
auth,
sample,
)
from app.database import Base, engine, SessionLocal, load_sample_data from app.database import Base, engine, SessionLocal, load_sample_data
app = FastAPI() app = FastAPI()
# Determine environment and configuration file path # Determine environment and configuration file path
environment = os.getenv('ENVIRONMENT', 'dev') environment = os.getenv("ENVIRONMENT", "dev")
config_file = Path(__file__).resolve().parent.parent / f'config_{environment}.json' config_file = Path(__file__).resolve().parent.parent / f"config_{environment}.json"
# Load configuration # Load configuration
with open(config_file) as f: with open(config_file) as f:
config = json.load(f) config = json.load(f)
cert_path = config['ssl_cert_path'] cert_path = config["ssl_cert_path"]
key_path = config['ssl_key_path'] key_path = config["ssl_key_path"]
# Generate SSL Key and Certificate if not exist (only for development) # Generate SSL Key and Certificate if not exist (only for development)
if environment == 'dev': if environment == "dev":
Path("ssl").mkdir(parents=True, exist_ok=True) Path("ssl").mkdir(parents=True, exist_ok=True)
if not Path(cert_path).exists() or not Path(key_path).exists(): if not Path(cert_path).exists() or not Path(key_path).exists():
ssl_heidi.generate_self_signed_cert(cert_path, key_path) ssl_heidi.generate_self_signed_cert(cert_path, key_path)
@ -81,13 +92,13 @@ if __name__ == "__main__":
import os import os
# Get environment from an environment variable # Get environment from an environment variable
environment = os.getenv('ENVIRONMENT', 'dev') environment = os.getenv("ENVIRONMENT", "dev")
# Paths for SSL certificates # Paths for SSL certificates
cert_path = "ssl/cert.pem" cert_path = "ssl/cert.pem"
key_path = "ssl/key.pem" key_path = "ssl/key.pem"
if environment == 'test': if environment == "test":
cert_path = "ssl/mx-aare-test.psi.ch.pem" cert_path = "ssl/mx-aare-test.psi.ch.pem"
key_path = "ssl/mx-aare-test.psi.ch.key" key_path = "ssl/mx-aare-test.psi.ch.key"
host = "0.0.0.0" # Bind to all interfaces host = "0.0.0.0" # Bind to all interfaces

View File

@ -85,113 +85,193 @@ class Shipment(BaseModel):
# Example data for contacts # Example data for contacts
contacts = [ contacts = [
ContactPerson(id=1, firstname="Frodo", lastname="Baggins", phone_number="123-456-7890", ContactPerson(
email="frodo.baggins@lotr.com"), id=1,
ContactPerson(id=2, firstname="Samwise", lastname="Gamgee", phone_number="987-654-3210", firstname="Frodo",
email="samwise.gamgee@lotr.com"), lastname="Baggins",
ContactPerson(id=3, firstname="Aragorn", lastname="Elessar", phone_number="123-333-4444", phone_number="123-456-7890",
email="aragorn.elessar@lotr.com"), email="frodo.baggins@lotr.com",
ContactPerson(id=4, firstname="Legolas", lastname="Greenleaf", phone_number="555-666-7777", ),
email="legolas.greenleaf@lotr.com"), ContactPerson(
ContactPerson(id=5, firstname="Gimli", lastname="Son of Gloin", phone_number="888-999-0000", id=2,
email="gimli.sonofgloin@lotr.com"), firstname="Samwise",
ContactPerson(id=6, firstname="Gandalf", lastname="The Grey", phone_number="222-333-4444", lastname="Gamgee",
email="gandalf.thegrey@lotr.com"), phone_number="987-654-3210",
ContactPerson(id=7, firstname="Boromir", lastname="Son of Denethor", phone_number="111-222-3333", email="samwise.gamgee@lotr.com",
email="boromir.sonofdenethor@lotr.com"), ),
ContactPerson(id=8, firstname="Galadriel", lastname="Lady of Lothlórien", phone_number="444-555-6666", ContactPerson(
email="galadriel.lothlorien@lotr.com"), id=3,
ContactPerson(id=9, firstname="Elrond", lastname="Half-elven", phone_number="777-888-9999", firstname="Aragorn",
email="elrond.halfelven@lotr.com"), lastname="Elessar",
ContactPerson(id=10, firstname="Eowyn", lastname="Shieldmaiden of Rohan", phone_number="000-111-2222", phone_number="123-333-4444",
email="eowyn.rohan@lotr.com"), email="aragorn.elessar@lotr.com",
),
ContactPerson(
id=4,
firstname="Legolas",
lastname="Greenleaf",
phone_number="555-666-7777",
email="legolas.greenleaf@lotr.com",
),
ContactPerson(
id=5,
firstname="Gimli",
lastname="Son of Gloin",
phone_number="888-999-0000",
email="gimli.sonofgloin@lotr.com",
),
ContactPerson(
id=6,
firstname="Gandalf",
lastname="The Grey",
phone_number="222-333-4444",
email="gandalf.thegrey@lotr.com",
),
ContactPerson(
id=7,
firstname="Boromir",
lastname="Son of Denethor",
phone_number="111-222-3333",
email="boromir.sonofdenethor@lotr.com",
),
ContactPerson(
id=8,
firstname="Galadriel",
lastname="Lady of Lothlórien",
phone_number="444-555-6666",
email="galadriel.lothlorien@lotr.com",
),
ContactPerson(
id=9,
firstname="Elrond",
lastname="Half-elven",
phone_number="777-888-9999",
email="elrond.halfelven@lotr.com",
),
ContactPerson(
id=10,
firstname="Eowyn",
lastname="Shieldmaiden of Rohan",
phone_number="000-111-2222",
email="eowyn.rohan@lotr.com",
),
] ]
# Example data for return addresses # Example data for return addresses
return_addresses = [ return_addresses = [
Address(id=1, street='123 Hobbiton St', city='Shire', zipcode='12345', country='Middle Earth'), Address(
Address(id=2, street='456 Rohan Rd', city='Edoras', zipcode='67890', country='Middle Earth'), id=1,
Address(id=3, street='789 Greenwood Dr', city='Mirkwood', zipcode='13579', country='Middle Earth'), street="123 Hobbiton St",
Address(id=4, street='321 Gondor Ave', city='Minas Tirith', zipcode='24680', country='Middle Earth'), city="Shire",
Address(id=5, street='654 Falgorn Pass', city='Rivendell', zipcode='11223', country='Middle Earth') zipcode="12345",
country="Middle Earth",
),
Address(
id=2,
street="456 Rohan Rd",
city="Edoras",
zipcode="67890",
country="Middle Earth",
),
Address(
id=3,
street="789 Greenwood Dr",
city="Mirkwood",
zipcode="13579",
country="Middle Earth",
),
Address(
id=4,
street="321 Gondor Ave",
city="Minas Tirith",
zipcode="24680",
country="Middle Earth",
),
Address(
id=5,
street="654 Falgorn Pass",
city="Rivendell",
zipcode="11223",
country="Middle Earth",
),
] ]
# Example data for dewars # Example data for dewars
dewars = [ dewars = [
Dewar( Dewar(
id='DEWAR001', id="DEWAR001",
dewar_name='Dewar One', dewar_name="Dewar One",
tracking_number='TRACK123', tracking_number="TRACK123",
number_of_pucks=7, number_of_pucks=7,
number_of_samples=70, number_of_samples=70,
return_address=[return_addresses[0]], return_address=[return_addresses[0]],
contact_person=[contacts[0]], contact_person=[contacts[0]],
status='Ready for Shipping', status="Ready for Shipping",
ready_date='2023-09-30', ready_date="2023-09-30",
shipping_date='', shipping_date="",
arrival_date='', arrival_date="",
returning_date='', returning_date="",
qrcode='QR123DEWAR001' qrcode="QR123DEWAR001",
), ),
Dewar( Dewar(
id='DEWAR002', id="DEWAR002",
dewar_name='Dewar Two', dewar_name="Dewar Two",
tracking_number='TRACK124', tracking_number="TRACK124",
number_of_pucks=3, number_of_pucks=3,
number_of_samples=33, number_of_samples=33,
return_address=[return_addresses[1]], return_address=[return_addresses[1]],
contact_person=[contacts[1]], contact_person=[contacts[1]],
status='In Preparation', status="In Preparation",
ready_date='', ready_date="",
shipping_date='', shipping_date="",
arrival_date='', arrival_date="",
returning_date='', returning_date="",
qrcode='QR123DEWAR002' qrcode="QR123DEWAR002",
), ),
Dewar( Dewar(
id='DEWAR003', id="DEWAR003",
dewar_name='Dewar Three', dewar_name="Dewar Three",
tracking_number='TRACK125', tracking_number="TRACK125",
number_of_pucks=7, number_of_pucks=7,
number_of_samples=72, number_of_samples=72,
return_address=[return_addresses[0]], return_address=[return_addresses[0]],
contact_person=[contacts[2]], contact_person=[contacts[2]],
status='Not Shipped', status="Not Shipped",
ready_date='2024.01.01', ready_date="2024.01.01",
shipping_date='', shipping_date="",
arrival_date='', arrival_date="",
returning_date='', returning_date="",
qrcode='QR123DEWAR003' qrcode="QR123DEWAR003",
), ),
Dewar( Dewar(
id='DEWAR004', id="DEWAR004",
dewar_name='Dewar Four', dewar_name="Dewar Four",
tracking_number='', tracking_number="",
number_of_pucks=7, number_of_pucks=7,
number_of_samples=70, number_of_samples=70,
return_address=[return_addresses[0]], return_address=[return_addresses[0]],
contact_person=[contacts[2]], contact_person=[contacts[2]],
status='Delayed', status="Delayed",
ready_date='2024.01.01', ready_date="2024.01.01",
shipping_date='2024.01.02', shipping_date="2024.01.02",
arrival_date='', arrival_date="",
returning_date='', returning_date="",
qrcode='QR123DEWAR003' qrcode="QR123DEWAR003",
), ),
Dewar( Dewar(
id='DEWAR005', id="DEWAR005",
dewar_name='Dewar Five', dewar_name="Dewar Five",
tracking_number='', tracking_number="",
number_of_pucks=3, number_of_pucks=3,
number_of_samples=30, number_of_samples=30,
return_address=[return_addresses[0]], return_address=[return_addresses[0]],
contact_person=[contacts[2]], contact_person=[contacts[2]],
status='Returned', status="Returned",
ready_date='2024.01.01', ready_date="2024.01.01",
shipping_date='2024.01.02', shipping_date="2024.01.02",
arrival_date='2024.01.03', arrival_date="2024.01.03",
returning_date='2024.01.07', returning_date="2024.01.07",
qrcode='QR123DEWAR003' qrcode="QR123DEWAR003",
), ),
] ]
@ -205,9 +285,11 @@ proposals = [
] ]
# Example: Attach specific Dewars by their ids to shipments # Example: Attach specific Dewars by their ids to shipments
specific_dewar_ids1 = ['DEWAR003'] # The IDs of the Dewars you want to attach to the first shipment specific_dewar_ids1 = [
specific_dewar_ids2 = ['DEWAR001', 'DEWAR002'] "DEWAR003"
specific_dewar_ids3 = ['DEWAR003', 'DEWAR004', 'DEWAR005'] ] # The IDs of the Dewars you want to attach to the first shipment
specific_dewar_ids2 = ["DEWAR001", "DEWAR002"]
specific_dewar_ids3 = ["DEWAR003", "DEWAR004", "DEWAR005"]
# The IDs of the Dewars you want to attach to the second shipment # The IDs of the Dewars you want to attach to the second shipment
# Find the Dewars with the matching ids # Find the Dewars with the matching ids
@ -218,38 +300,38 @@ specific_dewars3 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids3
# Define shipments with the selected Dewars # Define shipments with the selected Dewars
shipments = [ shipments = [
Shipment( Shipment(
shipment_id='SHIPMORDOR', shipment_id="SHIPMORDOR",
shipment_date='2024-10-10', shipment_date="2024-10-10",
shipment_name='Shipment from Mordor', shipment_name="Shipment from Mordor",
shipment_status='Delivered', shipment_status="Delivered",
contact_person=[contacts[1]], contact_person=[contacts[1]],
proposal_number=[proposals[1]], proposal_number=[proposals[1]],
return_address=[return_addresses[0]], return_address=[return_addresses[0]],
comments='Handle with care', comments="Handle with care",
dewars=specific_dewars1 # Attach specific Dewars for this shipment dewars=specific_dewars1, # Attach specific Dewars for this shipment
), ),
Shipment( Shipment(
shipment_id='SHIPMORDOR2', shipment_id="SHIPMORDOR2",
shipment_date='2024-10-24', shipment_date="2024-10-24",
shipment_name='Shipment from Mordor', shipment_name="Shipment from Mordor",
shipment_status='In Transit', shipment_status="In Transit",
contact_person=[contacts[3]], contact_person=[contacts[3]],
proposal_number=[proposals[2]], proposal_number=[proposals[2]],
return_address=[return_addresses[1]], # Changed index to a valid one return_address=[return_addresses[1]], # Changed index to a valid one
comments='Contains the one ring', comments="Contains the one ring",
dewars=specific_dewars2 # Attach specific Dewars for this shipment dewars=specific_dewars2, # Attach specific Dewars for this shipment
), ),
Shipment( Shipment(
shipment_id='SHIPMORDOR3', shipment_id="SHIPMORDOR3",
shipment_date='2024-10-28', shipment_date="2024-10-28",
shipment_name='Shipment from Mordor', shipment_name="Shipment from Mordor",
shipment_status='In Transit', shipment_status="In Transit",
contact_person=[contacts[4]], contact_person=[contacts[4]],
proposal_number=[proposals[3]], proposal_number=[proposals[3]],
return_address=[return_addresses[0]], # Changed index to a valid one return_address=[return_addresses[0]], # Changed index to a valid one
comments='Contains the one ring', comments="Contains the one ring",
dewars=specific_dewars3 dewars=specific_dewars3,
) ),
] ]
@ -269,7 +351,11 @@ async def get_proposals():
@app.get("/shipments", response_model=List[Shipment]) @app.get("/shipments", response_model=List[Shipment])
async def get_shipments(shipment_id: Optional[str] = Query(None, description="ID of the specific shipment to retrieve")): async def get_shipments(
shipment_id: Optional[str] = Query(
None, description="ID of the specific shipment to retrieve"
)
):
if shipment_id: if shipment_id:
shipment = next((sh for sh in shipments if sh.shipment_id == shipment_id), None) shipment = next((sh for sh in shipments if sh.shipment_id == shipment_id), None)
if not shipment: if not shipment:
@ -281,7 +367,9 @@ async def get_shipments(shipment_id: Optional[str] = Query(None, description="ID
@app.delete("/shipments/{shipment_id}", status_code=status.HTTP_204_NO_CONTENT) @app.delete("/shipments/{shipment_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_shipment(shipment_id: str): async def delete_shipment(shipment_id: str):
global shipments # Use global variable to access the shipments list global shipments # Use global variable to access the shipments list
shipments = [shipment for shipment in shipments if shipment.shipment_id != shipment_id] shipments = [
shipment for shipment in shipments if shipment.shipment_id != shipment_id
]
@app.post("/shipments/{shipment_id}/add_dewar", response_model=Shipment) @app.post("/shipments/{shipment_id}/add_dewar", response_model=Shipment)
@ -322,16 +410,32 @@ async def update_shipment(shipment_id: str, updated_shipment: Shipment):
if updated_dewar.id in existing_dewar_dict: if updated_dewar.id in existing_dewar_dict:
# Update existing dewar # Update existing dewar
existing_dewar_dict[updated_dewar.id].dewar_name = updated_dewar.dewar_name existing_dewar_dict[updated_dewar.id].dewar_name = updated_dewar.dewar_name
existing_dewar_dict[updated_dewar.id].tracking_number = updated_dewar.tracking_number existing_dewar_dict[updated_dewar.id].tracking_number = (
existing_dewar_dict[updated_dewar.id].number_of_pucks = updated_dewar.number_of_pucks updated_dewar.tracking_number
existing_dewar_dict[updated_dewar.id].number_of_samples = updated_dewar.number_of_samples )
existing_dewar_dict[updated_dewar.id].return_address = updated_dewar.return_address existing_dewar_dict[updated_dewar.id].number_of_pucks = (
existing_dewar_dict[updated_dewar.id].contact_person = updated_dewar.contact_person updated_dewar.number_of_pucks
)
existing_dewar_dict[updated_dewar.id].number_of_samples = (
updated_dewar.number_of_samples
)
existing_dewar_dict[updated_dewar.id].return_address = (
updated_dewar.return_address
)
existing_dewar_dict[updated_dewar.id].contact_person = (
updated_dewar.contact_person
)
existing_dewar_dict[updated_dewar.id].status = updated_dewar.status existing_dewar_dict[updated_dewar.id].status = updated_dewar.status
existing_dewar_dict[updated_dewar.id].ready_date = updated_dewar.ready_date existing_dewar_dict[updated_dewar.id].ready_date = updated_dewar.ready_date
existing_dewar_dict[updated_dewar.id].shipping_date = updated_dewar.shipping_date existing_dewar_dict[updated_dewar.id].shipping_date = (
existing_dewar_dict[updated_dewar.id].arrival_date = updated_dewar.arrival_date updated_dewar.shipping_date
existing_dewar_dict[updated_dewar.id].returning_date = updated_dewar.returning_date )
existing_dewar_dict[updated_dewar.id].arrival_date = (
updated_dewar.arrival_date
)
existing_dewar_dict[updated_dewar.id].returning_date = (
updated_dewar.returning_date
)
existing_dewar_dict[updated_dewar.id].qrcode = updated_dewar.qrcode existing_dewar_dict[updated_dewar.id].qrcode = updated_dewar.qrcode
else: else:
# Add new dewar # Add new dewar
@ -358,7 +462,7 @@ async def get_dewars():
@app.post("/dewars", response_model=Dewar, status_code=status.HTTP_201_CREATED) @app.post("/dewars", response_model=Dewar, status_code=status.HTTP_201_CREATED)
async def create_dewar(dewar: Dewar) -> Dewar: async def create_dewar(dewar: Dewar) -> Dewar:
dewar_id = f'DEWAR-{uuid.uuid4().hex[:8].upper()}' # Generates a unique dewar ID dewar_id = f"DEWAR-{uuid.uuid4().hex[:8].upper()}" # Generates a unique dewar ID
dewar.id = dewar_id # Set the generated ID on the dewar object dewar.id = dewar_id # Set the generated ID on the dewar object
dewars.append(dewar) # Add the modified dewar object to the list dewars.append(dewar) # Add the modified dewar object to the list
@ -382,14 +486,21 @@ async def remove_dewar_from_shipment(shipment_id: str, dewar_id: str):
@app.get("/shipment_contact_persons") @app.get("/shipment_contact_persons")
async def get_shipment_contact_persons(): async def get_shipment_contact_persons():
return [{"shipment_id": shipment.shipment_id, "contact_person": shipment.get_shipment_contact_persons()} for return [
shipment in shipments] {
"shipment_id": shipment.shipment_id,
"contact_person": shipment.get_shipment_contact_persons(),
}
for shipment in shipments
]
@app.post("/shipments", response_model=Shipment, status_code=status.HTTP_201_CREATED) @app.post("/shipments", response_model=Shipment, status_code=status.HTTP_201_CREATED)
async def create_shipment(shipment: Shipment): async def create_shipment(shipment: Shipment):
# Automatically generate a shipment ID # Automatically generate a shipment ID
shipment_id = f'SHIP-{uuid.uuid4().hex[:8].upper()}' # Generates a unique shipment ID shipment_id = (
f"SHIP-{uuid.uuid4().hex[:8].upper()}" # Generates a unique shipment ID
)
shipment.shipment_id = shipment_id # Set the generated ID shipment.shipment_id = shipment_id # Set the generated ID
# Append the shipment to the list # Append the shipment to the list
@ -398,13 +509,15 @@ async def create_shipment(shipment: Shipment):
# Creation of a new contact # Creation of a new contact
@app.post("/contacts", response_model=ContactPerson, status_code=status.HTTP_201_CREATED) @app.post(
"/contacts", response_model=ContactPerson, status_code=status.HTTP_201_CREATED
)
async def create_contact(contact: ContactPerson): async def create_contact(contact: ContactPerson):
# Check for duplicate contact by email (or other unique fields) # Check for duplicate contact by email (or other unique fields)
if any(c.email == contact.email for c in contacts): if any(c.email == contact.email for c in contacts):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="This contact already exists." detail="This contact already exists.",
) )
# Find the next available id # Find the next available id
@ -419,13 +532,15 @@ async def create_contact(contact: ContactPerson):
# Creation of a return address # Creation of a return address
@app.post("/return_addresses", response_model=Address, status_code=status.HTTP_201_CREATED) @app.post(
"/return_addresses", response_model=Address, status_code=status.HTTP_201_CREATED
)
async def create_return_address(address: Address): async def create_return_address(address: Address):
# Check for duplicate address by city # Check for duplicate address by city
if any(a.city == address.city for a in return_addresses): if any(a.city == address.city for a in return_addresses):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Address in this city already exists." detail="Address in this city already exists.",
) )
# Find the next available id # Find the next available id

View File

@ -7,20 +7,26 @@ client = TestClient(app)
def test_login_success(): def test_login_success():
response = client.post("/auth/token/login", data={"username": "testuser", "password": "testpass"}) response = client.post(
"/auth/token/login", data={"username": "testuser", "password": "testpass"}
)
assert response.status_code == 200 assert response.status_code == 200
assert "access_token" in response.json() assert "access_token" in response.json()
def test_login_failure(): def test_login_failure():
response = client.post("/auth/token/login", data={"username": "wrong", "password": "wrongpass"}) response = client.post(
"/auth/token/login", data={"username": "wrong", "password": "wrongpass"}
)
assert response.status_code == 401 assert response.status_code == 401
assert response.json() == {"detail": "Incorrect username or password"} assert response.json() == {"detail": "Incorrect username or password"}
def test_protected_route(): def test_protected_route():
# Step 1: Login # Step 1: Login
response = client.post("/auth/token/login", data={"username": "testuser", "password": "testpass"}) response = client.post(
"/auth/token/login", data={"username": "testuser", "password": "testpass"}
)
token = response.json()["access_token"] token = response.json()["access_token"]
# Step 2: Access protected route # Step 2: Access protected route