feat(sessions): added session router

This commit is contained in:
2025-01-24 15:35:11 +01:00
parent c03705c2cc
commit a0924b3360
7 changed files with 168 additions and 21 deletions

View File

@ -114,7 +114,12 @@ class MongoDBDatasource:
return UserCredentials(**out) return UserCredentials(**out)
def find_one( def find_one(
self, collection: str, query_filter: dict, dtype: Type[T], user: User | None = None self,
collection: str,
query_filter: dict,
dtype: Type[T],
fields: list[str] = None,
user: User | None = None,
) -> T | None: ) -> T | None:
""" """
Find one document in the collection. Find one document in the collection.
@ -130,7 +135,7 @@ class MongoDBDatasource:
""" """
if user is not None: if user is not None:
query_filter = self.add_user_filter(user, query_filter) query_filter = self.add_user_filter(user, query_filter)
out = self.db[collection].find_one(query_filter) out = self.db[collection].find_one(query_filter, projection=fields)
if out is None: if out is None:
return None return None
return dtype(**out) return dtype(**out)
@ -143,6 +148,7 @@ class MongoDBDatasource:
limit: int = 0, limit: int = 0,
offset: int = 0, offset: int = 0,
fields: list[str] = None, fields: list[str] = None,
sort: list[str] = None,
user: User | None = None, user: User | None = None,
) -> list[T]: ) -> list[T]:
""" """
@ -159,7 +165,9 @@ class MongoDBDatasource:
""" """
if user is not None: if user is not None:
query_filter = self.add_user_filter(user, query_filter) query_filter = self.add_user_filter(user, query_filter)
out = self.db[collection].find(query_filter, limit=limit, skip=offset, projection=fields) out = self.db[collection].find(
query_filter, limit=limit, skip=offset, projection=fields, sort=sort
)
return [dtype(**x) for x in out] return [dtype(**x) for x in out]
def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T: def post(self, collection: str, data: dict, dtype: Type[T], user: User | None = None) -> T:

View File

@ -10,6 +10,7 @@ from bec_atlas.router.deployments_router import DeploymentsRouter
from bec_atlas.router.realm_router import RealmRouter from bec_atlas.router.realm_router import RealmRouter
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
from bec_atlas.router.scan_router import ScanRouter from bec_atlas.router.scan_router import ScanRouter
from bec_atlas.router.session_router import SessionRouter
from bec_atlas.router.user_router import UserRouter from bec_atlas.router.user_router import UserRouter
CONFIG = { CONFIG = {
@ -54,15 +55,20 @@ class AtlasApp:
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if not self.datasources.datasources: if not self.datasources.datasources:
raise ValueError("Datasources not loaded") raise ValueError("Datasources not loaded")
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.scan_router.router, tags=["Scan"])
# User
self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources) self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.user_router.router, tags=["User"]) self.app.include_router(self.user_router.router, tags=["User"])
# Realm
self.realm_router = RealmRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.realm_router.router, tags=["Realm"])
# Deployment
self.deployment_router = DeploymentsRouter(prefix=self.prefix, datasources=self.datasources) self.deployment_router = DeploymentsRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.deployment_router.router, tags=["Deployment"]) self.app.include_router(self.deployment_router.router, tags=["Deployment"])
# Deployment Credentials
self.deployment_credentials_router = DeploymentCredentialsRouter( self.deployment_credentials_router = DeploymentCredentialsRouter(
prefix=self.prefix, datasources=self.datasources prefix=self.prefix, datasources=self.datasources
) )
@ -70,17 +76,25 @@ class AtlasApp:
self.deployment_credentials_router.router, tags=["Deployment Credentials"] self.deployment_credentials_router.router, tags=["Deployment Credentials"]
) )
# Deployment Access
self.deployment_access_router = DeploymentAccessRouter( self.deployment_access_router = DeploymentAccessRouter(
prefix=self.prefix, datasources=self.datasources prefix=self.prefix, datasources=self.datasources
) )
self.app.include_router(self.deployment_access_router.router, tags=["Deployment Access"]) self.app.include_router(self.deployment_access_router.router, tags=["Deployment Access"])
# BEC Access
self.bec_access_router = BECAccessRouter(prefix=self.prefix, datasources=self.datasources) self.bec_access_router = BECAccessRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.bec_access_router.router, tags=["BEC Access"]) self.app.include_router(self.bec_access_router.router, tags=["BEC Access"])
self.realm_router = RealmRouter(prefix=self.prefix, datasources=self.datasources) # Session
self.app.include_router(self.realm_router.router, tags=["Realm"]) self.session_router = SessionRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.session_router.router, tags=["Session"])
# Scan
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.scan_router.router, tags=["Scan"])
# Redis
self.redis_router = RedisRouter(prefix=self.prefix, datasources=self.datasources) self.redis_router = RedisRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.redis_router.router, tags=["Redis"]) self.app.include_router(self.redis_router.router, tags=["Redis"])

View File

@ -70,14 +70,14 @@ class UserInfo(BaseModel):
class Deployments(MongoBaseModel, AccessProfile): class Deployments(MongoBaseModel, AccessProfile):
realm_id: str | ObjectId realm_id: str
name: str name: str
active_session_id: str | ObjectId | None = None active_session_id: str | ObjectId | None = None
config_templates: list[str | ObjectId] = [] config_templates: list[str | ObjectId] = []
class DeploymentsPartial(MongoBaseModel, AccessProfilePartial): class DeploymentsPartial(MongoBaseModel, AccessProfilePartial):
realm_id: str | ObjectId | None = None realm_id: str | None = None
name: str | None = None name: str | None = None
active_session_id: str | ObjectId | None = None active_session_id: str | ObjectId | None = None
config_templates: list[str | ObjectId] | None = None config_templates: list[str | ObjectId] | None = None
@ -120,7 +120,7 @@ class BECAccessProfile(MongoBaseModel, AccessProfile):
""" """
deployment_id: str | ObjectId deployment_id: str
username: str username: str
passwords: dict[str, str] = {} passwords: dict[str, str] = {}
categories: list[str] = [] categories: list[str] = []
@ -167,10 +167,13 @@ class State(AccessProfile):
class Session(MongoBaseModel, AccessProfile): class Session(MongoBaseModel, AccessProfile):
deployment_id: str | ObjectId deployment_id: str
name: str name: str
SessionPartial = make_all_fields_optional(Session, "SessionPartial")
class Datasets(AccessProfile): class Datasets(AccessProfile):
realm_id: str realm_id: str
dataset_id: str dataset_id: str
@ -214,8 +217,8 @@ class SignalData(AccessProfile, MongoBaseModel):
min, and max values for the signal. min, and max values for the signal.
""" """
scan_id: str | ObjectId | None = None scan_id: str | None = None
device_id: str | ObjectId device_id: str
signal_name: str signal_name: str
data: list[Any] data: list[Any]
timestamps: list[float] timestamps: list[float]
@ -227,9 +230,9 @@ class SignalData(AccessProfile, MongoBaseModel):
class DeviceData(AccessProfile, MongoBaseModel): class DeviceData(AccessProfile, MongoBaseModel):
scan_id: str | ObjectId | None = None scan_id: str | None = None
name: str name: str
device_config_id: str | ObjectId device_config_id: str
signals: list[SignalData] signals: list[SignalData]

View File

@ -38,7 +38,7 @@ class BECAccessRouter(BaseRouter):
user = current_user.email user = current_user.email
out = self.db.find_one( out = self.db.find_one(
"bec_access_profiles", "bec_access_profiles",
{"deployment_id": ObjectId(deployment_id), "username": user}, {"deployment_id": deployment_id, "username": user},
BECAccessProfile, BECAccessProfile,
user=current_user, user=current_user,
) )

View File

@ -153,7 +153,9 @@ class DeploymentAccessRouter(BaseRouter):
redis: RedisDatasource = self.datasources.datasources.get("redis") redis: RedisDatasource = self.datasources.datasources.get("redis")
db: MongoDBDatasource = self.datasources.datasources.get("mongodb") db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
profiles = db.find( profiles = db.find(
"bec_access_profiles", {"deployment_id": ObjectId(deployment_id)}, BECAccessProfile collection="bec_access_profiles",
query_filter={"deployment_id": deployment_id},
dtype=BECAccessProfile,
) )
profiles = [profile.model_dump(exclude_none=True) for profile in profiles] profiles = [profile.model_dump(exclude_none=True) for profile in profiles]
for profile in profiles: for profile in profiles:

View File

@ -44,6 +44,7 @@ class ScanRouter(BaseRouter):
fields: list[str] = Query(default=None), fields: list[str] = Query(default=None),
offset: int = 0, offset: int = 0,
limit: int = 100, limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user), current_user: UserInfo = Depends(get_current_user),
) -> list[ScanStatusPartial]: ) -> list[ScanStatusPartial]:
""" """
@ -51,6 +52,17 @@ class ScanRouter(BaseRouter):
Args: Args:
session_id (str): The session id session_id (str): The session id
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
fields (list[str]): List of fields to return, e.g ["name", "description"]
offset (int): Offset for the query
limit (int): Limit for the query
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user
Returns:
list[ScanStatusPartial]: List of scans
""" """
if fields: if fields:
@ -72,16 +84,14 @@ class ScanRouter(BaseRouter):
limit=limit, limit=limit,
offset=offset, offset=offset,
fields=fields, fields=fields,
sort=sort,
user=current_user, user=current_user,
) )
async def scans_with_id( async def scans_with_id(
self, self,
scan_id: str, scan_id: str,
filter: str | None = None,
fields: list[str] = Query(default=None), fields: list[str] = Query(default=None),
offset: int = 0,
limit: int = 100,
current_user: UserInfo = Depends(get_current_user), current_user: UserInfo = Depends(get_current_user),
): ):
""" """
@ -90,7 +100,18 @@ class ScanRouter(BaseRouter):
Args: Args:
scan_id (str): The scan id scan_id (str): The scan id
""" """
return self.db.find_one("scans", {"_id": scan_id}, ScanStatusPartial, user=current_user) if fields:
fields = {
field: 1
for field in fields
if field in ScanStatusPartial.model_json_schema()["properties"].keys()
}
return self.db.find_one(
collection="scans",
query_filter={"_id": scan_id},
dtype=ScanStatusPartial,
user=current_user,
)
async def update_scan_user_data( async def update_scan_user_data(
self, self,

View File

@ -0,0 +1,99 @@
import json
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo
from bec_atlas.model.model import Session
from bec_atlas.router.base_router import BaseRouter
class SessionRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/sessions",
self.sessions,
methods=["GET"],
description="Get all sessions",
response_model=list[Session],
response_model_exclude_none=True,
)
self.router.add_api_route(
"/sessions/realm",
self.sessions_by_realm,
methods=["GET"],
description="Get all sessions for a realm",
response_model=list[Session],
response_model_exclude_none=True,
)
async def sessions(
self,
filter: str | None = None,
fields: list[str] = Query(default=None),
offset: int = 0,
limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user),
) -> list[Session]:
"""
Get all sessions.
Args:
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
fields (list[str]): List of fields to return, e.g ["name", "description"]
offset (int): Offset for the query
limit (int): Limit for the query
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}'
current_user (UserInfo): The current user
Returns:
list[Sessions]: List of sessions
"""
if fields:
fields = {
field: 1
for field in fields
if field in Session.model_json_schema()["properties"].keys()
}
return self.db.find(
"sessions",
filter,
Session,
fields=fields,
offset=offset,
limit=limit,
sort=sort,
user=current_user,
)
async def sessions_by_realm(
self,
realm_id: str,
filter: str | None = None,
fields: list[str] = Query(default=None),
offset: int = 0,
limit: int = 100,
sort: str | None = None,
current_user: UserInfo = Depends(get_current_user),
) -> list[Session]:
"""
Get all sessions for a realm.
"""
filters = {"realm_id": realm_id}
if filter:
filter = json.loads(filter)
filters.update(filter)
out = await self.sessions(filter, fields, offset, limit, sort, current_user)
return out