Add endpoint for creating local contacts with access control

Introduced a new `local_contact_router` to handle creation of local contacts. The endpoint enforces role-based access control and ensures no duplication of email addresses. Updated the router exports for consistency and cleaned up a large test file to improve readability.
This commit is contained in:
GotthardG
2025-02-26 09:58:19 +01:00
parent 43d67b1044
commit f588bc0cda
13 changed files with 360 additions and 418 deletions

View File

@ -9,6 +9,8 @@ from .data import (
dewar_types,
serial_numbers,
sample_events,
local_contacts,
beamtimes,
)
from .slots_data import slots
@ -24,4 +26,6 @@ __all__ = [
"serial_numbers",
"sample_events",
"slots",
"local_contacts",
"beamtimes",
]

View File

@ -10,13 +10,14 @@ from app.models import (
DewarSerialNumber,
SampleEvent,
LogisticsEvent,
LocalContact,
Beamtime,
)
from datetime import datetime, timedelta
import random
import time
import hashlib
dewar_types = [
DewarType(id=1, dewar_type="Type A"),
DewarType(id=2, dewar_type="Type B"),
@ -373,6 +374,50 @@ specific_dewars1 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids1
specific_dewars2 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids2]
specific_dewars3 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids3]
local_contacts = [
LocalContact(
id=1,
firstname="John",
lastname="Rambo",
phone_number="+410000000",
email="john.rambo@war.com",
),
LocalContact(
id=2,
firstname="John",
lastname="Mclane",
phone_number="+9990000099",
email="john.mclane@war.com",
),
]
beamtimes = [
Beamtime(
id=1,
pgroups="p20001",
beamtime_name="p20001-test",
beamline="X06DA",
start_date=datetime.strptime("06.02.2025", "%d.%m.%Y").date(),
end_date=datetime.strptime("07.02.2025", "%d.%m.%Y").date(),
status="confirmed",
comments="this is a test beamtime",
proposal_id=1,
local_contact_id=1,
),
Beamtime(
id=2,
pgroups="p20002",
beamtime_name="p20001-test",
beamline="X06DA",
start_date=datetime.strptime("07.02.2025", "%d.%m.%Y").date(),
end_date=datetime.strptime("08.02.2025", "%d.%m.%Y").date(),
status="confirmed",
comments="this is a test beamtime",
proposal_id=2,
local_contact_id=2,
),
]
# Define shipments
shipments = [
Shipment(

View File

@ -77,6 +77,8 @@ def load_sample_data(session: Session):
serial_numbers,
slots,
sample_events,
local_contacts,
beamtimes,
)
# If any data exists, don't reseed
@ -95,5 +97,8 @@ def load_sample_data(session: Session):
+ serial_numbers
+ slots
+ sample_events
+ local_contacts
+ beamtimes
)
session.commit()

View File

@ -45,6 +45,17 @@ class Contact(Base):
shipments = relationship("Shipment", back_populates="contact")
class LocalContact(Base):
__tablename__ = "local_contacts"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
status = Column(String(255), default="active")
firstname = Column(String(255), nullable=False)
lastname = Column(String(255), nullable=False)
phone_number = Column(String(255), nullable=False)
email = Column(String(255), nullable=False)
class Address(Base):
__tablename__ = "addresses"
@ -103,6 +114,10 @@ class Dewar(Base):
slot = relationship("Slot", back_populates="dewar")
events = relationship("LogisticsEvent", back_populates="dewar")
beamline_location = None
local_contact_id = Column(Integer, ForeignKey("local_contacts.id"), nullable=True)
local_contact = relationship("LocalContact")
beamtime = relationship("Beamtime", back_populates="dewars")
beamtime_id = Column(Integer, ForeignKey("beamtimes.id"), nullable=True)
@property
def number_of_pucks(self) -> int:
@ -216,6 +231,24 @@ class PuckEvent(Base):
puck = relationship("Puck", back_populates="events")
class Beamtime(Base):
__tablename__ = "beamtimes"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
pgroups = Column(String(255), nullable=False)
beamtime_name = Column(String(255), index=True)
beamline = Column(String(255), nullable=True)
start_date = Column(Date, nullable=True)
end_date = Column(Date, nullable=True)
status = Column(String(255), nullable=True)
comments = Column(String(200), nullable=True)
proposal_id = Column(Integer, ForeignKey("proposals.id"), nullable=True)
local_contact_id = Column(Integer, ForeignKey("local_contacts.id"), nullable=False)
local_contact = relationship("LocalContact")
dewars = relationship("Dewar", back_populates="beamtime")
# class Results(Base):
# __tablename__ = "results"
#

View File

@ -1,5 +1,6 @@
from .address import address_router
from .contact import contact_router
from .local_contact import local_contact_router
from .proposal import router as proposal_router
from .dewar import dewar_router
from .shipment import shipment_router
@ -9,6 +10,7 @@ from .protected_router import protected_router as protected_router
__all__ = [
"address_router",
"contact_router",
"local_contact_router",
"proposal_router",
"dewar_router",
"shipment_router",

View File

@ -21,6 +21,12 @@ mock_users_db = {
"password": "testpass2", # In a real scenario, store the hash of the password
"pgroups": ["p20004", "p20005", "p20006"],
},
"admin": {
"username": "admin",
"password": "adminpass",
"pgroups": ["p20007"],
# "role": "admin",
},
}
@ -49,6 +55,9 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
username: str = payload.get("sub")
print(f"[DEBUG] Username decoded from token: {username}") # Add debug log here
return loginData(username=username, pgroups=payload.get("pgroups"))
# return loginData(username=username, pgroups=payload.get("pgroups"),
# role=payload.get("role"))
except jwt.ExpiredSignatureError:
print("[DEBUG] Token expired")
raise HTTPException(status_code=401, detail="Token expired")
@ -57,6 +66,14 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
raise HTTPException(status_code=401, detail="Invalid token")
# async def get_user_role(token: str = Depends(oauth2_scheme)) -> str:
# try:
# payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# return payload.get("role")
# except jwt.ExpiredSignatureError:
# raise HTTPException(status_code=401, detail="Token expired")
@router.post("/token/login", response_model=loginToken)
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = mock_users_db.get(form_data.username)
@ -70,10 +87,14 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
# Create token
access_token = create_access_token(
data={"sub": user["username"], "pgroups": user["pgroups"]}
# data = {"sub": user["username"], "pgroups": user["pgroups"],
# "role": user["role"]}
)
return loginToken(access_token=access_token, token_type="bearer")
@router.get("/protected-route")
async def read_protected_data(current_user: loginData = Depends(get_current_user)):
# return {"username": current_user.username, "pgroups":
# current_user.pgroups, "role": current_user.role}
return {"username": current_user.username, "pgroups": current_user.pgroups}

View File

@ -0,0 +1,56 @@
from fastapi import APIRouter, HTTPException, status, Depends
from sqlalchemy.orm import Session
from app.models import LocalContact as LocalContactModel
from app.schemas import LocalContactCreate as LocalContactSchema, loginData
from app.dependencies import get_db
from app.routers.auth import get_current_user
local_contact_router = APIRouter()
@local_contact_router.post(
"/",
response_model=LocalContactSchema,
status_code=status.HTTP_201_CREATED,
)
async def create_local_contact(
local_contact: LocalContactSchema,
db: Session = Depends(get_db),
current_user: loginData = Depends(get_current_user),
):
"""
Create a new local contact. Only selected users can create a local contact.
"""
# Access control: Only allow users with specific roles (e.g., "admin" or
# "contact_manager")
if current_user.role not in ["admin", "contact_manager"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to create a local contact.",
)
# Check if a local contact with the same email already exists
if (
db.query(LocalContactModel)
.filter(LocalContactModel.email == local_contact.email)
.first()
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A local contact with this email already exists.",
)
# Create a new LocalContact
db_local_contact = LocalContactModel(
firstname=local_contact.firstname,
lastname=local_contact.lastname,
phone_number=local_contact.phone_number,
email=local_contact.email,
status=local_contact.status or "active",
)
db.add(db_local_contact)
db.commit()
db.refresh(db_local_contact)
return db_local_contact

View File

@ -351,7 +351,7 @@ async def get_all_dewars(db: Session = Depends(get_db)):
@router.get("/dewar/table", response_model=List[DewarTable])
async def get_all_dewars_table(db: Session = Depends(get_db)):
dewars = db.query(DewarModel).all()
dewars = db.query(DewarModel).filter(DewarModel.events.any()).all()
# Flatten relationships for simplified frontend rendering
response = []
@ -365,6 +365,7 @@ async def get_all_dewars_table(db: Session = Depends(get_db)):
dewar_name=dewar.dewar_name,
shipment_name=dewar.shipment.shipment_name if dewar.shipment else "N/A",
# Use the most recent event if available
beamtime=dewar.beamtime,
status=dewar.events[-1].event_type if dewar.events else "No Events",
tracking_number=dewar.tracking_number or "N/A",
slot_id=dewar.slot[0].id

View File

@ -5,6 +5,7 @@ from app.routers.address import address_router
from app.routers.contact import contact_router
from app.routers.shipment import shipment_router
from app.routers.dewar import dewar_router
from app.routers.local_contact import local_contact_router
protected_router = APIRouter(
dependencies=[Depends(get_current_user)] # Applies to all routes
@ -12,6 +13,9 @@ protected_router = APIRouter(
protected_router.include_router(address_router, prefix="/addresses", tags=["addresses"])
protected_router.include_router(contact_router, prefix="/contacts", tags=["contacts"])
protected_router.include_router(
local_contact_router, prefix="/local-contacts", tags=["local-contacts"]
)
protected_router.include_router(
shipment_router, prefix="/shipments", tags=["shipments"]
)

View File

@ -89,19 +89,19 @@ async def create_sample_event(
return sample # Return the sample, now including `mount_count`
@router.post("/samples/{sample_id}/upload-images")
async def upload_sample_images(
@router.post("/{sample_id}/upload-images")
async def upload_sample_image(
sample_id: int,
uploaded_files: list[UploadFile] = File(...),
uploaded_file: UploadFile = File(...),
db: Session = Depends(get_db),
):
logging.info(f"Received files: {[file.filename for file in uploaded_files]}")
logging.info(f"Received file: {uploaded_file.filename}")
"""
Uploads images for a given sample and saves them to a directory structure.
Uploads an image for a given sample and saves it to a directory structure.
Args:
sample_id (int): ID of the sample.
uploaded_files (list[UploadFile]): A list of files uploaded with the request.
uploaded_file (UploadFile): The file uploaded with the request.
db (Session): Database session.
"""
@ -123,35 +123,32 @@ async def upload_sample_images(
base_dir = Path(f"images/{pgroup}/{today}/{dewar_name}/{puck_name}/{position}")
base_dir.mkdir(parents=True, exist_ok=True)
# 3. Process and Save Each File
saved_files = []
for file in uploaded_files:
# Validate MIME type
if not file.content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail=f"Invalid file type: {file.filename}. Only images are accepted.",
)
# 3. Validate MIME type and Save the File
if not uploaded_file.content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail=f"Invalid file type: {uploaded_file.filename}."
f" Only images are accepted.",
)
# Save file to the base directory
file_path = base_dir / file.filename
file_path = base_dir / uploaded_file.filename
logging.debug(f"Saving file {uploaded_file.filename} to {file_path}")
# Save the file from the file stream
try:
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
saved_files.append(str(file_path)) # Track saved file paths
except Exception as e:
logging.error(f"Error saving file {file.filename}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Could not save file {file.filename}."
f" Ensure the server has correct permissions.",
)
try:
with file_path.open("wb") as buffer:
shutil.copyfileobj(uploaded_file.file, buffer)
logging.info(f"File saved: {file_path}")
except Exception as e:
logging.error(f"Error saving file {uploaded_file.filename}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Could not save file {uploaded_file.filename}."
f" Ensure the server has correct permissions.",
)
# 4. Return Saved Files Information
logging.info(f"Uploaded {len(saved_files)} files for sample {sample_id}.")
# 4. Return Saved File Information
logging.info(f"Uploaded 1 file for sample {sample_id}.")
return {
"message": f"{len(saved_files)} images uploaded successfully.",
"files": saved_files,
"message": "1 image uploaded successfully.",
"file": str(file_path),
}

View File

@ -17,6 +17,7 @@ class loginToken(BaseModel):
class loginData(BaseModel):
username: str
pgroups: List[str]
# role: Optional[str] = "user"
class DewarTypeBase(BaseModel):
@ -417,6 +418,29 @@ class ContactMinimal(BaseModel):
id: int
class Proposal(BaseModel):
id: int
number: str
class Config:
from_attributes = True
class LocalContactCreate(BaseModel):
firstname: str
lastname: str
phone_number: str
email: EmailStr
status: str = "active"
class Config:
from_attributes = True
class LocalContact(LocalContactCreate):
id: int
class AddressCreate(BaseModel):
pgroups: str
house_number: Optional[str] = None
@ -617,14 +641,6 @@ class DewarTable(BaseModel):
from_attributes = True
class Proposal(BaseModel):
id: int
number: str
class Config:
from_attributes = True
class Shipment(BaseModel):
id: int
pgroups: str
@ -752,3 +768,18 @@ class PuckWithTellPosition(BaseModel):
class Config:
from_attributes = True
class Beamtime(BaseModel):
id: int
pgroups: str
beamtime_name: str
beamline: str
start_date: date
end_date: date
status: str
comments: Optional[constr(max_length=200)] = None
proposal_id: Optional[int]
proposal: Optional[Proposal]
local_contact_id: Optional[int]
local_contact: Optional[LocalContact]