diff --git a/backend/bec_atlas/datasources/redis_datasource.py b/backend/bec_atlas/datasources/redis_datasource.py index 90874a8..c599ac0 100644 --- a/backend/bec_atlas/datasources/redis_datasource.py +++ b/backend/bec_atlas/datasources/redis_datasource.py @@ -1,4 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from bec_lib.redis_connector import RedisConnector +from redis.exceptions import AuthenticationError + +if TYPE_CHECKING: + from bec_atlas.model.model import Deployments class RedisDatasource: @@ -6,6 +14,69 @@ class RedisDatasource: self.config = config self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}") + try: + self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor") + self.reconfigured_acls = False + except AuthenticationError: + self.setup_acls() + self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor") + self.reconfigured_acls = True + print("Connected to Redis") + + def setup_acls(self): + """ + Setup the ACLs for the Redis proxy server. + """ + + # Create the ingestor user. This user is used by the data ingestor to write data to the database. + self.connector._redis_conn.acl_setuser( + "ingestor", + enabled=True, + passwords=f'+{self.config.get("password", "ingestor")}', + categories=["+@all"], + keys=["*"], + channels=["*"], + ) + + self.connector._redis_conn.acl_setuser( + "user", + enabled=True, + passwords="+user", + categories=["+@all"], + keys=["*"], + channels=["*"], + ) + self.connector._redis_conn.acl_setuser( + "default", enabled=True, categories=["-@all"], commands=["+auth", "+acl|whoami"] + ) + + def add_deployment_acl(self, deployment: Deployments): + """ + Add ACLs for the deployment. + + Args: + deployment (Deployments): The deployment object + """ + print(f"Adding ACLs for deployment <{deployment.name}>({deployment.id})") + self.connector._redis_conn.acl_setuser( + f"ingestor_{deployment.id}", + enabled=True, + passwords=f"+{deployment.deployment_key}", + categories=["+@all", "-@dangerous"], + keys=[ + f"internal/deployment/{deployment.id}/*", + f"internal/deployment/{deployment.id}/*/state", + f"internal/deployment/{deployment.id}/*/data/*", + ], + channels=[ + f"internal/deployment/{deployment.id}/*/state", + f"internal/deployment/{deployment.id}/*", + ], + commands=[f"+keys|internal/deployment/{deployment.id}/*/state"], + reset_channels=True, + reset_keys=True, + ) + def connect(self): pass diff --git a/backend/bec_atlas/ingestor/data_ingestor.py b/backend/bec_atlas/ingestor/data_ingestor.py index 55f94f3..ead3b01 100644 --- a/backend/bec_atlas/ingestor/data_ingestor.py +++ b/backend/bec_atlas/ingestor/data_ingestor.py @@ -1,12 +1,17 @@ from __future__ import annotations +import json import os import threading +from functools import lru_cache from bec_lib import messages from bec_lib.logger import bec_logger +from bec_lib.redis_connector import RedisConnector from bec_lib.serialization import MsgpackSerialization -from redis import Redis + +# from redis import Redis +from bson import ObjectId from redis.exceptions import ResponseError from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource @@ -24,32 +29,25 @@ class DataIngestor: redis_host = config.get("redis", {}).get("host", "localhost") redis_port = config.get("redis", {}).get("port", 6380) - self.redis = Redis(host=redis_host, port=redis_port) + self.redis = RedisConnector( + f"{redis_host}:{redis_port}" # username="ingestor", password="ingestor" + ) + # self.redis.authenticate( + # config.get("redis", {}).get("password", "ingestor"), username="ingestor" + # ) + + self.redis._redis_conn.connection_pool.connection_kwargs["username"] = "ingestor" + self.redis._redis_conn.connection_pool.connection_kwargs["password"] = "ingestor" self.shutdown_event = threading.Event() - self.available_deployments = {} + self.available_deployments = [] self.deployment_listener_thread = None self.receiver_thread = None + self.reclaim_pending_messages_thread = None self.consumer_name = f"ingestor_{os.getpid()}" - self.create_consumer_group() self.start_deployment_listener() self.start_receiver() - def create_consumer_group(self): - """ - Create the consumer group for the ingestor. - - """ - try: - self.redis.xgroup_create( - name="internal/database_ingest", groupname="ingestor", mkstream=True - ) - except ResponseError as exc: - if "BUSYGROUP Consumer Group name already exists" in str(exc): - logger.info("Consumer group already exists.") - else: - raise exc - def start_deployment_listener(self): """ Start the listener for the available deployments. @@ -57,7 +55,8 @@ class DataIngestor: """ out = self.redis.get("deployments") if out: - self.available_deployments = out + self.available_deployments = json.loads(out) + self.update_consumer_groups() self.deployment_listener_thread = threading.Thread( target=self.update_available_deployments, name="deployment_listener" ) @@ -70,19 +69,65 @@ class DataIngestor: """ self.receiver_thread = threading.Thread(target=self.ingestor_loop, name="receiver") self.receiver_thread.start() + self.reclaim_pending_messages_thread = threading.Thread( + target=self.reclaim_pending_messages, name="reclaim_pending_messages" + ) + self.reclaim_pending_messages_thread.start() def update_available_deployments(self): """ Update the available deployments from the Redis queue. """ - sub = self.redis.pubsub() - sub.subscribe("deployments") + + def _update_deployments(data, parent): + parent.available_deployments = data + parent.update_consumer_groups() + + self.redis.register("deployments", cb=_update_deployments, parent=self) + + def update_consumer_groups(self): + """ + Update the consumer groups for the available deployments. + + """ + for deployment in self.available_deployments: + try: + self.redis._redis_conn.xgroup_create( + name=f"internal/deployment/{deployment['id']}/ingest", + groupname="ingestor", + mkstream=True, + ) + except ResponseError as exc: + if "BUSYGROUP Consumer Group name already exists" in str(exc): + logger.info("Consumer group already exists.") + else: + raise exc + + def reclaim_pending_messages(self): + """ + Reclaim any pending messages from the Redis queue. + + """ while not self.shutdown_event.is_set(): - out = sub.get_message(ignore_subscribe_messages=True, timeout=1) - if out: - logger.info(f"Updating available deployments: {out}") - self.available_deployments = out - sub.close() + to_process = [] + for deployment in self.available_deployments: + pending_messages = self.redis._redis_conn.xautoclaim( + f"internal/deployment/{deployment['id']}/ingest", + "ingestor", + self.consumer_name, + min_idle_time=1000, + ) + if pending_messages[1]: + to_process.append( + [ + f"internal/deployment/{deployment['id']}/ingest".encode(), + pending_messages[1], + ] + ) + + if to_process: + self._handle_stream_messages(to_process) + self.shutdown_event.wait(10) def ingestor_loop(self): """ @@ -90,24 +135,32 @@ class DataIngestor: """ while not self.shutdown_event.is_set(): - data = self.redis.xreadgroup( - groupname="ingestor", - consumername=self.consumer_name, - streams={"internal/database_ingest": ">"}, - block=1000, + streams = { + f"internal/deployment/{deployment['id']}/ingest": ">" + for deployment in self.available_deployments + } + data = self.redis._redis_conn.xreadgroup( + groupname="ingestor", consumername=self.consumer_name, streams=streams, block=1000 ) if not data: logger.debug("No messages to ingest.") continue - for stream, msgs in data: - for message in msgs: - msg_dict = MsgpackSerialization.loads(message[1]) - self.handle_message(msg_dict) - self.redis.xack(stream, "ingestor", message[0]) + self._handle_stream_messages(data) - def handle_message(self, msg_dict: dict): + def _handle_stream_messages(self, data): + for stream, msgs in data: + for message in msgs: + msg = message[1] + out = {} + for key, val in msg.items(): + out[key.decode()] = MsgpackSerialization.loads(val) + deployment_id = stream.decode().split("/")[-2] + self.handle_message(out, deployment_id) + self.redis._redis_conn.xack(stream, "ingestor", message[0]) + + def handle_message(self, msg_dict: dict, deploymend_id: str): """ Handle a message from the Redis queue. @@ -119,22 +172,34 @@ class DataIngestor: data = msg_dict.get("data") if data is None: return - deployment = msg_dict.get("deployment") - if deployment is None: - return - - if not deployment == self.available_deployments.get(deployment): - return if isinstance(data, messages.ScanStatusMessage): - self.update_scan_status(data) + self.update_scan_status(data, deploymend_id) - def update_scan_status(self, msg: messages.ScanStatusMessage): + @lru_cache() + def get_default_session_id(self, deployment_id: str): + """ + Get the session id for a deployment. + + Args: + deployment_id (str): The deployment id + + Returns: + str: The session id + + """ + out = self.datasource.db["sessions"].find_one( + {"name": "_default_", "deployment_id": ObjectId(deployment_id)} + ) + return out["_id"] + + def update_scan_status(self, msg: messages.ScanStatusMessage, deployment_id: str): """ Update the status of a scan in the database. If the scan does not exist, create it. Args: msg (messages.ScanStatusMessage): The message containing the scan status. + deployment_id (str): The deployment id """ if not hasattr(msg, "session_id"): @@ -143,17 +208,22 @@ class DataIngestor: else: session_id = msg.session_id if not session_id: - return + session_id = "_default_" + + if session_id == "_default_": + session_id = self.get_default_session_id(deployment_id) + # scans are indexed by the scan_id, hence we can use find_one and search by the ObjectId data = self.datasource.db["scans"].find_one({"_id": msg.scan_id}) if data is None: msg_conv = ScanStatus( owner_groups=["admin"], access_groups=["admin"], **msg.model_dump() ) - msg_conv._id = msg_conv.scan_id + + out = msg_conv.model_dump(exclude_none=True) + out["_id"] = msg.scan_id # TODO for compatibility with the old message format; remove once the bec_lib is updated - out = msg_conv.__dict__ out["session_id"] = session_id self.datasource.db["scans"].insert_one(out) @@ -168,4 +238,25 @@ class DataIngestor: self.deployment_listener_thread.join() if self.receiver_thread: self.receiver_thread.join() + if self.reclaim_pending_messages_thread: + self.reclaim_pending_messages_thread.join() + self.redis.shutdown() self.datasource.shutdown() + + +def main(): # pragma: no cover + from bec_atlas.main import CONFIG + + ingestor = DataIngestor(config=CONFIG) + event = threading.Event() + while not event.is_set(): + try: + event.wait(1) + except KeyboardInterrupt: + event.set() + ingestor.shutdown() + + +if __name__ == "__main__": + bec_logger.level = bec_logger.LOGLEVEL.INFO + main() diff --git a/backend/bec_atlas/main.py b/backend/bec_atlas/main.py index 35e8166..09d8b92 100644 --- a/backend/bec_atlas/main.py +++ b/backend/bec_atlas/main.py @@ -5,12 +5,12 @@ from fastapi import FastAPI from bec_atlas.datasources.datasource_manager import DatasourceManager from bec_atlas.router.deployments_router import DeploymentsRouter from bec_atlas.router.realm_router import RealmRouter -from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket +from bec_atlas.router.redis_router import RedisWebsocket from bec_atlas.router.scan_router import ScanRouter from bec_atlas.router.user_router import UserRouter CONFIG = { - "redis": {"host": "localhost", "port": 6379}, + "redis": {"host": "localhost", "port": 6380}, "scylla": {"hosts": ["localhost"]}, "mongodb": {"host": "localhost", "port": 27017}, } @@ -25,15 +25,16 @@ class AtlasApp: self.server = None self.prefix = f"/api/{self.API_VERSION}" self.datasources = DatasourceManager(config=self.config) + self.datasources.connect() self.register_event_handler() - self.add_routers() + # self.add_routers() def register_event_handler(self): - self.app.add_event_handler("startup", self.on_startup) self.app.add_event_handler("shutdown", self.on_shutdown) + self.app.add_event_handler("startup", self.on_startup) async def on_startup(self): - self.datasources.connect() + self.add_routers() async def on_shutdown(self): self.datasources.shutdown() diff --git a/backend/bec_atlas/model/model.py b/backend/bec_atlas/model/model.py index fc3ec6d..ae29019 100644 --- a/backend/bec_atlas/model/model.py +++ b/backend/bec_atlas/model/model.py @@ -3,7 +3,19 @@ from typing import Literal from bec_lib import messages from bson import ObjectId -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class MongoBaseModel(BaseModel): + id: str | ObjectId | None = Field(default=None, alias="_id") + + model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) + + @field_serializer("id") + def serialize_id(self, id: str | ObjectId): + if isinstance(id, ObjectId): + return str(id) + return id class AccessProfile(BaseModel): @@ -11,40 +23,32 @@ class AccessProfile(BaseModel): access_groups: list[str] = [] -class ScanStatus(AccessProfile, messages.ScanStatusMessage): - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) +class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ... -class UserCredentials(AccessProfile): +class UserCredentials(MongoBaseModel, AccessProfile): user_id: str | ObjectId password: str - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - -class User(AccessProfile): - id: str | ObjectId | None = Field(default=None, alias="_id") +class User(MongoBaseModel, AccessProfile): email: str groups: list[str] first_name: str last_name: str - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - class UserInfo(BaseModel): email: str groups: list[str] -class Deployments(AccessProfile): +class Deployments(MongoBaseModel, AccessProfile): realm_id: str name: str deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4())) active_session_id: str | None = None - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - class Experiments(AccessProfile): realm_id: str @@ -76,21 +80,16 @@ class State(AccessProfile): model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) -class Session(AccessProfile): - realm_id: str - session_id: str - config: str - - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) +class Session(MongoBaseModel, AccessProfile): + deployment_id: str | ObjectId + name: str -class Realm(AccessProfile): +class Realm(MongoBaseModel, AccessProfile): realm_id: str deployments: list[Deployments] = [] name: str - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - class Datasets(AccessProfile): realm_id: str diff --git a/backend/bec_atlas/router/deployments_router.py b/backend/bec_atlas/router/deployments_router.py index b524cdb..db338e5 100644 --- a/backend/bec_atlas/router/deployments_router.py +++ b/backend/bec_atlas/router/deployments_router.py @@ -1,3 +1,6 @@ +import json +from typing import TYPE_CHECKING + from fastapi import APIRouter, Depends from bec_atlas.authentication import get_current_user @@ -5,6 +8,9 @@ from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource from bec_atlas.model.model import Deployments, UserInfo from bec_atlas.router.base_router import BaseRouter +if TYPE_CHECKING: + from bec_atlas.datasources.redis_datasource import RedisDatasource + class DeploymentsRouter(BaseRouter): def __init__(self, prefix="/api/v1", datasources=None): @@ -25,6 +31,7 @@ class DeploymentsRouter(BaseRouter): description="Get a single deployment by id for a realm", response_model=Deployments, ) + self.update_available_deployments() async def deployments( self, realm: str, current_user: UserInfo = Depends(get_current_user) @@ -52,3 +59,16 @@ class DeploymentsRouter(BaseRouter): return self.db.find_one( "deployments", {"_id": deployment_id}, Deployments, user=current_user ) + + def update_available_deployments(self): + """ + Update the available deployments. + """ + self.available_deployments = self.db.find("deployments", {}, Deployments) + + redis: RedisDatasource = self.datasources.datasources.get("redis") + msg = json.dumps([msg.model_dump() for msg in self.available_deployments]) + redis.connector.set_and_publish("deployments", msg) + if redis.reconfigured_acls: + for deployment in self.available_deployments: + redis.add_deployment_acl(deployment) diff --git a/backend/bec_atlas/router/redis_router.py b/backend/bec_atlas/router/redis_router.py index 56dec3c..8257070 100644 --- a/backend/bec_atlas/router/redis_router.py +++ b/backend/bec_atlas/router/redis_router.py @@ -3,10 +3,10 @@ import functools import inspect import json import traceback -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import socketio -from bec_lib.endpoints import MessageEndpoints +from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.logger import bec_logger from fastapi import APIRouter @@ -18,6 +18,55 @@ if TYPE_CHECKING: from bec_lib.redis_connector import RedisConnector +class RedisAtlasEndpoints: + """ + This class contains the endpoints for the Redis API. It is used to + manage the subscriptions and the state information for the websocket + """ + + @staticmethod + def websocket_state(deployment: str, host_id: str): + """ + Endpoint for the websocket state information, containing the users and their subscriptions + per backend host. + + Args: + deployment (str): The deployment name + host_id (str): The host id of the backend + + Returns: + str: The endpoint for the websocket state information + """ + return f"internal/deployment/{deployment}/{host_id}/state" + + @staticmethod + def redis_data(deployment: str, endpoint: str): + """ + Endpoint for the redis data for a deployment and endpoint. + + Args: + deployment (str): The deployment name + endpoint (str): The endpoint name + + Returns: + str: The endpoint for the redis data + """ + return f"internal/deployment/{deployment}/data/{endpoint}" + + @staticmethod + def socketio_endpoint_room(endpoint: str): + """ + Endpoint for the socketio room for an endpoint. + + Args: + endpoint (str): The endpoint name + + Returns: + str: The endpoint for the socketio room + """ + return f"ENDPOINT::{endpoint}" + + class RedisRouter(BaseRouter): """ This class is a router for the Redis API. It exposes the redis client through @@ -47,6 +96,7 @@ def safe_socket(fcn): async def wrapper(self, sid, *args, **kwargs): try: out = await fcn(self, sid, *args, **kwargs) + # pylint: disable=broad-except except Exception as exc: content = traceback.format_exc() logger.error(content) @@ -71,15 +121,16 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager): super().__init__(url, channel, write_only, logger, redis_options) self.requested_channels = [] self.started_update_loop = False + self.known_deployments = set() # task = asyncio.create_task(self._required_channel_heartbeat()) # loop.run_until_complete(task) def start_update_loop(self): self.started_update_loop = True - loop = asyncio.get_event_loop() - task = loop.create_task(self._backend_heartbeat()) - return task + # loop = asyncio.get_event_loop() + # task = loop.create_task(self._backend_heartbeat()) + # return task async def disconnect(self, sid, namespace, **kwargs): if kwargs.get("ignore_queue"): @@ -105,16 +156,25 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager): async def _backend_heartbeat(self): while not self.parent.fastapi_app.server.should_exit: - await asyncio.sleep(1) - await self.redis.publish(f"deployments/{self.host_id}/heartbeat/", "ping") - data = json.dumps(self.parent.users) - print(f"Sending heartbeat: {data}") - await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30) + await asyncio.sleep(10) + await self.update_state_info() async def update_state_info(self): - data = json.dumps(self.parent.users) - await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30) - await self.redis.publish(f"deployments/{self.host_id}/state/", data) + deployments = {deployment: {} for deployment in self.known_deployments} + for user in self.parent.users: + deployment = self.parent.users[user]["deployment"] + if deployment not in deployments: + deployments[deployment] = {} + self.known_deployments.add(deployment) + deployments[deployment][user] = self.parent.users[user] + for name, data in deployments.items(): + data_json = json.dumps(data) + await self.redis.set( + RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json, ex=30 + ) + await self.redis.publish( + RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json + ) async def update_websocket_states(self): loop = asyncio.get_event_loop() @@ -140,11 +200,16 @@ class RedisWebsocket: self.prefix = prefix self.fastapi_app = app self.active_connections = set() + redis_host = datasources.datasources["redis"].config["host"] + redis_port = datasources.datasources["redis"].config["port"] + redis_password = datasources.datasources["redis"].config.get("password", "ingestor") self.socket = socketio.AsyncServer( cors_allowed_origins="*", async_mode="asgi", client_manager=BECAsyncRedisManager( - self, url=f"redis://{self.redis.host}:{self.redis.port}/0" + self, + url=f"redis://{redis_host}:{redis_port}/0", + redis_options={"username": "ingestor", "password": redis_password}, ), ) self.app = socketio.ASGIApp(self.socket) @@ -155,11 +220,22 @@ class RedisWebsocket: self.socket.on("register", self.redis_register) self.socket.on("unregister", self.redis.unregister) self.socket.on("disconnect", self.disconnect_client) + print("Redis websocket started") - @safe_socket - async def connect_client(self, sid, environ=None): - print("Client connected") - http_query = environ.get("HTTP_QUERY") + def _validate_new_user(self, http_query: str | None): + """ + Validate the connection of a new user. In particular, + the user must provide a valid token as well as have access + to the deployment. If subscriptions are provided, the user + must have access to the endpoints. + + Args: + http_query (str): The query parameters of the websocket connection + + Returns: + str: The user name + + """ if not http_query: raise ValueError("Query parameters not found") query = json.loads(http_query) @@ -168,28 +244,49 @@ class RedisWebsocket: raise ValueError("User not found in query parameters") user = query["user"] - if sid not in self.users: - # check if the user was already registered in redis - deployment_keys = await self.socket.manager.redis.keys("deployments/*/state/") - if not deployment_keys: - state_data = [] - else: - state_data = await self.socket.manager.redis.mget(*deployment_keys) - info = {} - for data in state_data: - if not data: - continue - obj = json.loads(data) - for value in obj.values(): - info[value["user"]] = value["subscriptions"] + # TODO: Validate the user token - if user in info: - self.users[sid] = {"user": user, "subscriptions": info[user]} - for endpoint in set(self.users[sid]["subscriptions"]): - await self.socket.enter_room(sid, f"ENDPOINT::{endpoint}") - await self.socket.manager.update_websocket_states() - else: - self.users[sid] = {"user": query["user"], "subscriptions": []} + deployment = query.get("deployment") + if not deployment: + raise ValueError("Deployment not found in query parameters") + + # TODO: Validate the user has access to the deployment + + return user, deployment + + @safe_socket + async def connect_client(self, sid, environ=None): + if sid in self.users: + logger.info("User already connected") + return + + http_query = environ.get("HTTP_QUERY") + + user, deployment = self._validate_new_user(http_query) + + # check if the user was already registered in redis + socketio_server_keys = await self.socket.manager.redis.keys( + RedisAtlasEndpoints.websocket_state(deployment, "*") + ) + if not socketio_server_keys: + state_data = [] + else: + state_data = await self.socket.manager.redis.mget(*socketio_server_keys) + info = {} + for data in state_data: + if not data: + continue + obj = json.loads(data) + for value in obj.values(): + info[value["user"]] = value["subscriptions"] + + if user in info: + self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} + for endpoint in set(info[user]): + print(f"Registering {endpoint}") + await self._update_user_subscriptions(sid, endpoint) + else: + self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment} await self.socket.manager.update_websocket_states() @@ -214,8 +311,8 @@ class RedisWebsocket: try: print(msg) data = json.loads(msg) - except json.JSONDecodeError: - raise ValueError("Invalid JSON message") + except json.JSONDecodeError as exc: + raise ValueError("Invalid JSON message") from exc endpoint = getattr(MessageEndpoints, data.get("endpoint"), None) if endpoint is None: @@ -223,27 +320,36 @@ class RedisWebsocket: # check if the endpoint receives arguments if len(inspect.signature(endpoint).parameters) > 0: - endpoint = endpoint(data.get("args")) + endpoint: MessageEndpoints = endpoint(data.get("args")) else: - endpoint = endpoint() + endpoint: MessageEndpoints = endpoint() - self.redis.register(endpoint, cb=self.on_redis_message, parent=self) - if data.get("endpoint") not in self.users[sid]["subscriptions"]: - await self.socket.enter_room(sid, f"ENDPOINT::{data.get('endpoint')}") - self.users[sid]["subscriptions"].append(data.get("endpoint")) + await self._update_user_subscriptions(sid, endpoint.endpoint) + + async def _update_user_subscriptions(self, sid: str, endpoint: str): + deployment = self.users[sid]["deployment"] + + endpoint_info = EndpointInfo( + RedisAtlasEndpoints.redis_data(deployment, endpoint), Any, MessageOp.STREAM + ) + + room = RedisAtlasEndpoints.socketio_endpoint_room(endpoint) + self.redis.register(endpoint_info, cb=self.on_redis_message, parent=self, room=room) + if endpoint not in self.users[sid]["subscriptions"]: + await self.socket.enter_room(sid, room) + self.users[sid]["subscriptions"].append(endpoint) await self.socket.manager.update_websocket_states() @staticmethod - def on_redis_message(message, parent): + def on_redis_message(message, parent, room): async def emit_message(message): - outgoing = { - "data": message.value.model_dump_json(), - "message_type": message.value.__class__.__name__, - } - await parent.socket.emit("new_message", data=outgoing, room=message.topic) + if "pubsub_data" in message: + msg = message["pubsub_data"] + else: + msg = message["data"] + outgoing = {"data": msg.content, "metadata": msg.metadata} + outgoing = json.dumps(outgoing) + await parent.socket.emit("message", data=outgoing, room=room) - # check that the event loop is running - if not parent.loop.is_running(): - parent.loop.run_until_complete(emit_message(message)) - else: - asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop) + # Run the coroutine in this loop + asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop) diff --git a/backend/bec_atlas/utils/demo_database_setup.py b/backend/bec_atlas/utils/demo_database_setup.py index d2549eb..36d9dea 100644 --- a/backend/bec_atlas/utils/demo_database_setup.py +++ b/backend/bec_atlas/utils/demo_database_setup.py @@ -1,6 +1,6 @@ import pymongo -from bec_atlas.model import Deployments, Realm +from bec_atlas.model import Deployments, Realm, Session class DemoSetupLoader: @@ -30,7 +30,15 @@ class DemoSetupLoader: deployment = Deployments( realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"] ) - self.db["deployments"].insert_one(deployment.__dict__) + if self.db["deployments"].find_one({"name": deployment.name}) is None: + self.db["deployments"].insert_one(deployment.__dict__) + + if self.db["sessions"].find_one({"name": "_default_"}) is None: + deployment = self.db["deployments"].find_one({"name": deployment["name"]}) + default_session = Session( + owner_groups=["admin", "demo"], deployment_id=deployment["_id"], name="_default_" + ) + self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True)) if __name__ == "__main__":