From 3a142b919b69431cfd5ccc4a8880ee1b725c3099 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Thu, 23 Jan 2025 13:16:57 +0100 Subject: [PATCH] fix(scan user data): added defaults to model --- .../bec_atlas/datasources/mongodb/mongodb.py | 12 +++- backend/bec_atlas/ingestor/data_ingestor.py | 6 +- backend/bec_atlas/model/model.py | 38 ++++++++---- backend/bec_atlas/router/scan_router.py | 58 ++++++++++++++++--- .../bec_atlas/utils/demo_database_setup.py | 2 +- 5 files changed, 91 insertions(+), 25 deletions(-) diff --git a/backend/bec_atlas/datasources/mongodb/mongodb.py b/backend/bec_atlas/datasources/mongodb/mongodb.py index 445f69d..6c09cda 100644 --- a/backend/bec_atlas/datasources/mongodb/mongodb.py +++ b/backend/bec_atlas/datasources/mongodb/mongodb.py @@ -136,7 +136,14 @@ class MongoDBDatasource: return dtype(**out) def find( - self, collection: str, query_filter: dict, dtype: Type[T], user: User | None = None + self, + collection: str, + query_filter: dict, + dtype: Type[T], + limit: int = 0, + offset: int = 0, + fields: list[str] = None, + user: User | None = None, ) -> list[T]: """ Find all documents in the collection. @@ -152,7 +159,7 @@ class MongoDBDatasource: """ if user is not None: query_filter = self.add_user_filter(user, query_filter) - out = self.db[collection].find(query_filter) + out = self.db[collection].find(query_filter, limit=limit, skip=offset, projection=fields) return [dtype(**x) for x in out] def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T: @@ -251,6 +258,7 @@ class MongoDBDatasource: access_filter = {"$match": self._read_only_user_filter(user)} lookup_pipeline.insert(0, access_filter) # pipeline = self.add_user_filter(user, pipeline) + out = self.db[collection].aggregate(pipeline) return [dtype(**x) for x in out] diff --git a/backend/bec_atlas/ingestor/data_ingestor.py b/backend/bec_atlas/ingestor/data_ingestor.py index 1817c5d..709df59 100644 --- a/backend/bec_atlas/ingestor/data_ingestor.py +++ b/backend/bec_atlas/ingestor/data_ingestor.py @@ -55,7 +55,7 @@ class DataIngestor: """ out = self.redis.get("deployments") if out: - self.available_deployments = json.loads(out) + self.available_deployments = out.data self.update_consumer_groups() self.deployment_listener_thread = threading.Thread( target=self.update_available_deployments, name="deployment_listener" @@ -192,7 +192,7 @@ class DataIngestor: """ out = self.datasource.db["sessions"].find_one( - {"name": "_default_", "deployment_id": ObjectId(deployment_id)} + {"name": "_default_", "deployment_id": deployment_id} ) if out is None: return None @@ -232,7 +232,7 @@ class DataIngestor: out["_id"] = msg.scan_id # TODO for compatibility with the old message format; remove once the bec_lib is updated - out["session_id"] = session_id + out["session_id"] = str(session_id) self.datasource.db["scans"].insert_one(out) else: diff --git a/backend/bec_atlas/model/model.py b/backend/bec_atlas/model/model.py index 5e539c6..d2b452b 100644 --- a/backend/bec_atlas/model/model.py +++ b/backend/bec_atlas/model/model.py @@ -1,10 +1,26 @@ from __future__ import annotations -from typing import Any, Literal +from types import UnionType +from typing import Any, Literal, Optional, Type, TypeVar, Union from bec_lib import messages from bson import ObjectId -from pydantic import BaseModel, ConfigDict, Field, field_serializer +from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer + +T = TypeVar("T") + + +def make_all_fields_optional(model: Type[T], model_name: str) -> Type[T]: + """Convert all fields in a Pydantic model to Optional.""" + + # create a dictionary of fields with the same name but with the type Optional[field] + # and a default value of None + fields = {} + + for name, field in model.__fields__.items(): + fields[name] = (field.annotation, None) + + return create_model(model_name, **fields, __config__=model.model_config) class MongoBaseModel(BaseModel): @@ -29,7 +45,11 @@ class AccessProfilePartial(AccessProfile): access_groups: list[str] | None = None -class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ... +class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): + user_data: ScanUserData | None = None + + +ScanStatusPartial = make_all_fields_optional(ScanStatus, "ScanStatusPartial") class UserCredentials(MongoBaseModel, AccessProfile): @@ -167,14 +187,12 @@ class DatasetUserData(AccessProfile): model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) -class ScanUserData(AccessProfile): +class ScanUserData(MongoBaseModel, AccessProfile): scan_id: str - name: str - rating: int - comments: str - preview: bytes - - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) + name: str | None = None + rating: int | None = None + comments: str | None = None + preview: bytes | None = None class DeviceConfig(AccessProfile): diff --git a/backend/bec_atlas/router/scan_router.py b/backend/bec_atlas/router/scan_router.py index c7bed63..3143e39 100644 --- a/backend/bec_atlas/router/scan_router.py +++ b/backend/bec_atlas/router/scan_router.py @@ -1,8 +1,8 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from bec_atlas.authentication import get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import ScanStatus, UserInfo +from bec_atlas.model.model import ScanStatusPartial, UserInfo from bec_atlas.router.base_router import BaseRouter @@ -16,32 +16,72 @@ class ScanRouter(BaseRouter): self.scans, methods=["GET"], description="Get all scans for a session", - response_model=list[ScanStatus], + response_model=list[ScanStatusPartial], + response_model_exclude_none=True, ) self.router.add_api_route( "/scans/id", self.scans_with_id, methods=["GET"], description="Get a single scan by id for a session", - response_model=ScanStatus, + response_model=ScanStatusPartial, + response_model_exclude_none=True, ) async def scans( - self, session_id: str, current_user: UserInfo = Depends(get_current_user) - ) -> list[ScanStatus]: + self, + session_id: str, + include_user_data: bool = False, + fields: list[str] = Query(default=None), + offset: int = 0, + limit: int = 100, + current_user: UserInfo = Depends(get_current_user), + ) -> list[ScanStatusPartial]: """ Get all scans for a session. Args: session_id (str): The session id """ - return self.db.find("scans", {"session_id": session_id}, ScanStatus, user=current_user) + if fields: + fields = {field: 1 for field in fields} + if include_user_data: + include = [{"$match": {"session_id": session_id}}] + if fields: + include.append({"$project": fields}) + include += [ + {"$skip": offset}, + {"$limit": limit}, + { + "$lookup": { + "from": "scan_user_data", + "let": {"_id": "$_id"}, + "pipeline": [{"$match": {"$expr": {"$eq": ["$_id", "$$_id"]}}}], + "as": "user_data", + } + }, + ] + return self.db.aggregate("scans", include, ScanStatusPartial, user=current_user) + return self.db.find( + "scans", + {"session_id": session_id}, + ScanStatusPartial, + limit=limit, + offset=offset, + fields=fields, + user=current_user, + ) - async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)): + async def scans_with_id( + self, + scan_id: str, + include_user_data: bool = False, + current_user: UserInfo = Depends(get_current_user), + ): """ Get scan with id from session Args: scan_id (str): The scan id """ - return self.db.find_one("scans", {"_id": scan_id}, ScanStatus, user=current_user) + return self.db.find_one("scans", {"_id": scan_id}, ScanStatusPartial, user=current_user) diff --git a/backend/bec_atlas/utils/demo_database_setup.py b/backend/bec_atlas/utils/demo_database_setup.py index 5b51147..c852d4d 100644 --- a/backend/bec_atlas/utils/demo_database_setup.py +++ b/backend/bec_atlas/utils/demo_database_setup.py @@ -74,7 +74,7 @@ class DemoSetupLoader: default_session = Session( owner_groups=["admin", "demo"], access_groups=["demo"], - deployment_id=deployment["_id"], + deployment_id=str(deployment["_id"]), name="_default_", ) self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))