fixing bugs with ci pipeline

This commit is contained in:
GotthardG 2024-12-16 22:50:04 +01:00
parent e0e176881b
commit 0178de96fd
14 changed files with 145 additions and 96 deletions

View File

@ -4,7 +4,13 @@ repos:
hooks:
- id: black
name: Black code formatter
args: ["--check"] # Only check formatting without modifying
args: ["--line-length", "88"] # Actively fix code, no --check
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v2.0.4 # Use a specific stable version
hooks:
- id: autopep8
args: [ "--in-place", "--aggressive", "--max-line-length=88" ]
- repo: https://github.com/pycqa/flake8
rev: 6.1.0 # Use the latest stable version of flake8

View File

@ -11,3 +11,17 @@ from .data import (
sample_events,
)
from .slots_data import slots
__all__ = [
"contacts",
"return_addresses",
"dewars",
"proposals",
"shipments",
"pucks",
"samples",
"dewar_types",
"serial_numbers",
"sample_events",
"slots",
]

View File

@ -8,7 +8,6 @@ from app.models import (
Sample,
DewarType,
DewarSerialNumber,
Slot,
SampleEvent,
)
from datetime import datetime, timedelta
@ -526,7 +525,8 @@ event_types = ["Mounted", "Failed", "Unmounted", "Lost"]
def generate_sample_events(samples, chance_no_event=0.2, chance_lost=0.1):
"""Generate events for samples with timestamps increasing between different samples."""
"""Generate events for samples with timestamps
increasing between different samples."""
events = []
# Set the start time to yesterday at 9:33 AM

View File

@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import timedelta
from app.models import Slot
slotQRCodes = [

View File

@ -11,7 +11,6 @@ db_username = os.getenv("DB_USERNAME")
db_password = os.getenv("DB_PASSWORD")
# Construct the database URL
# SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db"
SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db"
# Remove the `connect_args` parameter
@ -32,9 +31,6 @@ def get_db():
def init_db():
# Import models inside function to avoid circular dependency
from . import models
Base.metadata.create_all(bind=engine)

View File

@ -1,4 +1,8 @@
import os, tempfile, time, random, hashlib
import os
import tempfile
import time
import random
import hashlib
from fastapi import APIRouter, HTTPException, status, Depends, Response
from sqlalchemy.orm import Session, joinedload
from typing import List
@ -21,14 +25,12 @@ from app.models import (
Sample as SampleModel,
DewarType as DewarTypeModel,
DewarSerialNumber as DewarSerialNumberModel,
Shipment as ShipmentModel, # Clearer name for model
)
from app.dependencies import get_db
import uuid
import qrcode
import io
from io import BytesIO
from PIL import ImageFont, ImageDraw, Image
from PIL import Image
from reportlab.lib.pagesizes import A5, landscape
from reportlab.lib.units import cm
from reportlab.pdfgen import canvas
@ -211,7 +213,7 @@ def generate_label(dewar):
c.drawString(2 * cm, y_position, f"Country: {return_address.country}")
y_position -= line_height
c.drawString(2 * cm, y_position, f"Beamtime Information: Placeholder")
c.drawString(2 * cm, y_position, "Beamtime Information: Placeholder")
# Generate QR code
qr = qrcode.QRCode(version=1, box_size=10, border=4)

View File

@ -34,7 +34,8 @@ def calculate_time_until_refill(
@router.post("/dewars/return", response_model=DewarSchema)
async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(get_db)):
logger.info(
f"Returning dewar to storage: {data.dewar_qr_code} at location {data.location_qr_code}"
f"Returning dewar to storage: {data.dewar_qr_code}"
f"at location {data.location_qr_code}"
)
try:
@ -57,11 +58,13 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge
)
if original_slot and original_slot.qr_code != data.location_qr_code:
logger.error(
f"Dewar {data.dewar_qr_code} is associated with slot {original_slot.qr_code}"
f"Dewar {data.dewar_qr_code} is"
f"associated with slot {original_slot.qr_code}"
)
raise HTTPException(
status_code=400,
detail=f"Dewar {data.dewar_qr_code} is associated with a different slot {original_slot.qr_code}.",
detail=f"Dewar {data.dewar_qr_code} is associated"
f"with a different slot {original_slot.qr_code}.",
)
slot = (
@ -87,12 +90,16 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge
slot.occupied = True
dewar.last_retrieved_timestamp = None
# Set the `at_beamline` attribute to False
dewar.at_beamline = False
# Log the event
log_event(db, dewar.id, slot.id, "returned")
db.commit()
logger.info(
f"Dewar {data.dewar_qr_code} successfully returned to storage slot {slot.qr_code}."
f"Dewar {data.dewar_qr_code} successfully"
f"returned to storage slot {slot.qr_code}."
)
db.refresh(dewar)
return dewar
@ -174,7 +181,8 @@ async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get
log_event(db, dewar.id, slot.id if slot else None, transaction_type)
db.commit()
logger.info(
f"Transaction completed: {transaction_type} for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}"
f"Transaction completed: {transaction_type}"
f"for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}"
)
return {"message": "Status updated successfully"}
@ -191,7 +199,6 @@ async def get_all_slots(db: Session = Depends(get_db)):
retrievedTimestamp = None
beamlineLocation = None
at_beamline = False
retrieved = False
if slot.dewar_unique_id:
# Calculate time until refill
@ -212,32 +219,32 @@ async def get_all_slots(db: Session = Depends(get_db)):
else:
time_until_refill = -1
# Fetch the latest beamline event
last_beamline_event = (
# Fetch the latest event for the dewar
last_event = (
db.query(LogisticsEventModel)
.join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id)
.filter(
DewarModel.unique_id == slot.dewar.unique_id,
LogisticsEventModel.event_type == "beamline",
)
.filter(DewarModel.unique_id == slot.dewar.unique_id)
.order_by(LogisticsEventModel.timestamp.desc())
.first()
)
if last_beamline_event:
# Set retrievedTimestamp to the timestamp of the beamline event
retrievedTimestamp = last_beamline_event.timestamp.isoformat()
# Fetch the associated slot's label for beamlineLocation
associated_slot = (
db.query(SlotModel)
.filter(SlotModel.id == last_beamline_event.slot_id)
.first()
)
beamlineLocation = associated_slot.label if associated_slot else None
# Mark as being at a beamline
at_beamline = True
# Determine if the dewar is at the beamline
if last_event:
if last_event.event_type == "beamline":
at_beamline = True
# Optionally set retrievedTimestamp and beamlineLocation for
# beamline events
retrievedTimestamp = last_event.timestamp.isoformat()
associated_slot = (
db.query(SlotModel)
.filter(SlotModel.id == last_event.slot_id)
.first()
)
beamlineLocation = (
associated_slot.label if associated_slot else None
)
elif last_event.event_type == "returned":
at_beamline = False
# Correct the contact_person assignment
contact_person = None
@ -296,7 +303,8 @@ async def refill_dewar(qr_code: str, db: Session = Depends(get_db)):
time_until_refill_seconds = calculate_time_until_refill(now)
logger.info(
f"Dewar refilled successfully with time_until_refill: {time_until_refill_seconds}"
f"Dewar refilled successfully"
f"with time_until_refill: {time_until_refill_seconds}"
)
return {
@ -334,5 +342,6 @@ def log_event(db: Session, dewar_id: int, slot_id: Optional[int], event_type: st
db.add(new_event)
db.commit()
logger.info(
f"Logged event: {event_type} for dewar: {dewar_id} in slot: {slot_id if slot_id else 'N/A'}"
f"Logged event: {event_type} for dewar: {dewar_id} "
f"in slot: {slot_id if slot_id else 'N/A'}"
)

View File

@ -7,7 +7,6 @@ from app.schemas import (
PuckCreate,
PuckUpdate,
SetTellPosition,
PuckEvent,
)
from app.models import (
Puck as PuckModel,
@ -35,7 +34,8 @@ async def get_pucks(db: Session = Depends(get_db)):
@router.get("/with-tell-position", response_model=List[dict])
async def get_pucks_with_tell_position(db: Session = Depends(get_db)):
"""
Retrieve all pucks with a `tell_position` set (not null) and their associated samples.
Retrieve all pucks with a `tell_position`
set (not null) and their associated samples.
"""
# Query all pucks that have an event with a non-null tell_position
pucks = (
@ -157,8 +157,10 @@ async def set_tell_position(
# Create a new PuckEvent (always a new event, even with null/None)
new_puck_event = PuckEventModel(
puck_id=puck_id,
tell_position=actual_position, # Null for disassociation, else the valid position
event_type="tell_position_set", # Example event type
tell_position=actual_position,
# Null for disassociation, else the valid position
event_type="tell_position_set",
# Example event type
timestamp=datetime.utcnow(),
)
db.add(new_puck_event)
@ -232,7 +234,9 @@ async def get_pucks_by_slot(slot_identifier: str, db: Session = Depends(get_db))
if not slot_id:
raise HTTPException(
status_code=400,
detail="Invalid slot identifier. Must be an ID or one of the following: PXI, PXII, PXIII, X06SA, X10SA, X06DA.",
detail="Invalid slot identifier."
"Must be an ID or one of the following:"
"PXI, PXII, PXIII, X06SA, X10SA, X06DA.",
)
# Verify that the slot exists

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, status, Depends
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from typing import List
from app.schemas import Puck as PuckSchema, Sample as SampleSchema, SampleEventCreate
from app.schemas import Puck as PuckSchema, Sample as SampleSchema
from app.models import (
Puck as PuckModel,
Sample as SampleModel,

View File

@ -2,9 +2,10 @@ from fastapi import APIRouter, HTTPException, status, Query, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
import logging
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from datetime import date
from sqlalchemy.exc import SQLAlchemyError
import json
from app.models import (
Shipment as ShipmentModel,
@ -19,12 +20,9 @@ from app.schemas import (
ShipmentCreate,
UpdateShipmentComments,
Shipment as ShipmentSchema,
DewarUpdate,
ContactPerson as ContactPersonSchema,
Sample as SampleSchema,
DewarCreate,
PuckCreate,
SampleCreate,
DewarSchema,
)
from app.database import get_db
@ -185,7 +183,8 @@ async def update_shipment(
if not contact_person:
raise HTTPException(
status_code=404,
detail=f"Contact person with ID {value} for Dewar {dewar_data.dewar_id} not found",
detail=f"Contact person with ID {value}"
f"for Dewar {dewar_data.dewar_id} not found",
)
if key == "return_address_id":
address = (
@ -194,7 +193,8 @@ async def update_shipment(
if not address:
raise HTTPException(
status_code=404,
detail=f"Address with ID {value} for Dewar {dewar_data.dewar_id} not found",
detail=f"Address with ID {value}"
f"for Dewar {dewar_data.dewar_id} not found",
)
for key, value in update_fields.items():

View File

@ -51,9 +51,12 @@ async def upload_file(file: UploadFile = File(...)):
)
# Initialize the importer and process the spreadsheet
validated_model, errors, raw_data, headers = (
importer.import_spreadsheet_with_errors(file)
)
(
validated_model,
errors,
raw_data,
headers,
) = importer.import_spreadsheet_with_errors(file)
# Extract unique values for dewars, pucks, and samples
dewars = {sample.dewarname for sample in validated_model if sample.dewarname}
@ -82,7 +85,8 @@ async def upload_file(file: UploadFile = File(...)):
row_storage.set_row(row_num, row.dict())
logger.info(
f"Returning response with {len(validated_model)} records and {len(errors)} errors."
f"Returning response with {len(validated_model)}"
f"records and {len(errors)} errors."
)
return response_data
@ -121,7 +125,9 @@ async def validate_cell(data: dict):
try:
# Ensure we're using the full row data context for validation
validated_row = SpreadsheetModel(**current_row_data)
SpreadsheetModel(
**current_row_data
) # Instantiates the Pydantic model, performing validation
logger.info(f"Validation succeeded for row {row_num}, column {col_name}")
return {"is_valid": True, "message": ""}
except ValidationError as e:

View File

@ -15,7 +15,8 @@ class SpreadsheetModel(BaseModel):
...,
max_length=64,
title="Crystal Name",
description="max_length imposed by MTZ file header format https://www.ccp4.ac.uk/html/mtzformat.html",
description="max_length imposed by MTZ file header"
"format https://www.ccp4.ac.uk/html/mtzformat.html",
alias="crystalname",
),
]
@ -27,31 +28,31 @@ class SpreadsheetModel(BaseModel):
oscillation: Optional[float] = None # Only accept positive float
exposure: Optional[float] = None # Only accept positive floats between 0 and 1
totalrange: Optional[int] = None # Only accept positive integers between 0 and 360
transmission: Optional[int] = (
None # Only accept positive integers between 0 and 100
)
transmission: Optional[
int
] = None # Only accept positive integers between 0 and 100
targetresolution: Optional[float] = None # Only accept positive float
aperture: Optional[str] = None # Optional string field
datacollectiontype: Optional[str] = (
None # Only accept "standard", other types might be added later
)
processingpipeline: Optional[str] = (
"" # Only accept "gopy", "autoproc", "xia2dials"
)
spacegroupnumber: Optional[int] = (
None # Only accept positive integers between 1 and 230
)
cellparameters: Optional[str] = (
None # Must be a set of six positive floats or integers
)
datacollectiontype: Optional[
str
] = None # Only accept "standard", other types might be added later
processingpipeline: Optional[
str
] = "" # Only accept "gopy", "autoproc", "xia2dials"
spacegroupnumber: Optional[
int
] = None # Only accept positive integers between 1 and 230
cellparameters: Optional[
str
] = None # Must be a set of six positive floats or integers
rescutkey: Optional[str] = None # Only accept "is" or "cchalf"
rescutvalue: Optional[float] = (
None # Must be a positive float if rescutkey is provided
)
rescutvalue: Optional[
float
] = None # Must be a positive float if rescutkey is provided
userresolution: Optional[float] = None
pdbid: Optional[str] = (
"" # Accepts either the format of the protein data bank code or {provided}
)
pdbid: Optional[
str
] = "" # Accepts either the format of the protein data bank code or {provided}
autoprocfull: Optional[bool] = None
procfull: Optional[bool] = None
adpenabled: Optional[bool] = None
@ -206,11 +207,13 @@ class SpreadsheetModel(BaseModel):
v = int(v)
if not (0 <= v <= 360):
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 360."
f" '{v}' is not valid."
f"Value must be an integer between 0 and 360."
)
except (ValueError, TypeError) as e:
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 360."
f" '{v}' is not valid."
f"Value must be an integer between 0 and 360."
) from e
return v
@ -222,11 +225,13 @@ class SpreadsheetModel(BaseModel):
v = int(v)
if not (0 <= v <= 100):
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 100."
f" '{v}' is not valid."
f"Value must be an integer between 0 and 100."
)
except (ValueError, TypeError) as e:
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 0 and 100."
f" '{v}' is not valid."
f"Value must be an integer between 0 and 100."
) from e
return v
@ -235,7 +240,7 @@ class SpreadsheetModel(BaseModel):
def datacollectiontype_allowed(cls, v):
allowed = {"standard"} # Other types of data collection might be added later
if v and v.lower() not in allowed:
raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.")
raise ValueError(f" '{v}' is not valid." f"Value must be one of {allowed}.")
return v
@field_validator("processingpipeline", mode="before")
@ -243,7 +248,7 @@ class SpreadsheetModel(BaseModel):
def processingpipeline_allowed(cls, v):
allowed = {"gopy", "autoproc", "xia2dials"}
if v and v.lower() not in allowed:
raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.")
raise ValueError(f" '{v}' is not valid." f"Value must be one of {allowed}.")
return v
@field_validator("spacegroupnumber", mode="before")
@ -254,11 +259,13 @@ class SpreadsheetModel(BaseModel):
v = int(v)
if not (1 <= v <= 230):
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 1 and 230."
f" '{v}' is not valid."
f"Value must be an integer between 1 and 230."
)
except (ValueError, TypeError) as e:
raise ValueError(
f" '{v}' is not valid. Value must be an integer between 1 and 230."
f" '{v}' is not valid."
f"Value must be an integer between 1 and 230."
) from e
return v
@ -269,7 +276,8 @@ class SpreadsheetModel(BaseModel):
values = [float(i) for i in v.split(",")]
if len(values) != 6 or any(val <= 0 for val in values):
raise ValueError(
f" '{v}' is not valid. Value must be a set of six positive floats or integers."
f" '{v}' is not valid."
f"Value must be a set of six positive floats or integers."
)
return v
@ -295,11 +303,13 @@ class SpreadsheetModel(BaseModel):
v = float(v)
if not (0 <= v <= 2.0):
raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 2.0."
f" '{v}' is not valid."
f"Value must be a float between 0 and 2.0."
)
except (ValueError, TypeError) as e:
raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 2.0."
f" '{v}' is not valid."
f"Value must be a float between 0 and 2.0."
) from e
return v
@ -311,7 +321,8 @@ class SpreadsheetModel(BaseModel):
v = float(v)
if not (0 <= v <= 30):
raise ValueError(
f" '{v}' is not valid. Value must be a float between 0 and 30."
f" '{v}' is not valid."
f"Value must be a float between 0 and 30."
)
except (ValueError, TypeError) as e:
raise ValueError(

View File

@ -1,5 +1,5 @@
from typing import List, Optional
from datetime import datetime, timedelta # Add this import
from datetime import datetime
from pydantic import BaseModel, EmailStr, constr, Field
from datetime import date

View File

@ -1,7 +1,7 @@
import logging
import openpyxl
from pydantic import ValidationError
from typing import Union, List, Tuple
from typing import List, Tuple
from io import BytesIO
from app.sample_models import SpreadsheetModel
@ -201,7 +201,8 @@ class SampleSpreadsheetImporter:
for error in e.errors():
field = error["loc"][0]
msg = error["msg"]
# Map field name (which is the key in `record`) to its index in the row
# Map field name (which is the key in `record`) to its index in the
# row
field_to_col = {
"dewarname": 0,
"puckname": 1,