mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
fix(scan user data): added defaults to model
This commit is contained in:
@ -136,7 +136,14 @@ class MongoDBDatasource:
|
|||||||
return dtype(**out)
|
return dtype(**out)
|
||||||
|
|
||||||
def find(
|
def find(
|
||||||
self, collection: str, query_filter: dict, dtype: Type[T], user: User | None = None
|
self,
|
||||||
|
collection: str,
|
||||||
|
query_filter: dict,
|
||||||
|
dtype: Type[T],
|
||||||
|
limit: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
fields: list[str] = None,
|
||||||
|
user: User | None = None,
|
||||||
) -> list[T]:
|
) -> list[T]:
|
||||||
"""
|
"""
|
||||||
Find all documents in the collection.
|
Find all documents in the collection.
|
||||||
@ -152,7 +159,7 @@ class MongoDBDatasource:
|
|||||||
"""
|
"""
|
||||||
if user is not None:
|
if user is not None:
|
||||||
query_filter = self.add_user_filter(user, query_filter)
|
query_filter = self.add_user_filter(user, query_filter)
|
||||||
out = self.db[collection].find(query_filter)
|
out = self.db[collection].find(query_filter, limit=limit, skip=offset, projection=fields)
|
||||||
return [dtype(**x) for x in out]
|
return [dtype(**x) for x in out]
|
||||||
|
|
||||||
def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T:
|
def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T:
|
||||||
@ -251,6 +258,7 @@ class MongoDBDatasource:
|
|||||||
access_filter = {"$match": self._read_only_user_filter(user)}
|
access_filter = {"$match": self._read_only_user_filter(user)}
|
||||||
lookup_pipeline.insert(0, access_filter)
|
lookup_pipeline.insert(0, access_filter)
|
||||||
# pipeline = self.add_user_filter(user, pipeline)
|
# pipeline = self.add_user_filter(user, pipeline)
|
||||||
|
|
||||||
out = self.db[collection].aggregate(pipeline)
|
out = self.db[collection].aggregate(pipeline)
|
||||||
return [dtype(**x) for x in out]
|
return [dtype(**x) for x in out]
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class DataIngestor:
|
|||||||
"""
|
"""
|
||||||
out = self.redis.get("deployments")
|
out = self.redis.get("deployments")
|
||||||
if out:
|
if out:
|
||||||
self.available_deployments = json.loads(out)
|
self.available_deployments = out.data
|
||||||
self.update_consumer_groups()
|
self.update_consumer_groups()
|
||||||
self.deployment_listener_thread = threading.Thread(
|
self.deployment_listener_thread = threading.Thread(
|
||||||
target=self.update_available_deployments, name="deployment_listener"
|
target=self.update_available_deployments, name="deployment_listener"
|
||||||
@ -192,7 +192,7 @@ class DataIngestor:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
out = self.datasource.db["sessions"].find_one(
|
out = self.datasource.db["sessions"].find_one(
|
||||||
{"name": "_default_", "deployment_id": ObjectId(deployment_id)}
|
{"name": "_default_", "deployment_id": deployment_id}
|
||||||
)
|
)
|
||||||
if out is None:
|
if out is None:
|
||||||
return None
|
return None
|
||||||
@ -232,7 +232,7 @@ class DataIngestor:
|
|||||||
out["_id"] = msg.scan_id
|
out["_id"] = msg.scan_id
|
||||||
|
|
||||||
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
||||||
out["session_id"] = session_id
|
out["session_id"] = str(session_id)
|
||||||
|
|
||||||
self.datasource.db["scans"].insert_one(out)
|
self.datasource.db["scans"].insert_one(out)
|
||||||
else:
|
else:
|
||||||
|
@ -1,10 +1,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Literal
|
from types import UnionType
|
||||||
|
from typing import Any, Literal, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
from bec_lib import messages
|
from bec_lib import messages
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def make_all_fields_optional(model: Type[T], model_name: str) -> Type[T]:
|
||||||
|
"""Convert all fields in a Pydantic model to Optional."""
|
||||||
|
|
||||||
|
# create a dictionary of fields with the same name but with the type Optional[field]
|
||||||
|
# and a default value of None
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
for name, field in model.__fields__.items():
|
||||||
|
fields[name] = (field.annotation, None)
|
||||||
|
|
||||||
|
return create_model(model_name, **fields, __config__=model.model_config)
|
||||||
|
|
||||||
|
|
||||||
class MongoBaseModel(BaseModel):
|
class MongoBaseModel(BaseModel):
|
||||||
@ -29,7 +45,11 @@ class AccessProfilePartial(AccessProfile):
|
|||||||
access_groups: list[str] | None = None
|
access_groups: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
|
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage):
|
||||||
|
user_data: ScanUserData | None = None
|
||||||
|
|
||||||
|
|
||||||
|
ScanStatusPartial = make_all_fields_optional(ScanStatus, "ScanStatusPartial")
|
||||||
|
|
||||||
|
|
||||||
class UserCredentials(MongoBaseModel, AccessProfile):
|
class UserCredentials(MongoBaseModel, AccessProfile):
|
||||||
@ -167,14 +187,12 @@ class DatasetUserData(AccessProfile):
|
|||||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
class ScanUserData(AccessProfile):
|
class ScanUserData(MongoBaseModel, AccessProfile):
|
||||||
scan_id: str
|
scan_id: str
|
||||||
name: str
|
name: str | None = None
|
||||||
rating: int
|
rating: int | None = None
|
||||||
comments: str
|
comments: str | None = None
|
||||||
preview: bytes
|
preview: bytes | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConfig(AccessProfile):
|
class DeviceConfig(AccessProfile):
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from bec_atlas.authentication import get_current_user
|
from bec_atlas.authentication import get_current_user
|
||||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||||
from bec_atlas.model.model import ScanStatus, UserInfo
|
from bec_atlas.model.model import ScanStatusPartial, UserInfo
|
||||||
from bec_atlas.router.base_router import BaseRouter
|
from bec_atlas.router.base_router import BaseRouter
|
||||||
|
|
||||||
|
|
||||||
@ -16,32 +16,72 @@ class ScanRouter(BaseRouter):
|
|||||||
self.scans,
|
self.scans,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
description="Get all scans for a session",
|
description="Get all scans for a session",
|
||||||
response_model=list[ScanStatus],
|
response_model=list[ScanStatusPartial],
|
||||||
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/scans/id",
|
"/scans/id",
|
||||||
self.scans_with_id,
|
self.scans_with_id,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
description="Get a single scan by id for a session",
|
description="Get a single scan by id for a session",
|
||||||
response_model=ScanStatus,
|
response_model=ScanStatusPartial,
|
||||||
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def scans(
|
async def scans(
|
||||||
self, session_id: str, current_user: UserInfo = Depends(get_current_user)
|
self,
|
||||||
) -> list[ScanStatus]:
|
session_id: str,
|
||||||
|
include_user_data: bool = False,
|
||||||
|
fields: list[str] = Query(default=None),
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
current_user: UserInfo = Depends(get_current_user),
|
||||||
|
) -> list[ScanStatusPartial]:
|
||||||
"""
|
"""
|
||||||
Get all scans for a session.
|
Get all scans for a session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (str): The session id
|
session_id (str): The session id
|
||||||
"""
|
"""
|
||||||
return self.db.find("scans", {"session_id": session_id}, ScanStatus, user=current_user)
|
if fields:
|
||||||
|
fields = {field: 1 for field in fields}
|
||||||
|
if include_user_data:
|
||||||
|
include = [{"$match": {"session_id": session_id}}]
|
||||||
|
if fields:
|
||||||
|
include.append({"$project": fields})
|
||||||
|
include += [
|
||||||
|
{"$skip": offset},
|
||||||
|
{"$limit": limit},
|
||||||
|
{
|
||||||
|
"$lookup": {
|
||||||
|
"from": "scan_user_data",
|
||||||
|
"let": {"_id": "$_id"},
|
||||||
|
"pipeline": [{"$match": {"$expr": {"$eq": ["$_id", "$$_id"]}}}],
|
||||||
|
"as": "user_data",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return self.db.aggregate("scans", include, ScanStatusPartial, user=current_user)
|
||||||
|
return self.db.find(
|
||||||
|
"scans",
|
||||||
|
{"session_id": session_id},
|
||||||
|
ScanStatusPartial,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
fields=fields,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
|
||||||
async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)):
|
async def scans_with_id(
|
||||||
|
self,
|
||||||
|
scan_id: str,
|
||||||
|
include_user_data: bool = False,
|
||||||
|
current_user: UserInfo = Depends(get_current_user),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get scan with id from session
|
Get scan with id from session
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scan_id (str): The scan id
|
scan_id (str): The scan id
|
||||||
"""
|
"""
|
||||||
return self.db.find_one("scans", {"_id": scan_id}, ScanStatus, user=current_user)
|
return self.db.find_one("scans", {"_id": scan_id}, ScanStatusPartial, user=current_user)
|
||||||
|
@ -74,7 +74,7 @@ class DemoSetupLoader:
|
|||||||
default_session = Session(
|
default_session = Session(
|
||||||
owner_groups=["admin", "demo"],
|
owner_groups=["admin", "demo"],
|
||||||
access_groups=["demo"],
|
access_groups=["demo"],
|
||||||
deployment_id=deployment["_id"],
|
deployment_id=str(deployment["_id"]),
|
||||||
name="_default_",
|
name="_default_",
|
||||||
)
|
)
|
||||||
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
|
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
|
||||||
|
Reference in New Issue
Block a user