fix(scan user data): added defaults to model

This commit is contained in:
2025-01-23 13:16:57 +01:00
parent a158e8deac
commit 3a142b919b
5 changed files with 91 additions and 25 deletions

View File

@ -136,7 +136,14 @@ class MongoDBDatasource:
return dtype(**out)
def find(
self, collection: str, query_filter: dict, dtype: Type[T], user: User | None = None
self,
collection: str,
query_filter: dict,
dtype: Type[T],
limit: int = 0,
offset: int = 0,
fields: list[str] = None,
user: User | None = None,
) -> list[T]:
"""
Find all documents in the collection.
@ -152,7 +159,7 @@ class MongoDBDatasource:
"""
if user is not None:
query_filter = self.add_user_filter(user, query_filter)
out = self.db[collection].find(query_filter)
out = self.db[collection].find(query_filter, limit=limit, skip=offset, projection=fields)
return [dtype(**x) for x in out]
def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T:
@ -251,6 +258,7 @@ class MongoDBDatasource:
access_filter = {"$match": self._read_only_user_filter(user)}
lookup_pipeline.insert(0, access_filter)
# pipeline = self.add_user_filter(user, pipeline)
out = self.db[collection].aggregate(pipeline)
return [dtype(**x) for x in out]

View File

@ -55,7 +55,7 @@ class DataIngestor:
"""
out = self.redis.get("deployments")
if out:
self.available_deployments = json.loads(out)
self.available_deployments = out.data
self.update_consumer_groups()
self.deployment_listener_thread = threading.Thread(
target=self.update_available_deployments, name="deployment_listener"
@ -192,7 +192,7 @@ class DataIngestor:
"""
out = self.datasource.db["sessions"].find_one(
{"name": "_default_", "deployment_id": ObjectId(deployment_id)}
{"name": "_default_", "deployment_id": deployment_id}
)
if out is None:
return None
@ -232,7 +232,7 @@ class DataIngestor:
out["_id"] = msg.scan_id
# TODO for compatibility with the old message format; remove once the bec_lib is updated
out["session_id"] = session_id
out["session_id"] = str(session_id)
self.datasource.db["scans"].insert_one(out)
else:

View File

@ -1,10 +1,26 @@
from __future__ import annotations
from typing import Any, Literal
from types import UnionType
from typing import Any, Literal, Optional, Type, TypeVar, Union
from bec_lib import messages
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer
T = TypeVar("T")
def make_all_fields_optional(model: Type[T], model_name: str) -> Type[T]:
"""Convert all fields in a Pydantic model to Optional."""
# create a dictionary of fields with the same name but with the type Optional[field]
# and a default value of None
fields = {}
for name, field in model.__fields__.items():
fields[name] = (field.annotation, None)
return create_model(model_name, **fields, __config__=model.model_config)
class MongoBaseModel(BaseModel):
@ -29,7 +45,11 @@ class AccessProfilePartial(AccessProfile):
access_groups: list[str] | None = None
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage):
user_data: ScanUserData | None = None
ScanStatusPartial = make_all_fields_optional(ScanStatus, "ScanStatusPartial")
class UserCredentials(MongoBaseModel, AccessProfile):
@ -167,14 +187,12 @@ class DatasetUserData(AccessProfile):
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class ScanUserData(AccessProfile):
class ScanUserData(MongoBaseModel, AccessProfile):
scan_id: str
name: str
rating: int
comments: str
preview: bytes
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
name: str | None = None
rating: int | None = None
comments: str | None = None
preview: bytes | None = None
class DeviceConfig(AccessProfile):

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from bec_atlas.authentication import get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import ScanStatus, UserInfo
from bec_atlas.model.model import ScanStatusPartial, UserInfo
from bec_atlas.router.base_router import BaseRouter
@ -16,32 +16,72 @@ class ScanRouter(BaseRouter):
self.scans,
methods=["GET"],
description="Get all scans for a session",
response_model=list[ScanStatus],
response_model=list[ScanStatusPartial],
response_model_exclude_none=True,
)
self.router.add_api_route(
"/scans/id",
self.scans_with_id,
methods=["GET"],
description="Get a single scan by id for a session",
response_model=ScanStatus,
response_model=ScanStatusPartial,
response_model_exclude_none=True,
)
async def scans(
self, session_id: str, current_user: UserInfo = Depends(get_current_user)
) -> list[ScanStatus]:
self,
session_id: str,
include_user_data: bool = False,
fields: list[str] = Query(default=None),
offset: int = 0,
limit: int = 100,
current_user: UserInfo = Depends(get_current_user),
) -> list[ScanStatusPartial]:
"""
Get all scans for a session.
Args:
session_id (str): The session id
"""
return self.db.find("scans", {"session_id": session_id}, ScanStatus, user=current_user)
if fields:
fields = {field: 1 for field in fields}
if include_user_data:
include = [{"$match": {"session_id": session_id}}]
if fields:
include.append({"$project": fields})
include += [
{"$skip": offset},
{"$limit": limit},
{
"$lookup": {
"from": "scan_user_data",
"let": {"_id": "$_id"},
"pipeline": [{"$match": {"$expr": {"$eq": ["$_id", "$$_id"]}}}],
"as": "user_data",
}
},
]
return self.db.aggregate("scans", include, ScanStatusPartial, user=current_user)
return self.db.find(
"scans",
{"session_id": session_id},
ScanStatusPartial,
limit=limit,
offset=offset,
fields=fields,
user=current_user,
)
async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)):
async def scans_with_id(
self,
scan_id: str,
include_user_data: bool = False,
current_user: UserInfo = Depends(get_current_user),
):
"""
Get scan with id from session
Args:
scan_id (str): The scan id
"""
return self.db.find_one("scans", {"_id": scan_id}, ScanStatus, user=current_user)
return self.db.find_one("scans", {"_id": scan_id}, ScanStatusPartial, user=current_user)

View File

@ -74,7 +74,7 @@ class DemoSetupLoader:
default_session = Session(
owner_groups=["admin", "demo"],
access_groups=["demo"],
deployment_id=deployment["_id"],
deployment_id=str(deployment["_id"]),
name="_default_",
)
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))