mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
fix: towards a first working version
This commit is contained in:
@ -17,11 +17,16 @@ class DatasourceManager:
|
||||
datasource.connect()
|
||||
|
||||
def load_datasources(self):
|
||||
logger.info(f"Loading datasources with config: {self.config}")
|
||||
for datasource_name, datasource_config in self.config.items():
|
||||
if datasource_name == "redis":
|
||||
logger.info(
|
||||
f"Loading Redis datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
||||
)
|
||||
self.datasources[datasource_name] = RedisDatasource(datasource_config)
|
||||
if datasource_name == "mongodb":
|
||||
logger.info(
|
||||
f"Loading MongoDB datasource. Host: {datasource_config.get('host')}, Port: {datasource_config.get('port')}, Username: {datasource_config.get('username')}"
|
||||
)
|
||||
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
|
||||
|
||||
def shutdown(self):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import pymongo
|
||||
from bec_lib.logger import bec_logger
|
||||
@ -21,10 +22,21 @@ class MongoDBDatasource:
|
||||
"""
|
||||
Connect to the MongoDB database.
|
||||
"""
|
||||
host = self.config.get("host", "localhost")
|
||||
port = self.config.get("port", 27017)
|
||||
host = self.config.get("host")
|
||||
port = self.config.get("port")
|
||||
username = self.config.get("username")
|
||||
password = self.config.get("password")
|
||||
if username and password:
|
||||
self.client = pymongo.MongoClient(
|
||||
f"mongodb://{username}:{password}@{host}:{port}/?authSource=bec_atlas"
|
||||
)
|
||||
else:
|
||||
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}/")
|
||||
|
||||
# Check if the connection is successful
|
||||
self.client.list_databases()
|
||||
|
||||
logger.info(f"Connecting to MongoDB at {host}:{port}")
|
||||
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}/")
|
||||
self.db = self.client["bec_atlas"]
|
||||
if include_setup:
|
||||
self.db["users"].create_index([("email", 1)], unique=True)
|
||||
@ -55,7 +67,7 @@ class MongoDBDatasource:
|
||||
{
|
||||
"email": "jane.doe@bec_atlas.ch",
|
||||
"password": "atlas",
|
||||
"groups": ["demo_user"],
|
||||
"groups": ["demo"],
|
||||
"first_name": "Jane",
|
||||
"last_name": "Doe",
|
||||
"owner_groups": ["admin"],
|
||||
@ -136,30 +148,91 @@ class MongoDBDatasource:
|
||||
out = self.db[collection].find(query_filter)
|
||||
return [dtype(**x) for x in out]
|
||||
|
||||
def add_user_filter(self, user: User, query_filter: dict) -> dict:
|
||||
def aggregate(
|
||||
self, collection: str, pipeline: list[dict], dtype: BaseModel, user: User | None = None
|
||||
) -> list[BaseModel]:
|
||||
"""
|
||||
Aggregate documents in the collection.
|
||||
|
||||
Args:
|
||||
collection (str): The collection name
|
||||
pipeline (list[dict]): The aggregation pipeline
|
||||
dtype (BaseModel): The data type to return
|
||||
user (User): The user making the request
|
||||
|
||||
Returns:
|
||||
list[BaseModel]: The data type with the document data
|
||||
"""
|
||||
if user is not None:
|
||||
# Add the user filter to the lookup pipeline
|
||||
|
||||
for pipe in pipeline:
|
||||
if "$lookup" not in pipe:
|
||||
continue
|
||||
if "pipeline" not in pipe["$lookup"]:
|
||||
continue
|
||||
lookup = pipe["$lookup"]
|
||||
lookup_pipeline = lookup["pipeline"]
|
||||
access_filter = {"$match": self._read_only_user_filter(user)}
|
||||
lookup_pipeline.insert(0, access_filter)
|
||||
# pipeline = self.add_user_filter(user, pipeline)
|
||||
out = self.db[collection].aggregate(pipeline)
|
||||
return [dtype(**x) for x in out]
|
||||
|
||||
def add_user_filter(
|
||||
self, user: User, query_filter: dict, operation: Literal["r", "w"] = "r"
|
||||
) -> dict:
|
||||
"""
|
||||
Add the user filter to the query filter.
|
||||
|
||||
Args:
|
||||
user (User): The user making the request
|
||||
query_filter (dict): The query filter
|
||||
operation (Literal["r", "w"]): The operation to perform
|
||||
|
||||
Returns:
|
||||
dict: The updated query filter
|
||||
"""
|
||||
if operation == "r":
|
||||
user_filter = self._read_only_user_filter(user)
|
||||
else:
|
||||
user_filter = self._write_user_filter(user)
|
||||
if user_filter:
|
||||
query_filter = {"$and": [query_filter, user_filter]}
|
||||
return query_filter
|
||||
|
||||
def _read_only_user_filter(self, user: User) -> dict:
|
||||
"""
|
||||
Add the user filter to the query filter.
|
||||
|
||||
Args:
|
||||
user (User): The user making the request
|
||||
|
||||
Returns:
|
||||
dict: The updated query filter
|
||||
"""
|
||||
if "admin" not in user.groups:
|
||||
query_filter = {
|
||||
"$and": [
|
||||
query_filter,
|
||||
{
|
||||
"$or": [
|
||||
{"owner_groups": {"$in": user.groups}},
|
||||
{"access_groups": {"$in": user.groups}},
|
||||
]
|
||||
},
|
||||
return {
|
||||
"$or": [
|
||||
{"owner_groups": {"$in": user.groups}},
|
||||
{"access_groups": {"$in": user.groups}},
|
||||
]
|
||||
}
|
||||
return query_filter
|
||||
return {}
|
||||
|
||||
def _write_user_filter(self, user: User) -> dict:
|
||||
"""
|
||||
Add the user filter to the query filter.
|
||||
|
||||
Args:
|
||||
user (User): The user making the request
|
||||
|
||||
Returns:
|
||||
dict: The updated query filter
|
||||
"""
|
||||
if "admin" not in user.groups:
|
||||
return {"$or": [{"owner_groups": {"$in": user.groups}}]}
|
||||
return {}
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.exceptions import AuthenticationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -13,14 +14,26 @@ class RedisDatasource:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}")
|
||||
username = config.get("username")
|
||||
password = config.get("password")
|
||||
|
||||
try:
|
||||
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
|
||||
self.connector._redis_conn.auth(password, username=username)
|
||||
self.reconfigured_acls = False
|
||||
except AuthenticationError:
|
||||
self.setup_acls()
|
||||
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
|
||||
self.connector._redis_conn.auth(password, username=username)
|
||||
self.reconfigured_acls = True
|
||||
|
||||
self.connector._redis_conn.connection_pool.connection_kwargs["username"] = username
|
||||
self.connector._redis_conn.connection_pool.connection_kwargs["password"] = password
|
||||
|
||||
self.async_connector = AsyncRedis(
|
||||
host=config.get("host"),
|
||||
port=config.get("port"),
|
||||
username="ingestor",
|
||||
password=config.get("password"),
|
||||
)
|
||||
print("Connected to Redis")
|
||||
|
||||
def setup_acls(self):
|
||||
@ -32,7 +45,7 @@ class RedisDatasource:
|
||||
self.connector._redis_conn.acl_setuser(
|
||||
"ingestor",
|
||||
enabled=True,
|
||||
passwords=f'+{self.config.get("password", "ingestor")}',
|
||||
passwords=f'+{self.config.get("password")}',
|
||||
categories=["+@all"],
|
||||
keys=["*"],
|
||||
channels=["*"],
|
||||
@ -71,6 +84,8 @@ class RedisDatasource:
|
||||
channels=[
|
||||
f"internal/deployment/{deployment.id}/*/state",
|
||||
f"internal/deployment/{deployment.id}/*",
|
||||
f"internal/deployment/{deployment.id}/request",
|
||||
f"internal/deployment/{deployment.id}/request_response/*",
|
||||
],
|
||||
commands=[f"+keys|internal/deployment/{deployment.id}/*/state"],
|
||||
reset_channels=True,
|
||||
|
@ -135,6 +135,9 @@ class DataIngestor:
|
||||
|
||||
"""
|
||||
while not self.shutdown_event.is_set():
|
||||
if not self.available_deployments:
|
||||
self.shutdown_event.wait(1)
|
||||
continue
|
||||
streams = {
|
||||
f"internal/deployment/{deployment['id']}/ingest": ">"
|
||||
for deployment in self.available_deployments
|
||||
|
@ -1,17 +1,15 @@
|
||||
import socketio
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.router.deployments_router import DeploymentsRouter
|
||||
from bec_atlas.router.realm_router import RealmRouter
|
||||
from bec_atlas.router.redis_router import RedisWebsocket
|
||||
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
|
||||
from bec_atlas.router.scan_router import ScanRouter
|
||||
from bec_atlas.router.user_router import UserRouter
|
||||
|
||||
CONFIG = {
|
||||
"redis": {"host": "localhost", "port": 6380},
|
||||
"scylla": {"hosts": ["localhost"]},
|
||||
"mongodb": {"host": "localhost", "port": 27017},
|
||||
}
|
||||
|
||||
@ -40,6 +38,7 @@ class AtlasApp:
|
||||
self.datasources.shutdown()
|
||||
|
||||
def add_routers(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if not self.datasources.datasources:
|
||||
raise ValueError("Datasources not loaded")
|
||||
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
@ -50,12 +49,13 @@ class AtlasApp:
|
||||
self.app.include_router(self.deployment_router.router, tags=["Deployment"])
|
||||
self.realm_router = RealmRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.realm_router.router, tags=["Realm"])
|
||||
self.redis_router = RedisRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.redis_router.router, tags=["Redis"])
|
||||
|
||||
if "redis" in self.datasources.datasources:
|
||||
self.redis_websocket = RedisWebsocket(
|
||||
prefix=self.prefix, datasources=self.datasources, app=self
|
||||
)
|
||||
self.app.mount("/", self.redis_websocket.app)
|
||||
self.redis_websocket = RedisWebsocket(
|
||||
prefix=self.prefix, datasources=self.datasources, app=self
|
||||
)
|
||||
self.app.mount("/", self.redis_websocket.app)
|
||||
|
||||
def run(self, port=8000):
|
||||
config = uvicorn.Config(self.app, host="localhost", port=port)
|
||||
@ -66,12 +66,18 @@ class AtlasApp:
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from bec_atlas.utils.env_loader import load_env
|
||||
|
||||
config = load_env()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the BEC Atlas API")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run the API on")
|
||||
|
||||
args = parser.parse_args()
|
||||
horizon_app = AtlasApp()
|
||||
horizon_app = AtlasApp(config=config)
|
||||
horizon_app.run(port=args.port)
|
||||
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from bec_lib import messages
|
||||
from bson import ObjectId
|
||||
@ -23,6 +25,11 @@ class AccessProfile(BaseModel):
|
||||
access_groups: list[str] = []
|
||||
|
||||
|
||||
class AccessProfilePartial(AccessProfile):
|
||||
owner_groups: list[str] | None = None
|
||||
access_groups: list[str] | None = None
|
||||
|
||||
|
||||
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
|
||||
|
||||
|
||||
@ -44,10 +51,25 @@ class UserInfo(BaseModel):
|
||||
|
||||
|
||||
class Deployments(MongoBaseModel, AccessProfile):
|
||||
realm_id: str
|
||||
realm_id: str | ObjectId
|
||||
name: str
|
||||
deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
active_session_id: str | None = None
|
||||
active_session_id: str | ObjectId | None = None
|
||||
config_templates: list[str | ObjectId] = []
|
||||
|
||||
|
||||
class DeploymentsPartial(MongoBaseModel, AccessProfilePartial):
|
||||
realm_id: str | ObjectId | None = None
|
||||
name: str | None = None
|
||||
deployment_key: str | None = None
|
||||
active_session_id: str | ObjectId | None = None
|
||||
config_templates: list[str | ObjectId] | None = None
|
||||
|
||||
|
||||
class Realm(MongoBaseModel, AccessProfile):
|
||||
realm_id: str
|
||||
deployments: list[Deployments | DeploymentsPartial] = []
|
||||
name: str
|
||||
|
||||
|
||||
class Experiments(AccessProfile):
|
||||
@ -85,12 +107,6 @@ class Session(MongoBaseModel, AccessProfile):
|
||||
name: str
|
||||
|
||||
|
||||
class Realm(MongoBaseModel, AccessProfile):
|
||||
realm_id: str
|
||||
deployments: list[Deployments] = []
|
||||
name: str
|
||||
|
||||
|
||||
class Datasets(AccessProfile):
|
||||
realm_id: str
|
||||
dataset_id: str
|
||||
@ -126,18 +142,33 @@ class DeviceConfig(AccessProfile):
|
||||
software_trigger: bool
|
||||
|
||||
|
||||
class SignalData(AccessProfile):
|
||||
scan_id: str
|
||||
device_id: str
|
||||
device_name: str
|
||||
class SignalData(AccessProfile, MongoBaseModel):
|
||||
"""
|
||||
Signal data for a device. This is the ophyd signal data,
|
||||
aggregated for a single scan. Upon completion of a scan,
|
||||
the data is aggregated and stored in this format. If possible,
|
||||
the data ingestor will calculate the average, standard deviation,
|
||||
min, and max values for the signal.
|
||||
"""
|
||||
|
||||
scan_id: str | ObjectId | None = None
|
||||
device_id: str | ObjectId
|
||||
signal_name: str
|
||||
data: float | int | str | bool | bytes | dict | list | None
|
||||
timestamp: float
|
||||
kind: Literal["hinted", "omitted", "normal", "config"]
|
||||
data: list[Any]
|
||||
timestamps: list[float]
|
||||
kind: Literal["hinted", "normal", "config", "omitted"]
|
||||
average: float | None = None
|
||||
std_dev: float | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
|
||||
|
||||
class DeviceData(AccessProfile):
|
||||
scan_id: str | None
|
||||
device_name: str
|
||||
device_config_id: str
|
||||
class DeviceData(AccessProfile, MongoBaseModel):
|
||||
scan_id: str | ObjectId | None = None
|
||||
name: str
|
||||
device_config_id: str | ObjectId
|
||||
signals: list[SignalData]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
out = DeploymentsPartial(realm_id="123")
|
||||
|
@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import Realm
|
||||
from bec_atlas.model.model import Realm, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -14,27 +15,48 @@ class RealmRouter(BaseRouter):
|
||||
"/realms",
|
||||
self.realms,
|
||||
methods=["GET"],
|
||||
description="Get all deployments for the realm",
|
||||
description="Get all realms",
|
||||
response_model=list[Realm],
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/realms/{realm_id}",
|
||||
self.realm_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single deployment by id for a realm",
|
||||
description="Get a single realm by id",
|
||||
response_model=Realm,
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
|
||||
async def realms(self) -> list[Realm]:
|
||||
async def realms(
|
||||
self, include_deployments: bool = False, current_user: UserInfo = Depends(get_current_user)
|
||||
) -> list[Realm]:
|
||||
"""
|
||||
Get all realms.
|
||||
|
||||
Args:
|
||||
include_deployments (bool): Include deployments in the response
|
||||
|
||||
Returns:
|
||||
list[Realm]: List of realms
|
||||
"""
|
||||
return self.db.find("realms", {}, Realm)
|
||||
if include_deployments:
|
||||
include = [
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "deployments",
|
||||
"let": {"realm_id": "$_id"},
|
||||
"pipeline": [{"$match": {"$expr": {"$eq": ["$realm_id", "$$realm_id"]}}}],
|
||||
"as": "deployments",
|
||||
}
|
||||
}
|
||||
]
|
||||
return self.db.aggregate("realms", include, Realm, user=current_user)
|
||||
return self.db.find("realms", {}, Realm, user=current_user)
|
||||
|
||||
async def realm_with_id(self, realm_id: str):
|
||||
async def realm_with_id(
|
||||
self, realm_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get realm with id.
|
||||
|
||||
@ -44,4 +66,4 @@ class RealmRouter(BaseRouter):
|
||||
Returns:
|
||||
Realm: The realm with the id
|
||||
"""
|
||||
return self.db.find_one("realms", {"_id": realm_id}, Realm)
|
||||
return self.db.find_one("realms", {"_id": realm_id}, Realm, user=current_user)
|
||||
|
@ -3,13 +3,14 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import socketio
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.serialization import json_ext
|
||||
from fastapi import APIRouter
|
||||
from bec_lib.serialization import MsgpackSerialization, json_ext
|
||||
from fastapi import APIRouter, Query, Response
|
||||
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
@ -67,6 +68,40 @@ class RedisAtlasEndpoints:
|
||||
"""
|
||||
return f"socketio/rooms/{deployment}/{endpoint}"
|
||||
|
||||
@staticmethod
|
||||
def redis_request(deployment: str):
|
||||
"""
|
||||
Endpoint for the redis request for a deployment and endpoint.
|
||||
|
||||
Args:
|
||||
deployment (str): The deployment name
|
||||
|
||||
Returns:
|
||||
str: The endpoint for the redis request
|
||||
"""
|
||||
return f"internal/deployment/{deployment}/request"
|
||||
|
||||
@staticmethod
|
||||
def redis_request_response(deployment: str, request_id: str):
|
||||
"""
|
||||
Endpoint for the redis request response for a deployment and endpoint.
|
||||
|
||||
Args:
|
||||
deployment (str): The deployment name
|
||||
request_id (str): The request id
|
||||
|
||||
Returns:
|
||||
str: The endpoint for the redis request response
|
||||
"""
|
||||
return f"internal/deployment/{deployment}/request_response/{request_id}"
|
||||
|
||||
|
||||
class MsgResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content: Any) -> bytes:
|
||||
return content.encode()
|
||||
|
||||
|
||||
class RedisRouter(BaseRouter):
|
||||
"""
|
||||
@ -76,14 +111,30 @@ class RedisRouter(BaseRouter):
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.redis = self.datasources.datasources["redis"].connector
|
||||
self.redis = self.datasources.datasources["redis"].async_connector
|
||||
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/redis", self.redis_get, methods=["GET"])
|
||||
self.router.add_api_route(
|
||||
"/redis/{deployment}", self.redis_get, methods=["GET"], response_class=MsgResponse
|
||||
)
|
||||
self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
|
||||
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
|
||||
|
||||
async def redis_get(self, key: str):
|
||||
return self.redis.get(key)
|
||||
async def redis_get(self, deployment: str, key: str = Query(...)):
|
||||
request_id = uuid.uuid4().hex
|
||||
response_endpoint = RedisAtlasEndpoints.redis_request_response(deployment, request_id)
|
||||
request_endpoint = RedisAtlasEndpoints.redis_request(deployment)
|
||||
pubsub = self.redis.pubsub()
|
||||
pubsub.ignore_subscribe_messages = True
|
||||
await pubsub.subscribe(response_endpoint)
|
||||
data = {"action": "get", "key": key, "response_endpoint": response_endpoint}
|
||||
await self.redis.publish(request_endpoint, json.dumps(data))
|
||||
response = await pubsub.get_message(timeout=10)
|
||||
print(response)
|
||||
response = await pubsub.get_message(timeout=10)
|
||||
out = MsgpackSerialization.loads(response["data"])
|
||||
|
||||
return json_ext.dumps({"data": out.content, "metadata": out.metadata})
|
||||
|
||||
async def redis_post(self, key: str, value: str):
|
||||
return self.redis.set(key, value)
|
||||
@ -129,9 +180,9 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager):
|
||||
|
||||
def start_update_loop(self):
|
||||
self.started_update_loop = True
|
||||
# loop = asyncio.get_event_loop()
|
||||
# task = loop.create_task(self._backend_heartbeat())
|
||||
# return task
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(self._backend_heartbeat())
|
||||
return task
|
||||
|
||||
async def disconnect(self, sid, namespace, **kwargs):
|
||||
if kwargs.get("ignore_queue"):
|
||||
@ -205,6 +256,8 @@ class RedisWebsocket:
|
||||
redis_port = datasources.datasources["redis"].config["port"]
|
||||
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
|
||||
self.socket = socketio.AsyncServer(
|
||||
transports=["websocket"],
|
||||
ping_timeout=60,
|
||||
cors_allowed_origins="*",
|
||||
async_mode="asgi",
|
||||
client_manager=BECAsyncRedisManager(
|
||||
@ -239,7 +292,10 @@ class RedisWebsocket:
|
||||
"""
|
||||
if not http_query:
|
||||
raise ValueError("Query parameters not found")
|
||||
query = json.loads(http_query)
|
||||
if isinstance(http_query, str):
|
||||
query = json.loads(http_query)
|
||||
else:
|
||||
query = http_query
|
||||
|
||||
if "user" not in query:
|
||||
raise ValueError("User not found in query parameters")
|
||||
@ -256,12 +312,12 @@ class RedisWebsocket:
|
||||
return user, deployment
|
||||
|
||||
@safe_socket
|
||||
async def connect_client(self, sid, environ=None):
|
||||
async def connect_client(self, sid, environ=None, auth=None, **kwargs):
|
||||
if sid in self.users:
|
||||
logger.info("User already connected")
|
||||
return
|
||||
|
||||
http_query = environ.get("HTTP_QUERY")
|
||||
http_query = environ.get("HTTP_QUERY") or auth
|
||||
|
||||
user, deployment = self._validate_new_user(http_query)
|
||||
|
||||
@ -283,9 +339,9 @@ class RedisWebsocket:
|
||||
|
||||
if user in info:
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
for endpoint in set(info[user]):
|
||||
for endpoint, endpoint_request in info[user]:
|
||||
print(f"Registering {endpoint}")
|
||||
await self._update_user_subscriptions(sid, endpoint)
|
||||
await self._update_user_subscriptions(sid, endpoint, endpoint_request)
|
||||
else:
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
|
||||
@ -321,13 +377,16 @@ class RedisWebsocket:
|
||||
|
||||
# check if the endpoint receives arguments
|
||||
if len(inspect.signature(endpoint).parameters) > 0:
|
||||
endpoint: MessageEndpoints = endpoint(data.get("args"))
|
||||
args = data.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
args = [args]
|
||||
endpoint: MessageEndpoints = endpoint(*args)
|
||||
else:
|
||||
endpoint: MessageEndpoints = endpoint()
|
||||
|
||||
await self._update_user_subscriptions(sid, endpoint.endpoint)
|
||||
await self._update_user_subscriptions(sid, endpoint.endpoint, msg)
|
||||
|
||||
async def _update_user_subscriptions(self, sid: str, endpoint: str):
|
||||
async def _update_user_subscriptions(self, sid: str, endpoint: str, endpoint_request: str):
|
||||
deployment = self.users[sid]["deployment"]
|
||||
|
||||
endpoint_info = EndpointInfo(
|
||||
@ -335,20 +394,31 @@ class RedisWebsocket:
|
||||
)
|
||||
|
||||
room = RedisAtlasEndpoints.socketio_endpoint_room(deployment, endpoint)
|
||||
self.redis.register(endpoint_info, cb=self.on_redis_message, parent=self, room=room)
|
||||
self.redis.register(
|
||||
endpoint_info,
|
||||
cb=self.on_redis_message,
|
||||
parent=self,
|
||||
room=room,
|
||||
endpoint_request=endpoint_request,
|
||||
)
|
||||
if endpoint not in self.users[sid]["subscriptions"]:
|
||||
await self.socket.enter_room(sid, room)
|
||||
self.users[sid]["subscriptions"].append(endpoint)
|
||||
self.users[sid]["subscriptions"].append((endpoint, endpoint_request))
|
||||
await self.socket.manager.update_websocket_states()
|
||||
|
||||
@staticmethod
|
||||
def on_redis_message(message, parent, room):
|
||||
def on_redis_message(message, parent, room, endpoint_request):
|
||||
async def emit_message(message):
|
||||
if "pubsub_data" in message:
|
||||
msg = message["pubsub_data"]
|
||||
else:
|
||||
msg = message["data"]
|
||||
outgoing = {"data": msg.content, "metadata": msg.metadata}
|
||||
outgoing = {
|
||||
"data": msg.content,
|
||||
"metadata": msg.metadata,
|
||||
"endpoint": room.split("/", 3)[-1],
|
||||
"endpoint_request": endpoint_request,
|
||||
}
|
||||
outgoing = json_ext.dumps(outgoing)
|
||||
await parent.socket.emit("message", data=outgoing, room=room)
|
||||
|
||||
|
@ -16,19 +16,32 @@ class DemoSetupLoader:
|
||||
self.load_deployments()
|
||||
|
||||
def load_realm(self):
|
||||
realm = Realm(realm_id="demo_beamline_1", name="Demo Beamline 1", owner_groups=["admin"])
|
||||
realm = Realm(
|
||||
realm_id="demo_beamline_1",
|
||||
name="Demo Beamline 1",
|
||||
owner_groups=["admin"],
|
||||
access_groups=["auth_user"],
|
||||
)
|
||||
realm._id = realm.realm_id
|
||||
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
|
||||
self.db["realms"].insert_one(realm.__dict__)
|
||||
|
||||
realm = Realm(realm_id="demo_beamline_2", name="Demo Beamline 2", owner_groups=["admin"])
|
||||
realm = Realm(
|
||||
realm_id="demo_beamline_2",
|
||||
name="Demo Beamline 2",
|
||||
owner_groups=["admin"],
|
||||
access_groups=["auth_user"],
|
||||
)
|
||||
realm._id = realm.realm_id
|
||||
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
|
||||
self.db["realms"].insert_one(realm.__dict__)
|
||||
|
||||
def load_deployments(self):
|
||||
deployment = Deployments(
|
||||
realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"]
|
||||
realm_id="demo_beamline_1",
|
||||
name="Demo Deployment 1",
|
||||
owner_groups=["admin", "demo"],
|
||||
access_groups=["demo"],
|
||||
)
|
||||
if self.db["deployments"].find_one({"name": deployment.name}) is None:
|
||||
self.db["deployments"].insert_one(deployment.__dict__)
|
||||
@ -36,7 +49,10 @@ class DemoSetupLoader:
|
||||
if self.db["sessions"].find_one({"name": "_default_"}) is None:
|
||||
deployment = self.db["deployments"].find_one({"name": deployment["name"]})
|
||||
default_session = Session(
|
||||
owner_groups=["admin", "demo"], deployment_id=deployment["_id"], name="_default_"
|
||||
owner_groups=["admin", "demo"],
|
||||
access_groups=["demo"],
|
||||
deployment_id=deployment["_id"],
|
||||
name="_default_",
|
||||
)
|
||||
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
|
||||
|
||||
|
25
backend/bec_atlas/utils/env_loader.py
Normal file
25
backend/bec_atlas/utils/env_loader.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_env() -> dict:
|
||||
"""
|
||||
Load the environment variables from the .env file.
|
||||
"""
|
||||
env_file = "./.env.yaml"
|
||||
|
||||
if not os.path.exists(env_file):
|
||||
env_file = os.path.join(os.path.dirname(__file__), ".env.yaml")
|
||||
|
||||
if not os.path.exists(env_file):
|
||||
# check if there is an env file in the parent directory
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
env_file = os.path.join(current_dir, ".env.yaml")
|
||||
|
||||
if not os.path.exists(env_file):
|
||||
raise FileNotFoundError(f"Could not find .env file in {os.getcwd()} or {current_dir}")
|
||||
|
||||
with open(env_file, "r", encoding="utf-8") as file:
|
||||
yaml_config = yaml.safe_load(file)
|
||||
return yaml_config
|
@ -132,7 +132,12 @@ def backend(redis_container, mongo_container):
|
||||
redis_host, redis_port = redis_container
|
||||
mongo_host, mongo_port = mongo_container
|
||||
config = {
|
||||
"redis": {"host": redis_host, "port": redis_port},
|
||||
"redis": {
|
||||
"host": redis_host,
|
||||
"port": redis_port,
|
||||
"username": "ingestor",
|
||||
"password": "ingestor",
|
||||
},
|
||||
"mongodb": {"host": mongo_host, "port": mongo_port},
|
||||
}
|
||||
|
||||
|
67
backend/utils/sls_deployments.yaml
Normal file
67
backend/utils/sls_deployments.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
|
||||
ADDAMS:
|
||||
x04sa-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for ADDAMS
|
||||
cSAXS:
|
||||
x12sa-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for cSAXS
|
||||
x12sa-bec-002.psi.ch:
|
||||
name: test
|
||||
description: Test environment for cSAXS
|
||||
Debye:
|
||||
x01da-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for Debye
|
||||
MicroXAS:
|
||||
x05la-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for MicroXAS
|
||||
x05la-bec-002.psi.ch:
|
||||
name: test
|
||||
description: Test environment for MicroXAS
|
||||
Phoenix:
|
||||
x07mb-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for Phoenix
|
||||
PolLux:
|
||||
x07da-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for PolLux
|
||||
PXI:
|
||||
x06sa-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for PXI
|
||||
PXII:
|
||||
x10sa-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for PXII
|
||||
PXIII:
|
||||
x06da-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for PXIII
|
||||
SIM:
|
||||
x11ma-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for SIM
|
||||
SuperXAS:
|
||||
x10da-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for SuperXAS
|
||||
I-TOMCAT:
|
||||
x02da-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for I-TOMCAT
|
||||
x02da-bec-002.psi.ch:
|
||||
name: test
|
||||
description: Test environment for I-TOMCAT
|
||||
X-Treme:
|
||||
x07ma-bec-001.psi.ch:
|
||||
name: production
|
||||
description: Primary deployment for X-Treme
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user