From 0178de96fda74531187a6aeeee65d92e6a910c14 Mon Sep 17 00:00:00 2001 From: GotthardG <51994228+GotthardG@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:50:04 +0100 Subject: [PATCH] fixing bugs with ci pipeline --- .pre-commit-config.yaml | 8 ++- backend/app/data/__init__.py | 14 ++++ backend/app/data/data.py | 4 +- backend/app/data/slots_data.py | 2 +- backend/app/database.py | 4 -- backend/app/routers/dewar.py | 12 ++-- backend/app/routers/logistics.py | 65 +++++++++-------- backend/app/routers/puck.py | 14 ++-- backend/app/routers/sample.py | 4 +- backend/app/routers/shipment.py | 12 ++-- backend/app/routers/spreadsheet.py | 16 +++-- backend/app/sample_models.py | 79 ++++++++++++--------- backend/app/schemas.py | 2 +- backend/app/services/spreadsheet_service.py | 5 +- 14 files changed, 145 insertions(+), 96 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95fa821..e1f7723 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,13 @@ repos: hooks: - id: black name: Black code formatter - args: ["--check"] # Only check formatting without modifying + args: ["--line-length", "88"] # Actively fix code, no --check + + - repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v2.0.4 # Use a specific stable version + hooks: + - id: autopep8 + args: [ "--in-place", "--aggressive", "--max-line-length=88" ] - repo: https://github.com/pycqa/flake8 rev: 6.1.0 # Use the latest stable version of flake8 diff --git a/backend/app/data/__init__.py b/backend/app/data/__init__.py index f6f83e9..d783213 100644 --- a/backend/app/data/__init__.py +++ b/backend/app/data/__init__.py @@ -11,3 +11,17 @@ from .data import ( sample_events, ) from .slots_data import slots + +__all__ = [ + "contacts", + "return_addresses", + "dewars", + "proposals", + "shipments", + "pucks", + "samples", + "dewar_types", + "serial_numbers", + "sample_events", + "slots", +] diff --git a/backend/app/data/data.py b/backend/app/data/data.py index 705887b..3174e80 100644 --- a/backend/app/data/data.py +++ b/backend/app/data/data.py @@ -8,7 +8,6 @@ from app.models import ( Sample, DewarType, DewarSerialNumber, - Slot, SampleEvent, ) from datetime import datetime, timedelta @@ -526,7 +525,8 @@ event_types = ["Mounted", "Failed", "Unmounted", "Lost"] def generate_sample_events(samples, chance_no_event=0.2, chance_lost=0.1): - """Generate events for samples with timestamps increasing between different samples.""" + """Generate events for samples with timestamps + increasing between different samples.""" events = [] # Set the start time to yesterday at 9:33 AM diff --git a/backend/app/data/slots_data.py b/backend/app/data/slots_data.py index 43f26da..4aa61e3 100644 --- a/backend/app/data/slots_data.py +++ b/backend/app/data/slots_data.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import timedelta from app.models import Slot slotQRCodes = [ diff --git a/backend/app/database.py b/backend/app/database.py index 40af914..d5b6090 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -11,7 +11,6 @@ db_username = os.getenv("DB_USERNAME") db_password = os.getenv("DB_PASSWORD") # Construct the database URL -# SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db" SQLALCHEMY_DATABASE_URL = f"mysql://{db_username}:{db_password}@localhost:3306/aare_db" # Remove the `connect_args` parameter @@ -32,9 +31,6 @@ def get_db(): def init_db(): - # Import models inside function to avoid circular dependency - from . import models - Base.metadata.create_all(bind=engine) diff --git a/backend/app/routers/dewar.py b/backend/app/routers/dewar.py index 106341d..25ee92f 100644 --- a/backend/app/routers/dewar.py +++ b/backend/app/routers/dewar.py @@ -1,4 +1,8 @@ -import os, tempfile, time, random, hashlib +import os +import tempfile +import time +import random +import hashlib from fastapi import APIRouter, HTTPException, status, Depends, Response from sqlalchemy.orm import Session, joinedload from typing import List @@ -21,14 +25,12 @@ from app.models import ( Sample as SampleModel, DewarType as DewarTypeModel, DewarSerialNumber as DewarSerialNumberModel, - Shipment as ShipmentModel, # Clearer name for model ) from app.dependencies import get_db -import uuid import qrcode import io from io import BytesIO -from PIL import ImageFont, ImageDraw, Image +from PIL import Image from reportlab.lib.pagesizes import A5, landscape from reportlab.lib.units import cm from reportlab.pdfgen import canvas @@ -211,7 +213,7 @@ def generate_label(dewar): c.drawString(2 * cm, y_position, f"Country: {return_address.country}") y_position -= line_height - c.drawString(2 * cm, y_position, f"Beamtime Information: Placeholder") + c.drawString(2 * cm, y_position, "Beamtime Information: Placeholder") # Generate QR code qr = qrcode.QRCode(version=1, box_size=10, border=4) diff --git a/backend/app/routers/logistics.py b/backend/app/routers/logistics.py index 87be0c8..fab4217 100644 --- a/backend/app/routers/logistics.py +++ b/backend/app/routers/logistics.py @@ -34,7 +34,8 @@ def calculate_time_until_refill( @router.post("/dewars/return", response_model=DewarSchema) async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(get_db)): logger.info( - f"Returning dewar to storage: {data.dewar_qr_code} at location {data.location_qr_code}" + f"Returning dewar to storage: {data.dewar_qr_code}" + f"at location {data.location_qr_code}" ) try: @@ -57,11 +58,13 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge ) if original_slot and original_slot.qr_code != data.location_qr_code: logger.error( - f"Dewar {data.dewar_qr_code} is associated with slot {original_slot.qr_code}" + f"Dewar {data.dewar_qr_code} is" + f"associated with slot {original_slot.qr_code}" ) raise HTTPException( status_code=400, - detail=f"Dewar {data.dewar_qr_code} is associated with a different slot {original_slot.qr_code}.", + detail=f"Dewar {data.dewar_qr_code} is associated" + f"with a different slot {original_slot.qr_code}.", ) slot = ( @@ -87,12 +90,16 @@ async def return_to_storage(data: LogisticsEventCreate, db: Session = Depends(ge slot.occupied = True dewar.last_retrieved_timestamp = None + # Set the `at_beamline` attribute to False + dewar.at_beamline = False + # Log the event log_event(db, dewar.id, slot.id, "returned") db.commit() logger.info( - f"Dewar {data.dewar_qr_code} successfully returned to storage slot {slot.qr_code}." + f"Dewar {data.dewar_qr_code} successfully" + f"returned to storage slot {slot.qr_code}." ) db.refresh(dewar) return dewar @@ -174,7 +181,8 @@ async def scan_dewar(event_data: LogisticsEventCreate, db: Session = Depends(get log_event(db, dewar.id, slot.id if slot else None, transaction_type) db.commit() logger.info( - f"Transaction completed: {transaction_type} for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}" + f"Transaction completed: {transaction_type}" + f"for dewar {dewar_qr_code} in slot {slot.qr_code if slot else 'N/A'}" ) return {"message": "Status updated successfully"} @@ -191,7 +199,6 @@ async def get_all_slots(db: Session = Depends(get_db)): retrievedTimestamp = None beamlineLocation = None at_beamline = False - retrieved = False if slot.dewar_unique_id: # Calculate time until refill @@ -212,32 +219,32 @@ async def get_all_slots(db: Session = Depends(get_db)): else: time_until_refill = -1 - # Fetch the latest beamline event - last_beamline_event = ( + # Fetch the latest event for the dewar + last_event = ( db.query(LogisticsEventModel) .join(DewarModel, DewarModel.id == LogisticsEventModel.dewar_id) - .filter( - DewarModel.unique_id == slot.dewar.unique_id, - LogisticsEventModel.event_type == "beamline", - ) + .filter(DewarModel.unique_id == slot.dewar.unique_id) .order_by(LogisticsEventModel.timestamp.desc()) .first() ) - if last_beamline_event: - # Set retrievedTimestamp to the timestamp of the beamline event - retrievedTimestamp = last_beamline_event.timestamp.isoformat() - - # Fetch the associated slot's label for beamlineLocation - associated_slot = ( - db.query(SlotModel) - .filter(SlotModel.id == last_beamline_event.slot_id) - .first() - ) - beamlineLocation = associated_slot.label if associated_slot else None - - # Mark as being at a beamline - at_beamline = True + # Determine if the dewar is at the beamline + if last_event: + if last_event.event_type == "beamline": + at_beamline = True + # Optionally set retrievedTimestamp and beamlineLocation for + # beamline events + retrievedTimestamp = last_event.timestamp.isoformat() + associated_slot = ( + db.query(SlotModel) + .filter(SlotModel.id == last_event.slot_id) + .first() + ) + beamlineLocation = ( + associated_slot.label if associated_slot else None + ) + elif last_event.event_type == "returned": + at_beamline = False # Correct the contact_person assignment contact_person = None @@ -296,7 +303,8 @@ async def refill_dewar(qr_code: str, db: Session = Depends(get_db)): time_until_refill_seconds = calculate_time_until_refill(now) logger.info( - f"Dewar refilled successfully with time_until_refill: {time_until_refill_seconds}" + f"Dewar refilled successfully" + f"with time_until_refill: {time_until_refill_seconds}" ) return { @@ -334,5 +342,6 @@ def log_event(db: Session, dewar_id: int, slot_id: Optional[int], event_type: st db.add(new_event) db.commit() logger.info( - f"Logged event: {event_type} for dewar: {dewar_id} in slot: {slot_id if slot_id else 'N/A'}" + f"Logged event: {event_type} for dewar: {dewar_id} " + f"in slot: {slot_id if slot_id else 'N/A'}" ) diff --git a/backend/app/routers/puck.py b/backend/app/routers/puck.py index 8a1b54b..be18028 100644 --- a/backend/app/routers/puck.py +++ b/backend/app/routers/puck.py @@ -7,7 +7,6 @@ from app.schemas import ( PuckCreate, PuckUpdate, SetTellPosition, - PuckEvent, ) from app.models import ( Puck as PuckModel, @@ -35,7 +34,8 @@ async def get_pucks(db: Session = Depends(get_db)): @router.get("/with-tell-position", response_model=List[dict]) async def get_pucks_with_tell_position(db: Session = Depends(get_db)): """ - Retrieve all pucks with a `tell_position` set (not null) and their associated samples. + Retrieve all pucks with a `tell_position` + set (not null) and their associated samples. """ # Query all pucks that have an event with a non-null tell_position pucks = ( @@ -157,8 +157,10 @@ async def set_tell_position( # Create a new PuckEvent (always a new event, even with null/None) new_puck_event = PuckEventModel( puck_id=puck_id, - tell_position=actual_position, # Null for disassociation, else the valid position - event_type="tell_position_set", # Example event type + tell_position=actual_position, + # Null for disassociation, else the valid position + event_type="tell_position_set", + # Example event type timestamp=datetime.utcnow(), ) db.add(new_puck_event) @@ -232,7 +234,9 @@ async def get_pucks_by_slot(slot_identifier: str, db: Session = Depends(get_db)) if not slot_id: raise HTTPException( status_code=400, - detail="Invalid slot identifier. Must be an ID or one of the following: PXI, PXII, PXIII, X06SA, X10SA, X06DA.", + detail="Invalid slot identifier." + "Must be an ID or one of the following:" + "PXI, PXII, PXIII, X06SA, X10SA, X06DA.", ) # Verify that the slot exists diff --git a/backend/app/routers/sample.py b/backend/app/routers/sample.py index f3edf1d..2d24b5b 100644 --- a/backend/app/routers/sample.py +++ b/backend/app/routers/sample.py @@ -1,7 +1,7 @@ -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session from typing import List -from app.schemas import Puck as PuckSchema, Sample as SampleSchema, SampleEventCreate +from app.schemas import Puck as PuckSchema, Sample as SampleSchema from app.models import ( Puck as PuckModel, Sample as SampleModel, diff --git a/backend/app/routers/shipment.py b/backend/app/routers/shipment.py index 03f02fa..bbbd718 100644 --- a/backend/app/routers/shipment.py +++ b/backend/app/routers/shipment.py @@ -2,9 +2,10 @@ from fastapi import APIRouter, HTTPException, status, Query, Depends from sqlalchemy.orm import Session from typing import List, Optional import logging -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from datetime import date from sqlalchemy.exc import SQLAlchemyError +import json from app.models import ( Shipment as ShipmentModel, @@ -19,12 +20,9 @@ from app.schemas import ( ShipmentCreate, UpdateShipmentComments, Shipment as ShipmentSchema, - DewarUpdate, ContactPerson as ContactPersonSchema, Sample as SampleSchema, DewarCreate, - PuckCreate, - SampleCreate, DewarSchema, ) from app.database import get_db @@ -185,7 +183,8 @@ async def update_shipment( if not contact_person: raise HTTPException( status_code=404, - detail=f"Contact person with ID {value} for Dewar {dewar_data.dewar_id} not found", + detail=f"Contact person with ID {value}" + f"for Dewar {dewar_data.dewar_id} not found", ) if key == "return_address_id": address = ( @@ -194,7 +193,8 @@ async def update_shipment( if not address: raise HTTPException( status_code=404, - detail=f"Address with ID {value} for Dewar {dewar_data.dewar_id} not found", + detail=f"Address with ID {value}" + f"for Dewar {dewar_data.dewar_id} not found", ) for key, value in update_fields.items(): diff --git a/backend/app/routers/spreadsheet.py b/backend/app/routers/spreadsheet.py index 37c5331..d93733c 100644 --- a/backend/app/routers/spreadsheet.py +++ b/backend/app/routers/spreadsheet.py @@ -51,9 +51,12 @@ async def upload_file(file: UploadFile = File(...)): ) # Initialize the importer and process the spreadsheet - validated_model, errors, raw_data, headers = ( - importer.import_spreadsheet_with_errors(file) - ) + ( + validated_model, + errors, + raw_data, + headers, + ) = importer.import_spreadsheet_with_errors(file) # Extract unique values for dewars, pucks, and samples dewars = {sample.dewarname for sample in validated_model if sample.dewarname} @@ -82,7 +85,8 @@ async def upload_file(file: UploadFile = File(...)): row_storage.set_row(row_num, row.dict()) logger.info( - f"Returning response with {len(validated_model)} records and {len(errors)} errors." + f"Returning response with {len(validated_model)}" + f"records and {len(errors)} errors." ) return response_data @@ -121,7 +125,9 @@ async def validate_cell(data: dict): try: # Ensure we're using the full row data context for validation - validated_row = SpreadsheetModel(**current_row_data) + SpreadsheetModel( + **current_row_data + ) # Instantiates the Pydantic model, performing validation logger.info(f"Validation succeeded for row {row_num}, column {col_name}") return {"is_valid": True, "message": ""} except ValidationError as e: diff --git a/backend/app/sample_models.py b/backend/app/sample_models.py index 111dc79..b29b2c6 100644 --- a/backend/app/sample_models.py +++ b/backend/app/sample_models.py @@ -15,7 +15,8 @@ class SpreadsheetModel(BaseModel): ..., max_length=64, title="Crystal Name", - description="max_length imposed by MTZ file header format https://www.ccp4.ac.uk/html/mtzformat.html", + description="max_length imposed by MTZ file header" + "format https://www.ccp4.ac.uk/html/mtzformat.html", alias="crystalname", ), ] @@ -27,31 +28,31 @@ class SpreadsheetModel(BaseModel): oscillation: Optional[float] = None # Only accept positive float exposure: Optional[float] = None # Only accept positive floats between 0 and 1 totalrange: Optional[int] = None # Only accept positive integers between 0 and 360 - transmission: Optional[int] = ( - None # Only accept positive integers between 0 and 100 - ) + transmission: Optional[ + int + ] = None # Only accept positive integers between 0 and 100 targetresolution: Optional[float] = None # Only accept positive float aperture: Optional[str] = None # Optional string field - datacollectiontype: Optional[str] = ( - None # Only accept "standard", other types might be added later - ) - processingpipeline: Optional[str] = ( - "" # Only accept "gopy", "autoproc", "xia2dials" - ) - spacegroupnumber: Optional[int] = ( - None # Only accept positive integers between 1 and 230 - ) - cellparameters: Optional[str] = ( - None # Must be a set of six positive floats or integers - ) + datacollectiontype: Optional[ + str + ] = None # Only accept "standard", other types might be added later + processingpipeline: Optional[ + str + ] = "" # Only accept "gopy", "autoproc", "xia2dials" + spacegroupnumber: Optional[ + int + ] = None # Only accept positive integers between 1 and 230 + cellparameters: Optional[ + str + ] = None # Must be a set of six positive floats or integers rescutkey: Optional[str] = None # Only accept "is" or "cchalf" - rescutvalue: Optional[float] = ( - None # Must be a positive float if rescutkey is provided - ) + rescutvalue: Optional[ + float + ] = None # Must be a positive float if rescutkey is provided userresolution: Optional[float] = None - pdbid: Optional[str] = ( - "" # Accepts either the format of the protein data bank code or {provided} - ) + pdbid: Optional[ + str + ] = "" # Accepts either the format of the protein data bank code or {provided} autoprocfull: Optional[bool] = None procfull: Optional[bool] = None adpenabled: Optional[bool] = None @@ -206,11 +207,13 @@ class SpreadsheetModel(BaseModel): v = int(v) if not (0 <= v <= 360): raise ValueError( - f" '{v}' is not valid. Value must be an integer between 0 and 360." + f" '{v}' is not valid." + f"Value must be an integer between 0 and 360." ) except (ValueError, TypeError) as e: raise ValueError( - f" '{v}' is not valid. Value must be an integer between 0 and 360." + f" '{v}' is not valid." + f"Value must be an integer between 0 and 360." ) from e return v @@ -222,11 +225,13 @@ class SpreadsheetModel(BaseModel): v = int(v) if not (0 <= v <= 100): raise ValueError( - f" '{v}' is not valid. Value must be an integer between 0 and 100." + f" '{v}' is not valid." + f"Value must be an integer between 0 and 100." ) except (ValueError, TypeError) as e: raise ValueError( - f" '{v}' is not valid. Value must be an integer between 0 and 100." + f" '{v}' is not valid." + f"Value must be an integer between 0 and 100." ) from e return v @@ -235,7 +240,7 @@ class SpreadsheetModel(BaseModel): def datacollectiontype_allowed(cls, v): allowed = {"standard"} # Other types of data collection might be added later if v and v.lower() not in allowed: - raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.") + raise ValueError(f" '{v}' is not valid." f"Value must be one of {allowed}.") return v @field_validator("processingpipeline", mode="before") @@ -243,7 +248,7 @@ class SpreadsheetModel(BaseModel): def processingpipeline_allowed(cls, v): allowed = {"gopy", "autoproc", "xia2dials"} if v and v.lower() not in allowed: - raise ValueError(f" '{v}' is not valid. Value must be one of {allowed}.") + raise ValueError(f" '{v}' is not valid." f"Value must be one of {allowed}.") return v @field_validator("spacegroupnumber", mode="before") @@ -254,11 +259,13 @@ class SpreadsheetModel(BaseModel): v = int(v) if not (1 <= v <= 230): raise ValueError( - f" '{v}' is not valid. Value must be an integer between 1 and 230." + f" '{v}' is not valid." + f"Value must be an integer between 1 and 230." ) except (ValueError, TypeError) as e: raise ValueError( - f" '{v}' is not valid. Value must be an integer between 1 and 230." + f" '{v}' is not valid." + f"Value must be an integer between 1 and 230." ) from e return v @@ -269,7 +276,8 @@ class SpreadsheetModel(BaseModel): values = [float(i) for i in v.split(",")] if len(values) != 6 or any(val <= 0 for val in values): raise ValueError( - f" '{v}' is not valid. Value must be a set of six positive floats or integers." + f" '{v}' is not valid." + f"Value must be a set of six positive floats or integers." ) return v @@ -295,11 +303,13 @@ class SpreadsheetModel(BaseModel): v = float(v) if not (0 <= v <= 2.0): raise ValueError( - f" '{v}' is not valid. Value must be a float between 0 and 2.0." + f" '{v}' is not valid." + f"Value must be a float between 0 and 2.0." ) except (ValueError, TypeError) as e: raise ValueError( - f" '{v}' is not valid. Value must be a float between 0 and 2.0." + f" '{v}' is not valid." + f"Value must be a float between 0 and 2.0." ) from e return v @@ -311,7 +321,8 @@ class SpreadsheetModel(BaseModel): v = float(v) if not (0 <= v <= 30): raise ValueError( - f" '{v}' is not valid. Value must be a float between 0 and 30." + f" '{v}' is not valid." + f"Value must be a float between 0 and 30." ) except (ValueError, TypeError) as e: raise ValueError( diff --git a/backend/app/schemas.py b/backend/app/schemas.py index 6c4281d..854bbce 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -1,5 +1,5 @@ from typing import List, Optional -from datetime import datetime, timedelta # Add this import +from datetime import datetime from pydantic import BaseModel, EmailStr, constr, Field from datetime import date diff --git a/backend/app/services/spreadsheet_service.py b/backend/app/services/spreadsheet_service.py index 9b3212f..7df4c97 100644 --- a/backend/app/services/spreadsheet_service.py +++ b/backend/app/services/spreadsheet_service.py @@ -1,7 +1,7 @@ import logging import openpyxl from pydantic import ValidationError -from typing import Union, List, Tuple +from typing import List, Tuple from io import BytesIO from app.sample_models import SpreadsheetModel @@ -201,7 +201,8 @@ class SampleSpreadsheetImporter: for error in e.errors(): field = error["loc"][0] msg = error["msg"] - # Map field name (which is the key in `record`) to its index in the row + # Map field name (which is the key in `record`) to its index in the + # row field_to_col = { "dewarname": 0, "puckname": 1,