from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form from sqlalchemy.orm import Session from pathlib import Path from typing import List from datetime import datetime import shutil from app.schemas import ( Puck as PuckSchema, Sample as SampleSchema, SampleEventCreate, Sample, Image, ImageCreate, SampleResult, ExperimentParametersCreate, ExperimentParametersRead, ) from app.models import ( Puck as PuckModel, Sample as SampleModel, SampleEvent as SampleEventModel, Image as ImageModel, Dewar as DewarModel, ExperimentParameters as ExperimentParametersModel, ) from app.dependencies import get_db import logging from sqlalchemy.orm import joinedload router = APIRouter() @router.get("/{puck_id}/samples", response_model=List[SampleSchema]) async def get_samples_with_events(puck_id: int, db: Session = Depends(get_db)): puck = db.query(PuckModel).filter(PuckModel.id == puck_id).first() if not puck: raise HTTPException(status_code=404, detail="Puck not found") samples = db.query(SampleModel).filter(SampleModel.puck_id == puck_id).all() for sample in samples: sample.events = ( db.query(SampleEventModel) .filter(SampleEventModel.sample_id == sample.id) .all() ) return samples @router.get("/pucks-samples", response_model=List[PuckSchema]) async def get_all_pucks_with_samples_and_events( active_pgroup: str, db: Session = Depends(get_db) ): logging.info( "Fetching all pucks with " "samples and events for active_pgroup: %s", active_pgroup, ) pucks = ( db.query(PuckModel) .join(PuckModel.samples) # Join samples related to the puck .join(PuckModel.dewar) # Join the dewar from the puck .join(SampleModel.events) # Join sample events .filter(DewarModel.pgroups == active_pgroup) # Filter by the dewar's group .options( joinedload(PuckModel.samples).joinedload(SampleModel.events), joinedload(PuckModel.dewar), ) .distinct() # Avoid duplicate puck rows if there are multiple events/samples .all() ) if not pucks: raise HTTPException( status_code=404, detail="No pucks found with" " sample events for the active pgroup", ) # Extract samples from each puck if needed filtered_samples = [] for puck in pucks: if puck.dewar and getattr(puck.dewar, "pgroups", None) == active_pgroup: for sample in puck.samples: filtered_samples.append(sample) # Depending on what your endpoint expects, # you may choose to return pucks or samples. # For now, we're returning the list of pucks. return pucks # Route to post a new sample event @router.post("/samples/{sample_id}/events", response_model=Sample) async def create_sample_event( sample_id: int, event: SampleEventCreate, db: Session = Depends(get_db) ): # Ensure the sample exists sample = db.query(SampleModel).filter(SampleModel.id == sample_id).first() if not sample: raise HTTPException(status_code=404, detail="Sample not found") # Create the event sample_event = SampleEventModel( sample_id=sample_id, event_type=event.event_type, timestamp=datetime.now(), # Use the current timestamp ) db.add(sample_event) db.commit() db.refresh(sample_event) # Load events for the sample to be serialized in the response sample.events = ( db.query(SampleEventModel).filter(SampleEventModel.sample_id == sample_id).all() ) return sample # Return the sample, now including `mount_count` @router.post("/{sample_id}/upload-images", response_model=Image) async def upload_sample_image( sample_id: int, uploaded_file: UploadFile = File(...), comment: str = Form(None), db: Session = Depends(get_db), ): logging.info(f"Received file: {uploaded_file.filename}") # Validate Sample sample = db.query(SampleModel).filter(SampleModel.id == sample_id).first() if not sample: raise HTTPException(status_code=404, detail="Sample not found") # Define Directory Structure pgroup = sample.puck.dewar.pgroups # adjust to sample or puck pgroups as needed today = datetime.now().strftime("%Y-%m-%d") dewar_name = ( sample.puck.dewar.dewar_name if sample.puck and sample.puck.dewar else "default_dewar" ) puck_name = sample.puck.puck_name if sample.puck else "default_puck" position = sample.position if sample.position else "default_position" base_dir = Path(f"images/{pgroup}/{today}/{dewar_name}/{puck_name}/{position}") base_dir.mkdir(parents=True, exist_ok=True) # Validate MIME type and Save the File if not uploaded_file.content_type.startswith("image/"): raise HTTPException( status_code=400, detail=f"Invalid file type: {uploaded_file.filename}." f" Only images are accepted.", ) file_path = base_dir / uploaded_file.filename logging.debug(f"Saving file {uploaded_file.filename} to {file_path}") try: with file_path.open("wb") as buffer: shutil.copyfileobj(uploaded_file.file, buffer) logging.info(f"File saved: {file_path}") except Exception as e: logging.error(f"Error saving file {uploaded_file.filename}: {str(e)}") raise HTTPException( status_code=500, detail=f"Could not save file {uploaded_file.filename}." f" Ensure the server has correct permissions.", ) # Create the payload from the Pydantic schema image_payload = ImageCreate( pgroup=pgroup, comment=comment, filepath=str(file_path), status="active", sample_id=sample_id, ).dict() # Convert the payload to your mapped SQLAlchemy model instance. # Make sure that ImageModel is your mapped model for images. new_image = ImageModel(**image_payload) db.add(new_image) db.commit() db.refresh(new_image) logging.info( f"Uploaded 1 file for sample {sample_id} and" f" added record {new_image.id} to the database." ) # Returning the mapped SQLAlchemy object, which will be converted to the # Pydantic response model. return new_image @router.get("/results", response_model=List[SampleResult]) async def get_sample_results(active_pgroup: str, db: Session = Depends(get_db)): # Query samples for the active pgroup using joins. samples = ( db.query(SampleModel) .join(SampleModel.puck) .join(PuckModel.dewar) .filter(DewarModel.pgroups == active_pgroup) .all() ) if not samples: raise HTTPException( status_code=404, detail="No samples found for the active pgroup" ) results = [] for sample in samples: # Query images associated with the sample. images = db.query(ImageModel).filter(ImageModel.sample_id == sample.id).all() # Query experiment parameters (which include beamline parameters) for the # sample. experiment_parameters = ( db.query(ExperimentParametersModel) .filter(ExperimentParametersModel.sample_id == sample.id) .all() ) print("Experiment Parameters for sample", sample.id, experiment_parameters) results.append( { "sample_id": sample.id, "sample_name": sample.sample_name, "puck_name": sample.puck.puck_name if sample.puck else None, "dewar_name": sample.puck.dewar.dewar_name if (sample.puck and sample.puck.dewar) else None, "images": [ {"id": img.id, "filepath": img.filepath, "comment": img.comment} for img in images ], "experiment_runs": [ { "id": ex.id, "run_number": ex.run_number, "beamline_parameters": ex.beamline_parameters, "sample_id": ex.sample_id, } for ex in experiment_parameters ], } ) return results @router.post( "/samples/{sample_id}/experiment_parameters", response_model=ExperimentParametersRead, ) def create_experiment_parameters_for_sample( sample_id: int, exp_params: ExperimentParametersCreate, db: Session = Depends(get_db), ): # Calculate the new run_number for the given sample. # This assumes that the run_number is computed as one plus the maximum # current value. last_exp = ( db.query(ExperimentParametersModel) .filter(ExperimentParametersModel.sample_id == sample_id) .order_by(ExperimentParametersModel.run_number.desc()) .first() ) new_run_number = last_exp.run_number + 1 if last_exp else 1 # Create a new ExperimentParameters record. The beamline_parameters are # stored as JSON. new_exp = ExperimentParametersModel( run_number=new_run_number, beamline_parameters=exp_params.beamline_parameters.dict() if exp_params.beamline_parameters else None, sample_id=sample_id, ) db.add(new_exp) db.commit() db.refresh(new_exp) return new_exp