Add endpoint for creating local contacts with access control
Introduced a new `local_contact_router` to handle creation of local contacts. The endpoint enforces role-based access control and ensures no duplication of email addresses. Updated the router exports for consistency and cleaned up a large test file to improve readability.
This commit is contained in:
@ -9,6 +9,8 @@ from .data import (
|
||||
dewar_types,
|
||||
serial_numbers,
|
||||
sample_events,
|
||||
local_contacts,
|
||||
beamtimes,
|
||||
)
|
||||
from .slots_data import slots
|
||||
|
||||
@ -24,4 +26,6 @@ __all__ = [
|
||||
"serial_numbers",
|
||||
"sample_events",
|
||||
"slots",
|
||||
"local_contacts",
|
||||
"beamtimes",
|
||||
]
|
||||
|
@ -10,13 +10,14 @@ from app.models import (
|
||||
DewarSerialNumber,
|
||||
SampleEvent,
|
||||
LogisticsEvent,
|
||||
LocalContact,
|
||||
Beamtime,
|
||||
)
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
|
||||
dewar_types = [
|
||||
DewarType(id=1, dewar_type="Type A"),
|
||||
DewarType(id=2, dewar_type="Type B"),
|
||||
@ -373,6 +374,50 @@ specific_dewars1 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids1
|
||||
specific_dewars2 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids2]
|
||||
specific_dewars3 = [dewar for dewar in dewars if dewar.id in specific_dewar_ids3]
|
||||
|
||||
local_contacts = [
|
||||
LocalContact(
|
||||
id=1,
|
||||
firstname="John",
|
||||
lastname="Rambo",
|
||||
phone_number="+410000000",
|
||||
email="john.rambo@war.com",
|
||||
),
|
||||
LocalContact(
|
||||
id=2,
|
||||
firstname="John",
|
||||
lastname="Mclane",
|
||||
phone_number="+9990000099",
|
||||
email="john.mclane@war.com",
|
||||
),
|
||||
]
|
||||
|
||||
beamtimes = [
|
||||
Beamtime(
|
||||
id=1,
|
||||
pgroups="p20001",
|
||||
beamtime_name="p20001-test",
|
||||
beamline="X06DA",
|
||||
start_date=datetime.strptime("06.02.2025", "%d.%m.%Y").date(),
|
||||
end_date=datetime.strptime("07.02.2025", "%d.%m.%Y").date(),
|
||||
status="confirmed",
|
||||
comments="this is a test beamtime",
|
||||
proposal_id=1,
|
||||
local_contact_id=1,
|
||||
),
|
||||
Beamtime(
|
||||
id=2,
|
||||
pgroups="p20002",
|
||||
beamtime_name="p20001-test",
|
||||
beamline="X06DA",
|
||||
start_date=datetime.strptime("07.02.2025", "%d.%m.%Y").date(),
|
||||
end_date=datetime.strptime("08.02.2025", "%d.%m.%Y").date(),
|
||||
status="confirmed",
|
||||
comments="this is a test beamtime",
|
||||
proposal_id=2,
|
||||
local_contact_id=2,
|
||||
),
|
||||
]
|
||||
|
||||
# Define shipments
|
||||
shipments = [
|
||||
Shipment(
|
||||
|
@ -77,6 +77,8 @@ def load_sample_data(session: Session):
|
||||
serial_numbers,
|
||||
slots,
|
||||
sample_events,
|
||||
local_contacts,
|
||||
beamtimes,
|
||||
)
|
||||
|
||||
# If any data exists, don't reseed
|
||||
@ -95,5 +97,8 @@ def load_sample_data(session: Session):
|
||||
+ serial_numbers
|
||||
+ slots
|
||||
+ sample_events
|
||||
+ local_contacts
|
||||
+ beamtimes
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
@ -45,6 +45,17 @@ class Contact(Base):
|
||||
shipments = relationship("Shipment", back_populates="contact")
|
||||
|
||||
|
||||
class LocalContact(Base):
|
||||
__tablename__ = "local_contacts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
status = Column(String(255), default="active")
|
||||
firstname = Column(String(255), nullable=False)
|
||||
lastname = Column(String(255), nullable=False)
|
||||
phone_number = Column(String(255), nullable=False)
|
||||
email = Column(String(255), nullable=False)
|
||||
|
||||
|
||||
class Address(Base):
|
||||
__tablename__ = "addresses"
|
||||
|
||||
@ -103,6 +114,10 @@ class Dewar(Base):
|
||||
slot = relationship("Slot", back_populates="dewar")
|
||||
events = relationship("LogisticsEvent", back_populates="dewar")
|
||||
beamline_location = None
|
||||
local_contact_id = Column(Integer, ForeignKey("local_contacts.id"), nullable=True)
|
||||
local_contact = relationship("LocalContact")
|
||||
beamtime = relationship("Beamtime", back_populates="dewars")
|
||||
beamtime_id = Column(Integer, ForeignKey("beamtimes.id"), nullable=True)
|
||||
|
||||
@property
|
||||
def number_of_pucks(self) -> int:
|
||||
@ -216,6 +231,24 @@ class PuckEvent(Base):
|
||||
puck = relationship("Puck", back_populates="events")
|
||||
|
||||
|
||||
class Beamtime(Base):
|
||||
__tablename__ = "beamtimes"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
pgroups = Column(String(255), nullable=False)
|
||||
beamtime_name = Column(String(255), index=True)
|
||||
beamline = Column(String(255), nullable=True)
|
||||
start_date = Column(Date, nullable=True)
|
||||
end_date = Column(Date, nullable=True)
|
||||
status = Column(String(255), nullable=True)
|
||||
comments = Column(String(200), nullable=True)
|
||||
proposal_id = Column(Integer, ForeignKey("proposals.id"), nullable=True)
|
||||
local_contact_id = Column(Integer, ForeignKey("local_contacts.id"), nullable=False)
|
||||
|
||||
local_contact = relationship("LocalContact")
|
||||
dewars = relationship("Dewar", back_populates="beamtime")
|
||||
|
||||
|
||||
# class Results(Base):
|
||||
# __tablename__ = "results"
|
||||
#
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .address import address_router
|
||||
from .contact import contact_router
|
||||
from .local_contact import local_contact_router
|
||||
from .proposal import router as proposal_router
|
||||
from .dewar import dewar_router
|
||||
from .shipment import shipment_router
|
||||
@ -9,6 +10,7 @@ from .protected_router import protected_router as protected_router
|
||||
__all__ = [
|
||||
"address_router",
|
||||
"contact_router",
|
||||
"local_contact_router",
|
||||
"proposal_router",
|
||||
"dewar_router",
|
||||
"shipment_router",
|
||||
|
@ -21,6 +21,12 @@ mock_users_db = {
|
||||
"password": "testpass2", # In a real scenario, store the hash of the password
|
||||
"pgroups": ["p20004", "p20005", "p20006"],
|
||||
},
|
||||
"admin": {
|
||||
"username": "admin",
|
||||
"password": "adminpass",
|
||||
"pgroups": ["p20007"],
|
||||
# "role": "admin",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -49,6 +55,9 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
|
||||
username: str = payload.get("sub")
|
||||
print(f"[DEBUG] Username decoded from token: {username}") # Add debug log here
|
||||
return loginData(username=username, pgroups=payload.get("pgroups"))
|
||||
# return loginData(username=username, pgroups=payload.get("pgroups"),
|
||||
# role=payload.get("role"))
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
print("[DEBUG] Token expired")
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
@ -57,6 +66,14 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> loginData:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
# async def get_user_role(token: str = Depends(oauth2_scheme)) -> str:
|
||||
# try:
|
||||
# payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
# return payload.get("role")
|
||||
# except jwt.ExpiredSignatureError:
|
||||
# raise HTTPException(status_code=401, detail="Token expired")
|
||||
|
||||
|
||||
@router.post("/token/login", response_model=loginToken)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
user = mock_users_db.get(form_data.username)
|
||||
@ -70,10 +87,14 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
# Create token
|
||||
access_token = create_access_token(
|
||||
data={"sub": user["username"], "pgroups": user["pgroups"]}
|
||||
# data = {"sub": user["username"], "pgroups": user["pgroups"],
|
||||
# "role": user["role"]}
|
||||
)
|
||||
return loginToken(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
@router.get("/protected-route")
|
||||
async def read_protected_data(current_user: loginData = Depends(get_current_user)):
|
||||
# return {"username": current_user.username, "pgroups":
|
||||
# current_user.pgroups, "role": current_user.role}
|
||||
return {"username": current_user.username, "pgroups": current_user.pgroups}
|
||||
|
56
backend/app/routers/local_contact.py
Normal file
56
backend/app/routers/local_contact.py
Normal file
@ -0,0 +1,56 @@
|
||||
from fastapi import APIRouter, HTTPException, status, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import LocalContact as LocalContactModel
|
||||
from app.schemas import LocalContactCreate as LocalContactSchema, loginData
|
||||
from app.dependencies import get_db
|
||||
from app.routers.auth import get_current_user
|
||||
|
||||
local_contact_router = APIRouter()
|
||||
|
||||
|
||||
@local_contact_router.post(
|
||||
"/",
|
||||
response_model=LocalContactSchema,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_local_contact(
|
||||
local_contact: LocalContactSchema,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: loginData = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new local contact. Only selected users can create a local contact.
|
||||
"""
|
||||
# Access control: Only allow users with specific roles (e.g., "admin" or
|
||||
# "contact_manager")
|
||||
if current_user.role not in ["admin", "contact_manager"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to create a local contact.",
|
||||
)
|
||||
|
||||
# Check if a local contact with the same email already exists
|
||||
if (
|
||||
db.query(LocalContactModel)
|
||||
.filter(LocalContactModel.email == local_contact.email)
|
||||
.first()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="A local contact with this email already exists.",
|
||||
)
|
||||
|
||||
# Create a new LocalContact
|
||||
db_local_contact = LocalContactModel(
|
||||
firstname=local_contact.firstname,
|
||||
lastname=local_contact.lastname,
|
||||
phone_number=local_contact.phone_number,
|
||||
email=local_contact.email,
|
||||
status=local_contact.status or "active",
|
||||
)
|
||||
db.add(db_local_contact)
|
||||
db.commit()
|
||||
db.refresh(db_local_contact)
|
||||
|
||||
return db_local_contact
|
@ -351,7 +351,7 @@ async def get_all_dewars(db: Session = Depends(get_db)):
|
||||
|
||||
@router.get("/dewar/table", response_model=List[DewarTable])
|
||||
async def get_all_dewars_table(db: Session = Depends(get_db)):
|
||||
dewars = db.query(DewarModel).all()
|
||||
dewars = db.query(DewarModel).filter(DewarModel.events.any()).all()
|
||||
|
||||
# Flatten relationships for simplified frontend rendering
|
||||
response = []
|
||||
@ -365,6 +365,7 @@ async def get_all_dewars_table(db: Session = Depends(get_db)):
|
||||
dewar_name=dewar.dewar_name,
|
||||
shipment_name=dewar.shipment.shipment_name if dewar.shipment else "N/A",
|
||||
# Use the most recent event if available
|
||||
beamtime=dewar.beamtime,
|
||||
status=dewar.events[-1].event_type if dewar.events else "No Events",
|
||||
tracking_number=dewar.tracking_number or "N/A",
|
||||
slot_id=dewar.slot[0].id
|
||||
|
@ -5,6 +5,7 @@ from app.routers.address import address_router
|
||||
from app.routers.contact import contact_router
|
||||
from app.routers.shipment import shipment_router
|
||||
from app.routers.dewar import dewar_router
|
||||
from app.routers.local_contact import local_contact_router
|
||||
|
||||
protected_router = APIRouter(
|
||||
dependencies=[Depends(get_current_user)] # Applies to all routes
|
||||
@ -12,6 +13,9 @@ protected_router = APIRouter(
|
||||
|
||||
protected_router.include_router(address_router, prefix="/addresses", tags=["addresses"])
|
||||
protected_router.include_router(contact_router, prefix="/contacts", tags=["contacts"])
|
||||
protected_router.include_router(
|
||||
local_contact_router, prefix="/local-contacts", tags=["local-contacts"]
|
||||
)
|
||||
protected_router.include_router(
|
||||
shipment_router, prefix="/shipments", tags=["shipments"]
|
||||
)
|
||||
|
@ -89,19 +89,19 @@ async def create_sample_event(
|
||||
return sample # Return the sample, now including `mount_count`
|
||||
|
||||
|
||||
@router.post("/samples/{sample_id}/upload-images")
|
||||
async def upload_sample_images(
|
||||
@router.post("/{sample_id}/upload-images")
|
||||
async def upload_sample_image(
|
||||
sample_id: int,
|
||||
uploaded_files: list[UploadFile] = File(...),
|
||||
uploaded_file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
logging.info(f"Received files: {[file.filename for file in uploaded_files]}")
|
||||
logging.info(f"Received file: {uploaded_file.filename}")
|
||||
|
||||
"""
|
||||
Uploads images for a given sample and saves them to a directory structure.
|
||||
Uploads an image for a given sample and saves it to a directory structure.
|
||||
Args:
|
||||
sample_id (int): ID of the sample.
|
||||
uploaded_files (list[UploadFile]): A list of files uploaded with the request.
|
||||
uploaded_file (UploadFile): The file uploaded with the request.
|
||||
db (Session): Database session.
|
||||
"""
|
||||
|
||||
@ -123,35 +123,32 @@ async def upload_sample_images(
|
||||
base_dir = Path(f"images/{pgroup}/{today}/{dewar_name}/{puck_name}/{position}")
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. Process and Save Each File
|
||||
saved_files = []
|
||||
for file in uploaded_files:
|
||||
# Validate MIME type
|
||||
if not file.content_type.startswith("image/"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type: {file.filename}. Only images are accepted.",
|
||||
)
|
||||
# 3. Validate MIME type and Save the File
|
||||
if not uploaded_file.content_type.startswith("image/"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type: {uploaded_file.filename}."
|
||||
f" Only images are accepted.",
|
||||
)
|
||||
|
||||
# Save file to the base directory
|
||||
file_path = base_dir / file.filename
|
||||
file_path = base_dir / uploaded_file.filename
|
||||
logging.debug(f"Saving file {uploaded_file.filename} to {file_path}")
|
||||
|
||||
# Save the file from the file stream
|
||||
try:
|
||||
with file_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
saved_files.append(str(file_path)) # Track saved file paths
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving file {file.filename}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Could not save file {file.filename}."
|
||||
f" Ensure the server has correct permissions.",
|
||||
)
|
||||
try:
|
||||
with file_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(uploaded_file.file, buffer)
|
||||
logging.info(f"File saved: {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving file {uploaded_file.filename}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Could not save file {uploaded_file.filename}."
|
||||
f" Ensure the server has correct permissions.",
|
||||
)
|
||||
|
||||
# 4. Return Saved Files Information
|
||||
logging.info(f"Uploaded {len(saved_files)} files for sample {sample_id}.")
|
||||
# 4. Return Saved File Information
|
||||
logging.info(f"Uploaded 1 file for sample {sample_id}.")
|
||||
return {
|
||||
"message": f"{len(saved_files)} images uploaded successfully.",
|
||||
"files": saved_files,
|
||||
"message": "1 image uploaded successfully.",
|
||||
"file": str(file_path),
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ class loginToken(BaseModel):
|
||||
class loginData(BaseModel):
|
||||
username: str
|
||||
pgroups: List[str]
|
||||
# role: Optional[str] = "user"
|
||||
|
||||
|
||||
class DewarTypeBase(BaseModel):
|
||||
@ -417,6 +418,29 @@ class ContactMinimal(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class Proposal(BaseModel):
|
||||
id: int
|
||||
number: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LocalContactCreate(BaseModel):
|
||||
firstname: str
|
||||
lastname: str
|
||||
phone_number: str
|
||||
email: EmailStr
|
||||
status: str = "active"
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LocalContact(LocalContactCreate):
|
||||
id: int
|
||||
|
||||
|
||||
class AddressCreate(BaseModel):
|
||||
pgroups: str
|
||||
house_number: Optional[str] = None
|
||||
@ -617,14 +641,6 @@ class DewarTable(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Proposal(BaseModel):
|
||||
id: int
|
||||
number: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Shipment(BaseModel):
|
||||
id: int
|
||||
pgroups: str
|
||||
@ -752,3 +768,18 @@ class PuckWithTellPosition(BaseModel):
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Beamtime(BaseModel):
|
||||
id: int
|
||||
pgroups: str
|
||||
beamtime_name: str
|
||||
beamline: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
status: str
|
||||
comments: Optional[constr(max_length=200)] = None
|
||||
proposal_id: Optional[int]
|
||||
proposal: Optional[Proposal]
|
||||
local_contact_id: Optional[int]
|
||||
local_contact: Optional[LocalContact]
|
||||
|
Reference in New Issue
Block a user