From ff6946d2bc9cb8acde46587b3e8e480be79f55cf Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sun, 2 Mar 2025 11:58:12 +0100 Subject: [PATCH] refactor(backend): replaced datasources dict with properties --- .../datasources/datasource_manager.py | 44 ++++++++++++------- backend/bec_atlas/main.py | 2 +- backend/bec_atlas/router/base_router.py | 12 +++-- backend/bec_atlas/router/bec_access_router.py | 17 +++++-- .../router/deployment_access_router.py | 19 +++++--- .../router/deployment_credentials.py | 4 +- .../bec_atlas/router/deployments_router.py | 4 +- backend/bec_atlas/router/realm_router.py | 2 +- backend/bec_atlas/router/redis_router.py | 14 +++--- backend/bec_atlas/router/scan_router.py | 2 +- backend/bec_atlas/router/session_router.py | 2 +- backend/bec_atlas/router/user_router.py | 2 +- backend/tests/test_scan_ingestor.py | 2 +- 13 files changed, 80 insertions(+), 46 deletions(-) diff --git a/backend/bec_atlas/datasources/datasource_manager.py b/backend/bec_atlas/datasources/datasource_manager.py index c2a64c4..15a738e 100644 --- a/backend/bec_atlas/datasources/datasource_manager.py +++ b/backend/bec_atlas/datasources/datasource_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from bec_lib.logger import bec_logger from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource @@ -9,26 +11,36 @@ logger = bec_logger.logger class DatasourceManager: def __init__(self, config: dict): self.config = config - self.datasources = {} + self._redis: RedisDatasource | None = None + self._mongodb: MongoDBDatasource | None = None self.load_datasources() def connect(self): - for datasource in self.datasources.values(): - datasource.connect() + self.redis.connect() + self.mongodb.connect() def load_datasources(self): - for datasource_name, datasource_config in self.config.items(): - if datasource_name == "redis": - logger.info( - f"Loading Redis datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}" - ) - self.datasources[datasource_name] = RedisDatasource(datasource_config) - if datasource_name == "mongodb": - logger.info( - f"Loading MongoDB datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}" - ) - self.datasources[datasource_name] = MongoDBDatasource(datasource_config) + redis_config = self.config.get("redis") + if redis_config: + self._redis = RedisDatasource(redis_config) + mongodb_config = self.config.get("mongodb") + if mongodb_config: + self._mongodb = MongoDBDatasource(mongodb_config) + + @property + def redis(self) -> RedisDatasource: + if not self._redis: + raise RuntimeError("Redis datasource not loaded") + return self._redis + + @property + def mongodb(self) -> MongoDBDatasource: + if not self._mongodb: + raise RuntimeError("MongoDB datasource not loaded") + return self._mongodb def shutdown(self): - for datasource in self.datasources.values(): - datasource.shutdown() + if self._redis: + self._redis.shutdown() + if self._mongodb: + self._mongodb.shutdown() diff --git a/backend/bec_atlas/main.py b/backend/bec_atlas/main.py index 36857d7..8bb68f0 100644 --- a/backend/bec_atlas/main.py +++ b/backend/bec_atlas/main.py @@ -58,7 +58,7 @@ class AtlasApp: def add_routers(self): # pylint: disable=attribute-defined-outside-init - if not self.datasources.datasources: + if not self.datasources.redis or not self.datasources.mongodb: raise ValueError("Datasources not loaded") # User diff --git a/backend/bec_atlas/router/base_router.py b/backend/bec_atlas/router/base_router.py index 1b44150..041e5f7 100644 --- a/backend/bec_atlas/router/base_router.py +++ b/backend/bec_atlas/router/base_router.py @@ -10,12 +10,16 @@ if TYPE_CHECKING: # pragma: no cover class BaseRouter: - def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None: + def __init__( + self, prefix: str = "/api/v1", datasources: DatasourceManager | None = None + ) -> None: self.datasources = datasources self.prefix = prefix + if not self.datasources: + raise RuntimeError("Datasources not loaded") @lru_cache(maxsize=128) - def get_user_from_db(self, _token: str, email: str) -> User: + def get_user_from_db(self, _token: str, email: str) -> User | None: """ Get the user from the database. This is a helper function to be used by the convert_to_user decorator. The function is cached to avoid repeated database @@ -26,4 +30,6 @@ class BaseRouter: _token (str): The token email (str): The email """ - return self.datasources.datasources["mongodb"].get_user_by_email(email) + if not self.datasources: + raise RuntimeError("Datasources not loaded") + return self.datasources.mongodb.get_user_by_email(email) diff --git a/backend/bec_atlas/router/bec_access_router.py b/backend/bec_atlas/router/bec_access_router.py index 81be5d6..974d133 100644 --- a/backend/bec_atlas/router/bec_access_router.py +++ b/backend/bec_atlas/router/bec_access_router.py @@ -1,15 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from fastapi import APIRouter, Depends, HTTPException, Query from bec_atlas.authentication import convert_to_user, get_current_user from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource -from bec_atlas.model.model import BECAccessProfile, User, UserInfo +from bec_atlas.model.model import BECAccessProfile, User from bec_atlas.router.base_router import BaseRouter +if TYPE_CHECKING: # pragma: no cover + from bec_atlas.datasources.datasource_manager import DatasourceManager + class BECAccessRouter(BaseRouter): - def __init__(self, prefix="/api/v1", datasources=None): + def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + + if not self.datasources: + raise RuntimeError("Datasources not loaded") + + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/bec_access", diff --git a/backend/bec_atlas/router/deployment_access_router.py b/backend/bec_atlas/router/deployment_access_router.py index 0412e6b..aa286fb 100644 --- a/backend/bec_atlas/router/deployment_access_router.py +++ b/backend/bec_atlas/router/deployment_access_router.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import secrets import time from typing import TYPE_CHECKING, Any @@ -14,13 +16,16 @@ from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.redis_router import RedisAtlasEndpoints if TYPE_CHECKING: # pragma: no cover + from bec_atlas.datasources.datasource_manager import DatasourceManager from bec_atlas.datasources.redis_datasource import RedisDatasource class DeploymentAccessRouter(BaseRouter): - def __init__(self, prefix="/api/v1", datasources=None): + def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + if not self.datasources: + raise RuntimeError("Datasources not loaded") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/deployment_access", @@ -40,7 +45,7 @@ class DeploymentAccessRouter(BaseRouter): @convert_to_user async def get_deployment_access( self, deployment_id: str, current_user: User = Depends(get_current_user) - ) -> DeploymentAccess: + ) -> DeploymentAccess | None: """ Get the access lists for a specific deployment. @@ -104,7 +109,7 @@ class DeploymentAccessRouter(BaseRouter): Args: deployment_access (DeploymentAccess): The deployment access object """ - db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + db: MongoDBDatasource = self.datasources.mongodb new_profiles = set( updated.user_read_access @@ -160,8 +165,8 @@ class DeploymentAccessRouter(BaseRouter): """ Refresh the redis BEC access. """ - redis: RedisDatasource = self.datasources.datasources.get("redis") - db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + redis: RedisDatasource = self.datasources.redis + db: MongoDBDatasource = self.datasources.mongodb profiles = db.find( collection="bec_access_profiles", query_filter={"deployment_id": deployment_id}, @@ -190,7 +195,7 @@ class DeploymentAccessRouter(BaseRouter): Returns: bool: True if the user exists, False otherwise """ - db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + db: MongoDBDatasource = self.datasources.mongodb user = db.find_one("users", {"email": user}, User) return user is not None diff --git a/backend/bec_atlas/router/deployment_credentials.py b/backend/bec_atlas/router/deployment_credentials.py index 69d2645..4a2ab6a 100644 --- a/backend/bec_atlas/router/deployment_credentials.py +++ b/backend/bec_atlas/router/deployment_credentials.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover class DeploymentCredentialsRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/deploymentCredentials", @@ -80,7 +80,7 @@ class DeploymentCredentialsRouter(BaseRouter): raise HTTPException(status_code=404, detail="Deployment not found") # update the redis deployment key - redis: RedisDatasource = self.datasources.datasources.get("redis") + redis: RedisDatasource = self.datasources.redis redis.add_deployment_acl(out) return out diff --git a/backend/bec_atlas/router/deployments_router.py b/backend/bec_atlas/router/deployments_router.py index 78468b5..687cc14 100644 --- a/backend/bec_atlas/router/deployments_router.py +++ b/backend/bec_atlas/router/deployments_router.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover class DeploymentsRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/deployments/realm", @@ -72,7 +72,7 @@ class DeploymentsRouter(BaseRouter): self.available_deployments = self.db.find("deployments", {}, Deployments) credentials = self.db.find("deployment_credentials", {}, DeploymentCredential) - redis: RedisDatasource = self.datasources.datasources.get("redis") + redis: RedisDatasource = self.datasources.redis msg = json.dumps([msg.model_dump() for msg in self.available_deployments]) redis.connector.set_and_publish("deployments", msg) for deployment in credentials: diff --git a/backend/bec_atlas/router/realm_router.py b/backend/bec_atlas/router/realm_router.py index f0b0eda..a302dcb 100644 --- a/backend/bec_atlas/router/realm_router.py +++ b/backend/bec_atlas/router/realm_router.py @@ -9,7 +9,7 @@ from bec_atlas.router.base_router import BaseRouter class RealmRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/realms", diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 1f828e9..fc23580 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -141,8 +141,8 @@ class RedisRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None): super().__init__(prefix, datasources) - self.redis = self.datasources.datasources["redis"].async_connector - self.db = self.datasources.datasources["mongodb"] + self.redis = self.datasources.redis.async_connector + self.db = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( @@ -474,15 +474,15 @@ class RedisWebsocket: """ def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None): - self.redis: RedisConnector = datasources.datasources["redis"].connector + self.redis: RedisConnector = datasources.redis.connector self.prefix = prefix self.fastapi_app = app self.redis_router = app.redis_router self.active_connections = set() - redis_host = datasources.datasources["redis"].config["host"] - redis_port = datasources.datasources["redis"].config["port"] - redis_password = datasources.datasources["redis"].config.get("password", "ingestor") - self.db = datasources.datasources["mongodb"] + redis_host = datasources.redis.config["host"] + redis_port = datasources.redis.config["port"] + redis_password = datasources.redis.config.get("password", "ingestor") + self.db = datasources.mongodb self.socket = socketio.AsyncServer( transports=["websocket"], ping_timeout=60, diff --git a/backend/bec_atlas/router/scan_router.py b/backend/bec_atlas/router/scan_router.py index 189025a..5e7180d 100644 --- a/backend/bec_atlas/router/scan_router.py +++ b/backend/bec_atlas/router/scan_router.py @@ -14,7 +14,7 @@ from bec_atlas.router.base_router import BaseRouter class ScanRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/scans/session", diff --git a/backend/bec_atlas/router/session_router.py b/backend/bec_atlas/router/session_router.py index da988d5..60e1e95 100644 --- a/backend/bec_atlas/router/session_router.py +++ b/backend/bec_atlas/router/session_router.py @@ -11,7 +11,7 @@ from bec_atlas.router.base_router import BaseRouter class SessionRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): super().__init__(prefix, datasources) - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.router = APIRouter(prefix=prefix) self.router.add_api_route( "/sessions", diff --git a/backend/bec_atlas/router/user_router.py b/backend/bec_atlas/router/user_router.py index ecb71c1..1c7400c 100644 --- a/backend/bec_atlas/router/user_router.py +++ b/backend/bec_atlas/router/user_router.py @@ -28,7 +28,7 @@ class UserRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True): super().__init__(prefix, datasources) self.use_ssl = use_ssl - self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb") + self.db: MongoDBDatasource = self.datasources.mongodb self.ldap = LDAPUserService( ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch" ) diff --git a/backend/tests/test_scan_ingestor.py b/backend/tests/test_scan_ingestor.py index a19cd08..6041d94 100644 --- a/backend/tests/test_scan_ingestor.py +++ b/backend/tests/test_scan_ingestor.py @@ -25,7 +25,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend): Test that the login endpoint returns a token. """ client, app = backend - mongo: MongoDBDatasource = app.datasources.datasources["mongodb"] + mongo: MongoDBDatasource = app.datasources.mongodb deployment_id = str(mongo.find_one("deployments", {}, dtype=Deployments).id) session_id = str(mongo.find_one("sessions", {"deployment_id": deployment_id}, dtype=Session).id) msg = messages.ScanStatusMessage(