mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
fix(scan user data): added defaults to model
This commit is contained in:
@ -136,7 +136,14 @@ class MongoDBDatasource:
|
||||
return dtype(**out)
|
||||
|
||||
def find(
|
||||
self, collection: str, query_filter: dict, dtype: Type[T], user: User | None = None
|
||||
self,
|
||||
collection: str,
|
||||
query_filter: dict,
|
||||
dtype: Type[T],
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
fields: list[str] = None,
|
||||
user: User | None = None,
|
||||
) -> list[T]:
|
||||
"""
|
||||
Find all documents in the collection.
|
||||
@ -152,7 +159,7 @@ class MongoDBDatasource:
|
||||
"""
|
||||
if user is not None:
|
||||
query_filter = self.add_user_filter(user, query_filter)
|
||||
out = self.db[collection].find(query_filter)
|
||||
out = self.db[collection].find(query_filter, limit=limit, skip=offset, projection=fields)
|
||||
return [dtype(**x) for x in out]
|
||||
|
||||
def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T:
|
||||
@ -251,6 +258,7 @@ class MongoDBDatasource:
|
||||
access_filter = {"$match": self._read_only_user_filter(user)}
|
||||
lookup_pipeline.insert(0, access_filter)
|
||||
# pipeline = self.add_user_filter(user, pipeline)
|
||||
|
||||
out = self.db[collection].aggregate(pipeline)
|
||||
return [dtype(**x) for x in out]
|
||||
|
||||
|
@ -55,7 +55,7 @@ class DataIngestor:
|
||||
"""
|
||||
out = self.redis.get("deployments")
|
||||
if out:
|
||||
self.available_deployments = json.loads(out)
|
||||
self.available_deployments = out.data
|
||||
self.update_consumer_groups()
|
||||
self.deployment_listener_thread = threading.Thread(
|
||||
target=self.update_available_deployments, name="deployment_listener"
|
||||
@ -192,7 +192,7 @@ class DataIngestor:
|
||||
|
||||
"""
|
||||
out = self.datasource.db["sessions"].find_one(
|
||||
{"name": "_default_", "deployment_id": ObjectId(deployment_id)}
|
||||
{"name": "_default_", "deployment_id": deployment_id}
|
||||
)
|
||||
if out is None:
|
||||
return None
|
||||
@ -232,7 +232,7 @@ class DataIngestor:
|
||||
out["_id"] = msg.scan_id
|
||||
|
||||
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
||||
out["session_id"] = session_id
|
||||
out["session_id"] = str(session_id)
|
||||
|
||||
self.datasource.db["scans"].insert_one(out)
|
||||
else:
|
||||
|
@ -1,10 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
from types import UnionType
|
||||
from typing import Any, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
from bec_lib import messages
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def make_all_fields_optional(model: Type[T], model_name: str) -> Type[T]:
|
||||
"""Convert all fields in a Pydantic model to Optional."""
|
||||
|
||||
# create a dictionary of fields with the same name but with the type Optional[field]
|
||||
# and a default value of None
|
||||
fields = {}
|
||||
|
||||
for name, field in model.__fields__.items():
|
||||
fields[name] = (field.annotation, None)
|
||||
|
||||
return create_model(model_name, **fields, __config__=model.model_config)
|
||||
|
||||
|
||||
class MongoBaseModel(BaseModel):
|
||||
@ -29,7 +45,11 @@ class AccessProfilePartial(AccessProfile):
|
||||
access_groups: list[str] | None = None
|
||||
|
||||
|
||||
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
|
||||
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage):
|
||||
user_data: ScanUserData | None = None
|
||||
|
||||
|
||||
ScanStatusPartial = make_all_fields_optional(ScanStatus, "ScanStatusPartial")
|
||||
|
||||
|
||||
class UserCredentials(MongoBaseModel, AccessProfile):
|
||||
@ -167,14 +187,12 @@ class DatasetUserData(AccessProfile):
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ScanUserData(AccessProfile):
|
||||
class ScanUserData(MongoBaseModel, AccessProfile):
|
||||
scan_id: str
|
||||
name: str
|
||||
rating: int
|
||||
comments: str
|
||||
preview: bytes
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
name: str | None = None
|
||||
rating: int | None = None
|
||||
comments: str | None = None
|
||||
preview: bytes | None = None
|
||||
|
||||
|
||||
class DeviceConfig(AccessProfile):
|
||||
|
@ -1,8 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import ScanStatus, UserInfo
|
||||
from bec_atlas.model.model import ScanStatusPartial, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -16,32 +16,72 @@ class ScanRouter(BaseRouter):
|
||||
self.scans,
|
||||
methods=["GET"],
|
||||
description="Get all scans for a session",
|
||||
response_model=list[ScanStatus],
|
||||
response_model=list[ScanStatusPartial],
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/scans/id",
|
||||
self.scans_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single scan by id for a session",
|
||||
response_model=ScanStatus,
|
||||
response_model=ScanStatusPartial,
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
|
||||
async def scans(
|
||||
self, session_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
) -> list[ScanStatus]:
|
||||
self,
|
||||
session_id: str,
|
||||
include_user_data: bool = False,
|
||||
fields: list[str] = Query(default=None),
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
) -> list[ScanStatusPartial]:
|
||||
"""
|
||||
Get all scans for a session.
|
||||
|
||||
Args:
|
||||
session_id (str): The session id
|
||||
"""
|
||||
return self.db.find("scans", {"session_id": session_id}, ScanStatus, user=current_user)
|
||||
if fields:
|
||||
fields = {field: 1 for field in fields}
|
||||
if include_user_data:
|
||||
include = [{"$match": {"session_id": session_id}}]
|
||||
if fields:
|
||||
include.append({"$project": fields})
|
||||
include += [
|
||||
{"$skip": offset},
|
||||
{"$limit": limit},
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "scan_user_data",
|
||||
"let": {"_id": "$_id"},
|
||||
"pipeline": [{"$match": {"$expr": {"$eq": ["$_id", "$$_id"]}}}],
|
||||
"as": "user_data",
|
||||
}
|
||||
},
|
||||
]
|
||||
return self.db.aggregate("scans", include, ScanStatusPartial, user=current_user)
|
||||
return self.db.find(
|
||||
"scans",
|
||||
{"session_id": session_id},
|
||||
ScanStatusPartial,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
fields=fields,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)):
|
||||
async def scans_with_id(
|
||||
self,
|
||||
scan_id: str,
|
||||
include_user_data: bool = False,
|
||||
current_user: UserInfo = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get scan with id from session
|
||||
|
||||
Args:
|
||||
scan_id (str): The scan id
|
||||
"""
|
||||
return self.db.find_one("scans", {"_id": scan_id}, ScanStatus, user=current_user)
|
||||
return self.db.find_one("scans", {"_id": scan_id}, ScanStatusPartial, user=current_user)
|
||||
|
@ -74,7 +74,7 @@ class DemoSetupLoader:
|
||||
default_session = Session(
|
||||
owner_groups=["admin", "demo"],
|
||||
access_groups=["demo"],
|
||||
deployment_id=deployment["_id"],
|
||||
deployment_id=str(deployment["_id"]),
|
||||
name="_default_",
|
||||
)
|
||||
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
|
||||
|
Reference in New Issue
Block a user