refactor(backend): replaced datasources dict with properties

This commit is contained in:
2025-03-02 11:58:12 +01:00
parent ce04dac5be
commit ff6946d2bc
13 changed files with 80 additions and 46 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from bec_lib.logger import bec_logger
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
@ -9,26 +11,36 @@ logger = bec_logger.logger
class DatasourceManager:
def __init__(self, config: dict):
self.config = config
self.datasources = {}
self._redis: RedisDatasource | None = None
self._mongodb: MongoDBDatasource | None = None
self.load_datasources()
def connect(self):
for datasource in self.datasources.values():
datasource.connect()
self.redis.connect()
self.mongodb.connect()
def load_datasources(self):
for datasource_name, datasource_config in self.config.items():
if datasource_name == "redis":
logger.info(
f"Loading Redis datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
)
self.datasources[datasource_name] = RedisDatasource(datasource_config)
if datasource_name == "mongodb":
logger.info(
f"Loading MongoDB datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
)
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
redis_config = self.config.get("redis")
if redis_config:
self._redis = RedisDatasource(redis_config)
mongodb_config = self.config.get("mongodb")
if mongodb_config:
self._mongodb = MongoDBDatasource(mongodb_config)
@property
def redis(self) -> RedisDatasource:
if not self._redis:
raise RuntimeError("Redis datasource not loaded")
return self._redis
@property
def mongodb(self) -> MongoDBDatasource:
if not self._mongodb:
raise RuntimeError("MongoDB datasource not loaded")
return self._mongodb
def shutdown(self):
for datasource in self.datasources.values():
datasource.shutdown()
if self._redis:
self._redis.shutdown()
if self._mongodb:
self._mongodb.shutdown()

View File

@ -58,7 +58,7 @@ class AtlasApp:
def add_routers(self):
# pylint: disable=attribute-defined-outside-init
if not self.datasources.datasources:
if not self.datasources.redis or not self.datasources.mongodb:
raise ValueError("Datasources not loaded")
# User

View File

@ -10,12 +10,16 @@ if TYPE_CHECKING: # pragma: no cover
class BaseRouter:
def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None:
def __init__(
self, prefix: str = "/api/v1", datasources: DatasourceManager | None = None
) -> None:
self.datasources = datasources
self.prefix = prefix
if not self.datasources:
raise RuntimeError("Datasources not loaded")
@lru_cache(maxsize=128)
def get_user_from_db(self, _token: str, email: str) -> User:
def get_user_from_db(self, _token: str, email: str) -> User | None:
"""
Get the user from the database. This is a helper function to be used by the
convert_to_user decorator. The function is cached to avoid repeated database
@ -26,4 +30,6 @@ class BaseRouter:
_token (str): The token
email (str): The email
"""
return self.datasources.datasources["mongodb"].get_user_by_email(email)
if not self.datasources:
raise RuntimeError("Datasources not loaded")
return self.datasources.mongodb.get_user_by_email(email)

View File

@ -1,15 +1,26 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import convert_to_user, get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import BECAccessProfile, User, UserInfo
from bec_atlas.model.model import BECAccessProfile, User
from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING: # pragma: no cover
from bec_atlas.datasources.datasource_manager import DatasourceManager
class BECAccessRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
if not self.datasources:
raise RuntimeError("Datasources not loaded")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/bec_access",

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import secrets
import time
from typing import TYPE_CHECKING, Any
@ -14,13 +16,16 @@ from bec_atlas.router.base_router import BaseRouter
from bec_atlas.router.redis_router import RedisAtlasEndpoints
if TYPE_CHECKING: # pragma: no cover
from bec_atlas.datasources.datasource_manager import DatasourceManager
from bec_atlas.datasources.redis_datasource import RedisDatasource
class DeploymentAccessRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
if not self.datasources:
raise RuntimeError("Datasources not loaded")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/deployment_access",
@ -40,7 +45,7 @@ class DeploymentAccessRouter(BaseRouter):
@convert_to_user
async def get_deployment_access(
self, deployment_id: str, current_user: User = Depends(get_current_user)
) -> DeploymentAccess:
) -> DeploymentAccess | None:
"""
Get the access lists for a specific deployment.
@ -104,7 +109,7 @@ class DeploymentAccessRouter(BaseRouter):
Args:
deployment_access (DeploymentAccess): The deployment access object
"""
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
db: MongoDBDatasource = self.datasources.mongodb
new_profiles = set(
updated.user_read_access
@ -160,8 +165,8 @@ class DeploymentAccessRouter(BaseRouter):
"""
Refresh the redis BEC access.
"""
redis: RedisDatasource = self.datasources.datasources.get("redis")
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
redis: RedisDatasource = self.datasources.redis
db: MongoDBDatasource = self.datasources.mongodb
profiles = db.find(
collection="bec_access_profiles",
query_filter={"deployment_id": deployment_id},
@ -190,7 +195,7 @@ class DeploymentAccessRouter(BaseRouter):
Returns:
bool: True if the user exists, False otherwise
"""
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
db: MongoDBDatasource = self.datasources.mongodb
user = db.find_one("users", {"email": user}, User)
return user is not None

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
class DeploymentCredentialsRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/deploymentCredentials",
@ -80,7 +80,7 @@ class DeploymentCredentialsRouter(BaseRouter):
raise HTTPException(status_code=404, detail="Deployment not found")
# update the redis deployment key
redis: RedisDatasource = self.datasources.datasources.get("redis")
redis: RedisDatasource = self.datasources.redis
redis.add_deployment_acl(out)
return out

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
class DeploymentsRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/deployments/realm",
@ -72,7 +72,7 @@ class DeploymentsRouter(BaseRouter):
self.available_deployments = self.db.find("deployments", {}, Deployments)
credentials = self.db.find("deployment_credentials", {}, DeploymentCredential)
redis: RedisDatasource = self.datasources.datasources.get("redis")
redis: RedisDatasource = self.datasources.redis
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
redis.connector.set_and_publish("deployments", msg)
for deployment in credentials:

View File

@ -9,7 +9,7 @@ from bec_atlas.router.base_router import BaseRouter
class RealmRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/realms",

View File

@ -141,8 +141,8 @@ class RedisRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
super().__init__(prefix, datasources)
self.redis = self.datasources.datasources["redis"].async_connector
self.db = self.datasources.datasources["mongodb"]
self.redis = self.datasources.redis.async_connector
self.db = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
@ -474,15 +474,15 @@ class RedisWebsocket:
"""
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
self.redis: RedisConnector = datasources.datasources["redis"].connector
self.redis: RedisConnector = datasources.redis.connector
self.prefix = prefix
self.fastapi_app = app
self.redis_router = app.redis_router
self.active_connections = set()
redis_host = datasources.datasources["redis"].config["host"]
redis_port = datasources.datasources["redis"].config["port"]
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
self.db = datasources.datasources["mongodb"]
redis_host = datasources.redis.config["host"]
redis_port = datasources.redis.config["port"]
redis_password = datasources.redis.config.get("password", "ingestor")
self.db = datasources.mongodb
self.socket = socketio.AsyncServer(
transports=["websocket"],
ping_timeout=60,

View File

@ -14,7 +14,7 @@ from bec_atlas.router.base_router import BaseRouter
class ScanRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/scans/session",

View File

@ -11,7 +11,7 @@ from bec_atlas.router.base_router import BaseRouter
class SessionRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/sessions",

View File

@ -28,7 +28,7 @@ class UserRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
super().__init__(prefix, datasources)
self.use_ssl = use_ssl
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.db: MongoDBDatasource = self.datasources.mongodb
self.ldap = LDAPUserService(
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
)

View File

@ -25,7 +25,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
Test that the login endpoint returns a token.
"""
client, app = backend
mongo: MongoDBDatasource = app.datasources.datasources["mongodb"]
mongo: MongoDBDatasource = app.datasources.mongodb
deployment_id = str(mongo.find_one("deployments", {}, dtype=Deployments).id)
session_id = str(mongo.find_one("sessions", {"deployment_id": deployment_id}, dtype=Session).id)
msg = messages.ScanStatusMessage(