mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-13 22:51:49 +02:00
refactor(backend): replaced datasources dict with properties
This commit is contained in:
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
|
|
||||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||||
@ -9,26 +11,36 @@ logger = bec_logger.logger
|
|||||||
class DatasourceManager:
|
class DatasourceManager:
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasources = {}
|
self._redis: RedisDatasource | None = None
|
||||||
|
self._mongodb: MongoDBDatasource | None = None
|
||||||
self.load_datasources()
|
self.load_datasources()
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
for datasource in self.datasources.values():
|
self.redis.connect()
|
||||||
datasource.connect()
|
self.mongodb.connect()
|
||||||
|
|
||||||
def load_datasources(self):
|
def load_datasources(self):
|
||||||
for datasource_name, datasource_config in self.config.items():
|
redis_config = self.config.get("redis")
|
||||||
if datasource_name == "redis":
|
if redis_config:
|
||||||
logger.info(
|
self._redis = RedisDatasource(redis_config)
|
||||||
f"Loading Redis datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
mongodb_config = self.config.get("mongodb")
|
||||||
)
|
if mongodb_config:
|
||||||
self.datasources[datasource_name] = RedisDatasource(datasource_config)
|
self._mongodb = MongoDBDatasource(mongodb_config)
|
||||||
if datasource_name == "mongodb":
|
|
||||||
logger.info(
|
@property
|
||||||
f"Loading MongoDB datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
def redis(self) -> RedisDatasource:
|
||||||
)
|
if not self._redis:
|
||||||
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
|
raise RuntimeError("Redis datasource not loaded")
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mongodb(self) -> MongoDBDatasource:
|
||||||
|
if not self._mongodb:
|
||||||
|
raise RuntimeError("MongoDB datasource not loaded")
|
||||||
|
return self._mongodb
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
for datasource in self.datasources.values():
|
if self._redis:
|
||||||
datasource.shutdown()
|
self._redis.shutdown()
|
||||||
|
if self._mongodb:
|
||||||
|
self._mongodb.shutdown()
|
||||||
|
@ -58,7 +58,7 @@ class AtlasApp:
|
|||||||
|
|
||||||
def add_routers(self):
|
def add_routers(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if not self.datasources.datasources:
|
if not self.datasources.redis or not self.datasources.mongodb:
|
||||||
raise ValueError("Datasources not loaded")
|
raise ValueError("Datasources not loaded")
|
||||||
|
|
||||||
# User
|
# User
|
||||||
|
@ -10,12 +10,16 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
|
|
||||||
|
|
||||||
class BaseRouter:
|
class BaseRouter:
|
||||||
def __init__(self, prefix: str = "/api/v1", datasources: DatasourceManager = None) -> None:
|
def __init__(
|
||||||
|
self, prefix: str = "/api/v1", datasources: DatasourceManager | None = None
|
||||||
|
) -> None:
|
||||||
self.datasources = datasources
|
self.datasources = datasources
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
if not self.datasources:
|
||||||
|
raise RuntimeError("Datasources not loaded")
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def get_user_from_db(self, _token: str, email: str) -> User:
|
def get_user_from_db(self, _token: str, email: str) -> User | None:
|
||||||
"""
|
"""
|
||||||
Get the user from the database. This is a helper function to be used by the
|
Get the user from the database. This is a helper function to be used by the
|
||||||
convert_to_user decorator. The function is cached to avoid repeated database
|
convert_to_user decorator. The function is cached to avoid repeated database
|
||||||
@ -26,4 +30,6 @@ class BaseRouter:
|
|||||||
_token (str): The token
|
_token (str): The token
|
||||||
email (str): The email
|
email (str): The email
|
||||||
"""
|
"""
|
||||||
return self.datasources.datasources["mongodb"].get_user_by_email(email)
|
if not self.datasources:
|
||||||
|
raise RuntimeError("Datasources not loaded")
|
||||||
|
return self.datasources.mongodb.get_user_by_email(email)
|
||||||
|
@ -1,15 +1,26 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
|
||||||
from bec_atlas.authentication import convert_to_user, get_current_user
|
from bec_atlas.authentication import convert_to_user, get_current_user
|
||||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||||
from bec_atlas.model.model import BECAccessProfile, User, UserInfo
|
from bec_atlas.model.model import BECAccessProfile, User
|
||||||
from bec_atlas.router.base_router import BaseRouter
|
from bec_atlas.router.base_router import BaseRouter
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||||
|
|
||||||
|
|
||||||
class BECAccessRouter(BaseRouter):
|
class BECAccessRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
|
||||||
|
if not self.datasources:
|
||||||
|
raise RuntimeError("Datasources not loaded")
|
||||||
|
|
||||||
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/bec_access",
|
"/bec_access",
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -14,13 +16,16 @@ from bec_atlas.router.base_router import BaseRouter
|
|||||||
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
from bec_atlas.router.redis_router import RedisAtlasEndpoints
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||||
|
|
||||||
|
|
||||||
class DeploymentAccessRouter(BaseRouter):
|
class DeploymentAccessRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources: DatasourceManager | None = None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
if not self.datasources:
|
||||||
|
raise RuntimeError("Datasources not loaded")
|
||||||
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/deployment_access",
|
"/deployment_access",
|
||||||
@ -40,7 +45,7 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
@convert_to_user
|
@convert_to_user
|
||||||
async def get_deployment_access(
|
async def get_deployment_access(
|
||||||
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
self, deployment_id: str, current_user: User = Depends(get_current_user)
|
||||||
) -> DeploymentAccess:
|
) -> DeploymentAccess | None:
|
||||||
"""
|
"""
|
||||||
Get the access lists for a specific deployment.
|
Get the access lists for a specific deployment.
|
||||||
|
|
||||||
@ -104,7 +109,7 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
Args:
|
Args:
|
||||||
deployment_access (DeploymentAccess): The deployment access object
|
deployment_access (DeploymentAccess): The deployment access object
|
||||||
"""
|
"""
|
||||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
db: MongoDBDatasource = self.datasources.mongodb
|
||||||
|
|
||||||
new_profiles = set(
|
new_profiles = set(
|
||||||
updated.user_read_access
|
updated.user_read_access
|
||||||
@ -160,8 +165,8 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
"""
|
"""
|
||||||
Refresh the redis BEC access.
|
Refresh the redis BEC access.
|
||||||
"""
|
"""
|
||||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
redis: RedisDatasource = self.datasources.redis
|
||||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
db: MongoDBDatasource = self.datasources.mongodb
|
||||||
profiles = db.find(
|
profiles = db.find(
|
||||||
collection="bec_access_profiles",
|
collection="bec_access_profiles",
|
||||||
query_filter={"deployment_id": deployment_id},
|
query_filter={"deployment_id": deployment_id},
|
||||||
@ -190,7 +195,7 @@ class DeploymentAccessRouter(BaseRouter):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the user exists, False otherwise
|
bool: True if the user exists, False otherwise
|
||||||
"""
|
"""
|
||||||
db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
db: MongoDBDatasource = self.datasources.mongodb
|
||||||
user = db.find_one("users", {"email": user}, User)
|
user = db.find_one("users", {"email": user}, User)
|
||||||
return user is not None
|
return user is not None
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
class DeploymentCredentialsRouter(BaseRouter):
|
class DeploymentCredentialsRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/deploymentCredentials",
|
"/deploymentCredentials",
|
||||||
@ -80,7 +80,7 @@ class DeploymentCredentialsRouter(BaseRouter):
|
|||||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||||
|
|
||||||
# update the redis deployment key
|
# update the redis deployment key
|
||||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
redis: RedisDatasource = self.datasources.redis
|
||||||
redis.add_deployment_acl(out)
|
redis.add_deployment_acl(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
class DeploymentsRouter(BaseRouter):
|
class DeploymentsRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/deployments/realm",
|
"/deployments/realm",
|
||||||
@ -72,7 +72,7 @@ class DeploymentsRouter(BaseRouter):
|
|||||||
self.available_deployments = self.db.find("deployments", {}, Deployments)
|
self.available_deployments = self.db.find("deployments", {}, Deployments)
|
||||||
credentials = self.db.find("deployment_credentials", {}, DeploymentCredential)
|
credentials = self.db.find("deployment_credentials", {}, DeploymentCredential)
|
||||||
|
|
||||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
redis: RedisDatasource = self.datasources.redis
|
||||||
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
|
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
|
||||||
redis.connector.set_and_publish("deployments", msg)
|
redis.connector.set_and_publish("deployments", msg)
|
||||||
for deployment in credentials:
|
for deployment in credentials:
|
||||||
|
@ -9,7 +9,7 @@ from bec_atlas.router.base_router import BaseRouter
|
|||||||
class RealmRouter(BaseRouter):
|
class RealmRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/realms",
|
"/realms",
|
||||||
|
@ -141,8 +141,8 @@ class RedisRouter(BaseRouter):
|
|||||||
|
|
||||||
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
|
def __init__(self, prefix="/api/v1", datasources: DatasourceManager = None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.redis = self.datasources.datasources["redis"].async_connector
|
self.redis = self.datasources.redis.async_connector
|
||||||
self.db = self.datasources.datasources["mongodb"]
|
self.db = self.datasources.mongodb
|
||||||
|
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
@ -474,15 +474,15 @@ class RedisWebsocket:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
|
def __init__(self, prefix="/api/v1", datasources=None, app: AtlasApp = None):
|
||||||
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
self.redis: RedisConnector = datasources.redis.connector
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.fastapi_app = app
|
self.fastapi_app = app
|
||||||
self.redis_router = app.redis_router
|
self.redis_router = app.redis_router
|
||||||
self.active_connections = set()
|
self.active_connections = set()
|
||||||
redis_host = datasources.datasources["redis"].config["host"]
|
redis_host = datasources.redis.config["host"]
|
||||||
redis_port = datasources.datasources["redis"].config["port"]
|
redis_port = datasources.redis.config["port"]
|
||||||
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
|
redis_password = datasources.redis.config.get("password", "ingestor")
|
||||||
self.db = datasources.datasources["mongodb"]
|
self.db = datasources.mongodb
|
||||||
self.socket = socketio.AsyncServer(
|
self.socket = socketio.AsyncServer(
|
||||||
transports=["websocket"],
|
transports=["websocket"],
|
||||||
ping_timeout=60,
|
ping_timeout=60,
|
||||||
|
@ -14,7 +14,7 @@ from bec_atlas.router.base_router import BaseRouter
|
|||||||
class ScanRouter(BaseRouter):
|
class ScanRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/scans/session",
|
"/scans/session",
|
||||||
|
@ -11,7 +11,7 @@ from bec_atlas.router.base_router import BaseRouter
|
|||||||
class SessionRouter(BaseRouter):
|
class SessionRouter(BaseRouter):
|
||||||
def __init__(self, prefix="/api/v1", datasources=None):
|
def __init__(self, prefix="/api/v1", datasources=None):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.router = APIRouter(prefix=prefix)
|
self.router = APIRouter(prefix=prefix)
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/sessions",
|
"/sessions",
|
||||||
|
@ -28,7 +28,7 @@ class UserRouter(BaseRouter):
|
|||||||
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
|
def __init__(self, prefix="/api/v1", datasources=None, use_ssl=True):
|
||||||
super().__init__(prefix, datasources)
|
super().__init__(prefix, datasources)
|
||||||
self.use_ssl = use_ssl
|
self.use_ssl = use_ssl
|
||||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
self.db: MongoDBDatasource = self.datasources.mongodb
|
||||||
self.ldap = LDAPUserService(
|
self.ldap = LDAPUserService(
|
||||||
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
|
ldap_server="ldaps://d.psi.ch", base_dn="OU=users,OU=psi,DC=d,DC=psi,DC=ch"
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ def test_scan_ingestor_create_scan(scan_ingestor, backend):
|
|||||||
Test that the login endpoint returns a token.
|
Test that the login endpoint returns a token.
|
||||||
"""
|
"""
|
||||||
client, app = backend
|
client, app = backend
|
||||||
mongo: MongoDBDatasource = app.datasources.datasources["mongodb"]
|
mongo: MongoDBDatasource = app.datasources.mongodb
|
||||||
deployment_id = str(mongo.find_one("deployments", {}, dtype=Deployments).id)
|
deployment_id = str(mongo.find_one("deployments", {}, dtype=Deployments).id)
|
||||||
session_id = str(mongo.find_one("sessions", {"deployment_id": deployment_id}, dtype=Session).id)
|
session_id = str(mongo.find_one("sessions", {"deployment_id": deployment_id}, dtype=Session).id)
|
||||||
msg = messages.ScanStatusMessage(
|
msg = messages.ScanStatusMessage(
|
||||||
|
Reference in New Issue
Block a user