mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-13 22:51:49 +02:00
refactor(backend): replaced datasources dict with properties
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from bec_lib.logger import bec_logger
|
||||
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
@ -9,26 +11,36 @@ logger = bec_logger.logger
|
||||
class DatasourceManager:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.datasources = {}
|
||||
self._redis: RedisDatasource | None = None
|
||||
self._mongodb: MongoDBDatasource | None = None
|
||||
self.load_datasources()
|
||||
|
||||
def connect(self):
|
||||
for datasource in self.datasources.values():
|
||||
datasource.connect()
|
||||
self.redis.connect()
|
||||
self.mongodb.connect()
|
||||
|
||||
def load_datasources(self):
|
||||
for datasource_name, datasource_config in self.config.items():
|
||||
if datasource_name == "redis":
|
||||
logger.info(
|
||||
f"Loading Redis datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
||||
)
|
||||
self.datasources[datasource_name] = RedisDatasource(datasource_config)
|
||||
if datasource_name == "mongodb":
|
||||
logger.info(
|
||||
f"Loading MongoDB datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
||||
)
|
||||
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
|
||||
redis_config = self.config.get("redis")
|
||||
if redis_config:
|
||||
self._redis = RedisDatasource(redis_config)
|
||||
mongodb_config = self.config.get("mongodb")
|
||||
if mongodb_config:
|
||||
self._mongodb = MongoDBDatasource(mongodb_config)
|
||||
|
||||
@property
|
||||
def redis(self) -> RedisDatasource:
|
||||
if not self._redis:
|
||||
raise RuntimeError("Redis datasource not loaded")
|
||||
return self._redis
|
||||
|
||||
@property
|
||||
def mongodb(self) -> MongoDBDatasource:
|
||||
if not self._mongodb:
|
||||
raise RuntimeError("MongoDB datasource not loaded")
|
||||
return self._mongodb
|
||||
|
||||
def shutdown(self):
|
||||
for datasource in self.datasources.values():
|
||||
datasource.shutdown()
|
||||
if self._redis:
|
||||
self._redis.shutdown()
|
||||
if self._mongodb:
|
||||
self._mongodb.shutdown()
|
||||
|
@ -58,7 +58,7 @@ class AtlasApp:
|
||||
|
||||
def add_routers(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if not self.datasources.datasources:
|
||||
if not self.datasources.redis or not self.datasources.mongodb:
|
||||
raise ValueError("Datasources not loaded")
|
||||
|
||||
# User
|
||||
|
@ -10,12 +10,16 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
|
||||
class BaseRouter:
|
||||
def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None:
|
||||
def __init__(
|
||||
self, prefix: str = "/api/v1", datasources: DatasourceManager | None = None
|
||||
) -> None:
|
||||
self.datasources = datasources
|
||||
self.prefix = prefix
|
||||
if not self.datasources:
|
||||
raise RuntimeError("Datasources not loaded")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_user_from_db(self, _token: str, email: str) -> User:
|
||||
def get_user_from_db(self, _token: str, email: str) -> User | None:
|
||||
"""
|
||||
Get the user from the database. This is a helper function to be used by the
|
||||
convert_to_user decorator. The function is cached to avoid repeated database
|
||||
@ -26,4 +30,6 @@ class BaseRouter:
|
||||
_token (str): The token
|
||||
email (str): The email
|
||||
"""
|
||||
return self.datasources.datasources["mongodb"].get_user_by_email(email)
|
||||
if not self.datasources:
|
||||
raise RuntimeError("Datasources not loaded")
|
||||
return self.datasources.mongodb.get_user_by_email(email)
|
||||
|
@ -1,15 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import BECAccessProfile, User, UserInfo
|
||||
from bec_atlas.model.model import BECAccessProfile, User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
|
||||
|
||||
class BECAccessRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
|
||||
if not self.datasources:
|
||||
raise RuntimeError("Datasources not loaded")
|
||||
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/bec_access",
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -14,13 +16,16 @@ from bec_atlas.router.base_router import BaseRouter
|
||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||
|
||||
|
||||
class DeploymentAccessRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
if not self.datasources:
|
||||
raise RuntimeError("Datasources not loaded")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/deployment_access",
|
||||
@ -40,7 +45,7 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
@convert_to_user
|
||||
async def get_deployment_access(
|
||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||
) -> DeploymentAccess:
|
||||
) -> DeploymentAccess | None:
|
||||
"""
|
||||
Get the access lists for a specific deployment.
|
||||
|
||||
@ -104,7 +109,7 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
Args:
|
||||
deployment_access (DeploymentAccess): The deployment access object
|
||||
"""
|
||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
db: MongoDBDatasource = self.datasources.mongodb
|
||||
|
||||
new_profiles = set(
|
||||
updated.user_read_access
|
||||
@ -160,8 +165,8 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
"""
|
||||
Refresh the redis BEC access.
|
||||
"""
|
||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
redis: RedisDatasource = self.datasources.redis
|
||||
db: MongoDBDatasource = self.datasources.mongodb
|
||||
profiles = db.find(
|
||||
collection="bec_access_profiles",
|
||||
query_filter={"deployment_id": deployment_id},
|
||||
@ -190,7 +195,7 @@ class DeploymentAccessRouter(BaseRouter):
|
||||
Returns:
|
||||
bool: True if the user exists, False otherwise
|
||||
"""
|
||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
db: MongoDBDatasource = self.datasources.mongodb
|
||||
user = db.find_one("users", {"email": user}, User)
|
||||
return user is not None
|
||||
|
||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
class DeploymentCredentialsRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/deploymentCredentials",
|
||||
@ -80,7 +80,7 @@ class DeploymentCredentialsRouter(BaseRouter):
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
# update the redis deployment key
|
||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
||||
redis: RedisDatasource = self.datasources.redis
|
||||
redis.add_deployment_acl(out)
|
||||
|
||||
return out
|
||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
class DeploymentsRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/deployments/realm",
|
||||
@ -72,7 +72,7 @@ class DeploymentsRouter(BaseRouter):
|
||||
self.available_deployments = self.db.find("deployments", {}, Deployments)
|
||||
credentials = self.db.find("deployment_credentials", {}, DeploymentCredential)
|
||||
|
||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
||||
redis: RedisDatasource = self.datasources.redis
|
||||
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
|
||||
redis.connector.set_and_publish("deployments", msg)
|
||||
for deployment in credentials:
|
||||
|
@ -9,7 +9,7 @@ from bec_atlas.router.base_router import BaseRouter
|
||||
class RealmRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/realms",
|
||||
|
@ -141,8 +141,8 @@ class RedisRouter(BaseRouter):
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.redis = self.datasources.datasources["redis"].async_connector
|
||||
self.db = self.datasources.datasources["mongodb"]
|
||||
self.redis = self.datasources.redis.async_connector
|
||||
self.db = self.datasources.mongodb
|
||||
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
@ -474,15 +474,15 @@ class RedisWebsocket:
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
|
||||
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
||||
self.redis: RedisConnector = datasources.redis.connector
|
||||
self.prefix = prefix
|
||||
self.fastapi_app = app
|
||||
self.redis_router = app.redis_router
|
||||
self.active_connections = set()
|
||||
redis_host = datasources.datasources["redis"].config["host"]
|
||||
redis_port = datasources.datasources["redis"].config["port"]
|
||||
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
|
||||
self.db = datasources.datasources["mongodb"]
|
||||
redis_host = datasources.redis.config["host"]
|
||||
redis_port = datasources.redis.config["port"]
|
||||
redis_password = datasources.redis.config.get("password", "ingestor")
|
||||
self.db = datasources.mongodb
|
||||
self.socket = socketio.AsyncServer(
|
||||
transports=["websocket"],
|
||||
ping_timeout=60,
|
||||
|
@ -14,7 +14,7 @@ from bec_atlas.router.base_router import BaseRouter
|
||||
class ScanRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/scans/session",
|
||||
|
@ -11,7 +11,7 @@ from bec_atlas.router.base_router import BaseRouter
|
||||
class SessionRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/sessions",
|
||||
|
@ -28,7 +28,7 @@ class UserRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
|
||||
super().__init__(prefix, datasources)
|
||||
self.use_ssl = use_ssl
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||
self.ldap = LDAPUserService(
|
||||
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
|
||||
Test that the login endpoint returns a token.
|
||||
"""
|
||||
client, app = backend
|
||||
mongo: MongoDBDatasource = app.datasources.datasources["mongodb"]
|
||||
mongo: MongoDBDatasource = app.datasources.mongodb
|
||||
deployment_id = str(mongo.find_one("deployments", {}, dtype=Deployments).id)
|
||||
session_id = str(mongo.find_one("sessions", {"deployment_id": deployment_id}, dtype=Session).id)
|
||||
msg = messages.ScanStatusMessage(
|
||||
|
Reference in New Issue
Block a user