This commit is contained in:
2024-12-16 11:27:59 +01:00
parent 4a45119549
commit 3479e579ab
7 changed files with 432 additions and 136 deletions

View File

@ -1,4 +1,12 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from bec_lib.redis_connector import RedisConnector
from redis.exceptions import AuthenticationError
if TYPE_CHECKING:
from bec_atlas.model.model import Deployments
class RedisDatasource:
@ -6,6 +14,69 @@ class RedisDatasource:
self.config = config
self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}")
try:
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
self.reconfigured_acls = False
except AuthenticationError:
self.setup_acls()
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
self.reconfigured_acls = True
print("Connected to Redis")
def setup_acls(self):
"""
Setup the ACLs for the Redis proxy server.
"""
# Create the ingestor user. This user is used by the data ingestor to write data to the database.
self.connector._redis_conn.acl_setuser(
"ingestor",
enabled=True,
passwords=f'+{self.config.get("password", "ingestor")}',
categories=["+@all"],
keys=["*"],
channels=["*"],
)
self.connector._redis_conn.acl_setuser(
"user",
enabled=True,
passwords="+user",
categories=["+@all"],
keys=["*"],
channels=["*"],
)
self.connector._redis_conn.acl_setuser(
"default", enabled=True, categories=["-@all"], commands=["+auth", "+acl|whoami"]
)
def add_deployment_acl(self, deployment: Deployments):
"""
Add ACLs for the deployment.
Args:
deployment (Deployments): The deployment object
"""
print(f"Adding ACLs for deployment <{deployment.name}>({deployment.id})")
self.connector._redis_conn.acl_setuser(
f"ingestor_{deployment.id}",
enabled=True,
passwords=f"+{deployment.deployment_key}",
categories=["+@all", "-@dangerous"],
keys=[
f"internal/deployment/{deployment.id}/*",
f"internal/deployment/{deployment.id}/*/state",
f"internal/deployment/{deployment.id}/*/data/*",
],
channels=[
f"internal/deployment/{deployment.id}/*/state",
f"internal/deployment/{deployment.id}/*",
],
commands=[f"+keys|internal/deployment/{deployment.id}/*/state"],
reset_channels=True,
reset_keys=True,
)
def connect(self):
pass

View File

@ -1,12 +1,17 @@
from __future__ import annotations
import json
import os
import threading
from functools import lru_cache
from bec_lib import messages
from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisConnector
from bec_lib.serialization import MsgpackSerialization
from redis import Redis
# from redis import Redis
from bson import ObjectId
from redis.exceptions import ResponseError
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
@ -24,32 +29,25 @@ class DataIngestor:
redis_host = config.get("redis", {}).get("host", "localhost")
redis_port = config.get("redis", {}).get("port", 6380)
self.redis = Redis(host=redis_host, port=redis_port)
self.redis = RedisConnector(
f"{redis_host}:{redis_port}" # username="ingestor", password="ingestor"
)
# self.redis.authenticate(
# config.get("redis", {}).get("password", "ingestor"), username="ingestor"
# )
self.redis._redis_conn.connection_pool.connection_kwargs["username"] = "ingestor"
self.redis._redis_conn.connection_pool.connection_kwargs["password"] = "ingestor"
self.shutdown_event = threading.Event()
self.available_deployments = {}
self.available_deployments = []
self.deployment_listener_thread = None
self.receiver_thread = None
self.reclaim_pending_messages_thread = None
self.consumer_name = f"ingestor_{os.getpid()}"
self.create_consumer_group()
self.start_deployment_listener()
self.start_receiver()
def create_consumer_group(self):
"""
Create the consumer group for the ingestor.
"""
try:
self.redis.xgroup_create(
name="internal/database_ingest", groupname="ingestor", mkstream=True
)
except ResponseError as exc:
if "BUSYGROUP Consumer Group name already exists" in str(exc):
logger.info("Consumer group already exists.")
else:
raise exc
def start_deployment_listener(self):
"""
Start the listener for the available deployments.
@ -57,7 +55,8 @@ class DataIngestor:
"""
out = self.redis.get("deployments")
if out:
self.available_deployments = out
self.available_deployments = json.loads(out)
self.update_consumer_groups()
self.deployment_listener_thread = threading.Thread(
target=self.update_available_deployments, name="deployment_listener"
)
@ -70,19 +69,65 @@ class DataIngestor:
"""
self.receiver_thread = threading.Thread(target=self.ingestor_loop, name="receiver")
self.receiver_thread.start()
self.reclaim_pending_messages_thread = threading.Thread(
target=self.reclaim_pending_messages, name="reclaim_pending_messages"
)
self.reclaim_pending_messages_thread.start()
def update_available_deployments(self):
"""
Update the available deployments from the Redis queue.
"""
sub = self.redis.pubsub()
sub.subscribe("deployments")
def _update_deployments(data, parent):
parent.available_deployments = data
parent.update_consumer_groups()
self.redis.register("deployments", cb=_update_deployments, parent=self)
def update_consumer_groups(self):
"""
Update the consumer groups for the available deployments.
"""
for deployment in self.available_deployments:
try:
self.redis._redis_conn.xgroup_create(
name=f"internal/deployment/{deployment['id']}/ingest",
groupname="ingestor",
mkstream=True,
)
except ResponseError as exc:
if "BUSYGROUP Consumer Group name already exists" in str(exc):
logger.info("Consumer group already exists.")
else:
raise exc
def reclaim_pending_messages(self):
"""
Reclaim any pending messages from the Redis queue.
"""
while not self.shutdown_event.is_set():
out = sub.get_message(ignore_subscribe_messages=True, timeout=1)
if out:
logger.info(f"Updating available deployments: {out}")
self.available_deployments = out
sub.close()
to_process = []
for deployment in self.available_deployments:
pending_messages = self.redis._redis_conn.xautoclaim(
f"internal/deployment/{deployment['id']}/ingest",
"ingestor",
self.consumer_name,
min_idle_time=1000,
)
if pending_messages[1]:
to_process.append(
[
f"internal/deployment/{deployment['id']}/ingest".encode(),
pending_messages[1],
]
)
if to_process:
self._handle_stream_messages(to_process)
self.shutdown_event.wait(10)
def ingestor_loop(self):
"""
@ -90,24 +135,32 @@ class DataIngestor:
"""
while not self.shutdown_event.is_set():
data = self.redis.xreadgroup(
groupname="ingestor",
consumername=self.consumer_name,
streams={"internal/database_ingest": ">"},
block=1000,
streams = {
f"internal/deployment/{deployment['id']}/ingest": ">"
for deployment in self.available_deployments
}
data = self.redis._redis_conn.xreadgroup(
groupname="ingestor", consumername=self.consumer_name, streams=streams, block=1000
)
if not data:
logger.debug("No messages to ingest.")
continue
self._handle_stream_messages(data)
def _handle_stream_messages(self, data):
for stream, msgs in data:
for message in msgs:
msg_dict = MsgpackSerialization.loads(message[1])
self.handle_message(msg_dict)
self.redis.xack(stream, "ingestor", message[0])
msg = message[1]
out = {}
for key, val in msg.items():
out[key.decode()] = MsgpackSerialization.loads(val)
deployment_id = stream.decode().split("/")[-2]
self.handle_message(out, deployment_id)
self.redis._redis_conn.xack(stream, "ingestor", message[0])
def handle_message(self, msg_dict: dict):
def handle_message(self, msg_dict: dict, deploymend_id: str):
"""
Handle a message from the Redis queue.
@ -119,22 +172,34 @@ class DataIngestor:
data = msg_dict.get("data")
if data is None:
return
deployment = msg_dict.get("deployment")
if deployment is None:
return
if not deployment == self.available_deployments.get(deployment):
return
if isinstance(data, messages.ScanStatusMessage):
self.update_scan_status(data)
self.update_scan_status(data, deploymend_id)
def update_scan_status(self, msg: messages.ScanStatusMessage):
@lru_cache()
def get_default_session_id(self, deployment_id: str):
"""
Get the session id for a deployment.
Args:
deployment_id (str): The deployment id
Returns:
str: The session id
"""
out = self.datasource.db["sessions"].find_one(
{"name": "_default_", "deployment_id": ObjectId(deployment_id)}
)
return out["_id"]
def update_scan_status(self, msg: messages.ScanStatusMessage, deployment_id: str):
"""
Update the status of a scan in the database. If the scan does not exist, create it.
Args:
msg (messages.ScanStatusMessage): The message containing the scan status.
deployment_id (str): The deployment id
"""
if not hasattr(msg, "session_id"):
@ -143,17 +208,22 @@ class DataIngestor:
else:
session_id = msg.session_id
if not session_id:
return
session_id = "_default_"
if session_id == "_default_":
session_id = self.get_default_session_id(deployment_id)
# scans are indexed by the scan_id, hence we can use find_one and search by the ObjectId
data = self.datasource.db["scans"].find_one({"_id": msg.scan_id})
if data is None:
msg_conv = ScanStatus(
owner_groups=["admin"], access_groups=["admin"], **msg.model_dump()
)
msg_conv._id = msg_conv.scan_id
out = msg_conv.model_dump(exclude_none=True)
out["_id"] = msg.scan_id
# TODO for compatibility with the old message format; remove once the bec_lib is updated
out = msg_conv.__dict__
out["session_id"] = session_id
self.datasource.db["scans"].insert_one(out)
@ -168,4 +238,25 @@ class DataIngestor:
self.deployment_listener_thread.join()
if self.receiver_thread:
self.receiver_thread.join()
if self.reclaim_pending_messages_thread:
self.reclaim_pending_messages_thread.join()
self.redis.shutdown()
self.datasource.shutdown()
def main(): # pragma: no cover
from bec_atlas.main import CONFIG
ingestor = DataIngestor(config=CONFIG)
event = threading.Event()
while not event.is_set():
try:
event.wait(1)
except KeyboardInterrupt:
event.set()
ingestor.shutdown()
if __name__ == "__main__":
bec_logger.level = bec_logger.LOGLEVEL.INFO
main()

View File

@ -5,12 +5,12 @@ from fastapi import FastAPI
from bec_atlas.datasources.datasource_manager import DatasourceManager
from bec_atlas.router.deployments_router import DeploymentsRouter
from bec_atlas.router.realm_router import RealmRouter
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
from bec_atlas.router.redis_router import RedisWebsocket
from bec_atlas.router.scan_router import ScanRouter
from bec_atlas.router.user_router import UserRouter
CONFIG = {
"redis": {"host": "localhost", "port": 6379},
"redis": {"host": "localhost", "port": 6380},
"scylla": {"hosts": ["localhost"]},
"mongodb": {"host": "localhost", "port": 27017},
}
@ -25,15 +25,16 @@ class AtlasApp:
self.server = None
self.prefix = f"/api/{self.API_VERSION}"
self.datasources = DatasourceManager(config=self.config)
self.datasources.connect()
self.register_event_handler()
self.add_routers()
# self.add_routers()
def register_event_handler(self):
self.app.add_event_handler("startup", self.on_startup)
self.app.add_event_handler("shutdown", self.on_shutdown)
self.app.add_event_handler("startup", self.on_startup)
async def on_startup(self):
self.datasources.connect()
self.add_routers()
async def on_shutdown(self):
self.datasources.shutdown()

View File

@ -3,7 +3,19 @@ from typing import Literal
from bec_lib import messages
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class MongoBaseModel(BaseModel):
id: str | ObjectId | None = Field(default=None, alias="_id")
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
@field_serializer("id")
def serialize_id(self, id: str | ObjectId):
if isinstance(id, ObjectId):
return str(id)
return id
class AccessProfile(BaseModel):
@ -11,40 +23,32 @@ class AccessProfile(BaseModel):
access_groups: list[str] = []
class ScanStatus(AccessProfile, messages.ScanStatusMessage):
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
class UserCredentials(AccessProfile):
class UserCredentials(MongoBaseModel, AccessProfile):
user_id: str | ObjectId
password: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class User(AccessProfile):
id: str | ObjectId | None = Field(default=None, alias="_id")
class User(MongoBaseModel, AccessProfile):
email: str
groups: list[str]
first_name: str
last_name: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class UserInfo(BaseModel):
email: str
groups: list[str]
class Deployments(AccessProfile):
class Deployments(MongoBaseModel, AccessProfile):
realm_id: str
name: str
deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4()))
active_session_id: str | None = None
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Experiments(AccessProfile):
realm_id: str
@ -76,21 +80,16 @@ class State(AccessProfile):
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Session(AccessProfile):
realm_id: str
session_id: str
config: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Session(MongoBaseModel, AccessProfile):
deployment_id: str | ObjectId
name: str
class Realm(AccessProfile):
class Realm(MongoBaseModel, AccessProfile):
realm_id: str
deployments: list[Deployments] = []
name: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Datasets(AccessProfile):
realm_id: str

View File

@ -1,3 +1,6 @@
import json
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends
from bec_atlas.authentication import get_current_user
@ -5,6 +8,9 @@ from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import Deployments, UserInfo
from bec_atlas.router.base_router import BaseRouter
if TYPE_CHECKING:
from bec_atlas.datasources.redis_datasource import RedisDatasource
class DeploymentsRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
@ -25,6 +31,7 @@ class DeploymentsRouter(BaseRouter):
description="Get a single deployment by id for a realm",
response_model=Deployments,
)
self.update_available_deployments()
async def deployments(
self, realm: str, current_user: UserInfo = Depends(get_current_user)
@ -52,3 +59,16 @@ class DeploymentsRouter(BaseRouter):
return self.db.find_one(
"deployments", {"_id": deployment_id}, Deployments, user=current_user
)
def update_available_deployments(self):
"""
Update the available deployments.
"""
self.available_deployments = self.db.find("deployments", {}, Deployments)
redis: RedisDatasource = self.datasources.datasources.get("redis")
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
redis.connector.set_and_publish("deployments", msg)
if redis.reconfigured_acls:
for deployment in self.available_deployments:
redis.add_deployment_acl(deployment)

View File

@ -3,10 +3,10 @@ import functools
import inspect
import json
import traceback
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import socketio
from bec_lib.endpoints import MessageEndpoints
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
from bec_lib.logger import bec_logger
from fastapi import APIRouter
@ -18,6 +18,55 @@ if TYPE_CHECKING:
from bec_lib.redis_connector import RedisConnector
class RedisAtlasEndpoints:
"""
This class contains the endpoints for the Redis API. It is used to
manage the subscriptions and the state information for the websocket
"""
@staticmethod
def websocket_state(deployment: str, host_id: str):
"""
Endpoint for the websocket state information, containing the users and their subscriptions
per backend host.
Args:
deployment (str): The deployment name
host_id (str): The host id of the backend
Returns:
str: The endpoint for the websocket state information
"""
return f"internal/deployment/{deployment}/{host_id}/state"
@staticmethod
def redis_data(deployment: str, endpoint: str):
"""
Endpoint for the redis data for a deployment and endpoint.
Args:
deployment (str): The deployment name
endpoint (str): The endpoint name
Returns:
str: The endpoint for the redis data
"""
return f"internal/deployment/{deployment}/data/{endpoint}"
@staticmethod
def socketio_endpoint_room(endpoint: str):
"""
Endpoint for the socketio room for an endpoint.
Args:
endpoint (str): The endpoint name
Returns:
str: The endpoint for the socketio room
"""
return f"ENDPOINT::{endpoint}"
class RedisRouter(BaseRouter):
"""
This class is a router for the Redis API. It exposes the redis client through
@ -47,6 +96,7 @@ def safe_socket(fcn):
async def wrapper(self, sid, *args, **kwargs):
try:
out = await fcn(self, sid, *args, **kwargs)
# pylint: disable=broad-except
except Exception as exc:
content = traceback.format_exc()
logger.error(content)
@ -71,15 +121,16 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager):
super().__init__(url, channel, write_only, logger, redis_options)
self.requested_channels = []
self.started_update_loop = False
self.known_deployments = set()
# task = asyncio.create_task(self._required_channel_heartbeat())
# loop.run_until_complete(task)
def start_update_loop(self):
self.started_update_loop = True
loop = asyncio.get_event_loop()
task = loop.create_task(self._backend_heartbeat())
return task
# loop = asyncio.get_event_loop()
# task = loop.create_task(self._backend_heartbeat())
# return task
async def disconnect(self, sid, namespace, **kwargs):
if kwargs.get("ignore_queue"):
@ -105,16 +156,25 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager):
async def _backend_heartbeat(self):
while not self.parent.fastapi_app.server.should_exit:
await asyncio.sleep(1)
await self.redis.publish(f"deployments/{self.host_id}/heartbeat/", "ping")
data = json.dumps(self.parent.users)
print(f"Sending heartbeat: {data}")
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
await asyncio.sleep(10)
await self.update_state_info()
async def update_state_info(self):
data = json.dumps(self.parent.users)
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
await self.redis.publish(f"deployments/{self.host_id}/state/", data)
deployments = {deployment: {} for deployment in self.known_deployments}
for user in self.parent.users:
deployment = self.parent.users[user]["deployment"]
if deployment not in deployments:
deployments[deployment] = {}
self.known_deployments.add(deployment)
deployments[deployment][user] = self.parent.users[user]
for name, data in deployments.items():
data_json = json.dumps(data)
await self.redis.set(
RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json, ex=30
)
await self.redis.publish(
RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json
)
async def update_websocket_states(self):
loop = asyncio.get_event_loop()
@ -140,11 +200,16 @@ class RedisWebsocket:
self.prefix = prefix
self.fastapi_app = app
self.active_connections = set()
redis_host = datasources.datasources["redis"].config["host"]
redis_port = datasources.datasources["redis"].config["port"]
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
self.socket = socketio.AsyncServer(
cors_allowed_origins="*",
async_mode="asgi",
client_manager=BECAsyncRedisManager(
self, url=f"redis://{self.redis.host}:{self.redis.port}/0"
self,
url=f"redis://{redis_host}:{redis_port}/0",
redis_options={"username": "ingestor", "password": redis_password},
),
)
self.app = socketio.ASGIApp(self.socket)
@ -155,11 +220,22 @@ class RedisWebsocket:
self.socket.on("register", self.redis_register)
self.socket.on("unregister", self.redis.unregister)
self.socket.on("disconnect", self.disconnect_client)
print("Redis websocket started")
@safe_socket
async def connect_client(self, sid, environ=None):
print("Client connected")
http_query = environ.get("HTTP_QUERY")
def _validate_new_user(self, http_query: str | None):
"""
Validate the connection of a new user. In particular,
the user must provide a valid token as well as have access
to the deployment. If subscriptions are provided, the user
must have access to the endpoints.
Args:
http_query (str): The query parameters of the websocket connection
Returns:
str: The user name
"""
if not http_query:
raise ValueError("Query parameters not found")
query = json.loads(http_query)
@ -168,13 +244,34 @@ class RedisWebsocket:
raise ValueError("User not found in query parameters")
user = query["user"]
if sid not in self.users:
# TODO: Validate the user token
deployment = query.get("deployment")
if not deployment:
raise ValueError("Deployment not found in query parameters")
# TODO: Validate the user has access to the deployment
return user, deployment
@safe_socket
async def connect_client(self, sid, environ=None):
if sid in self.users:
logger.info("User already connected")
return
http_query = environ.get("HTTP_QUERY")
user, deployment = self._validate_new_user(http_query)
# check if the user was already registered in redis
deployment_keys = await self.socket.manager.redis.keys("deployments/*/state/")
if not deployment_keys:
socketio_server_keys = await self.socket.manager.redis.keys(
RedisAtlasEndpoints.websocket_state(deployment, "*")
)
if not socketio_server_keys:
state_data = []
else:
state_data = await self.socket.manager.redis.mget(*deployment_keys)
state_data = await self.socket.manager.redis.mget(*socketio_server_keys)
info = {}
for data in state_data:
if not data:
@ -184,12 +281,12 @@ class RedisWebsocket:
info[value["user"]] = value["subscriptions"]
if user in info:
self.users[sid] = {"user": user, "subscriptions": info[user]}
for endpoint in set(self.users[sid]["subscriptions"]):
await self.socket.enter_room(sid, f"ENDPOINT::{endpoint}")
await self.socket.manager.update_websocket_states()
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
for endpoint in set(info[user]):
print(f"Registering {endpoint}")
await self._update_user_subscriptions(sid, endpoint)
else:
self.users[sid] = {"user": query["user"], "subscriptions": []}
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
await self.socket.manager.update_websocket_states()
@ -214,8 +311,8 @@ class RedisWebsocket:
try:
print(msg)
data = json.loads(msg)
except json.JSONDecodeError:
raise ValueError("Invalid JSON message")
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON message") from exc
endpoint = getattr(MessageEndpoints, data.get("endpoint"), None)
if endpoint is None:
@ -223,27 +320,36 @@ class RedisWebsocket:
# check if the endpoint receives arguments
if len(inspect.signature(endpoint).parameters) > 0:
endpoint = endpoint(data.get("args"))
endpoint: MessageEndpoints = endpoint(data.get("args"))
else:
endpoint = endpoint()
endpoint: MessageEndpoints = endpoint()
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
if data.get("endpoint") not in self.users[sid]["subscriptions"]:
await self.socket.enter_room(sid, f"ENDPOINT::{data.get('endpoint')}")
self.users[sid]["subscriptions"].append(data.get("endpoint"))
await self._update_user_subscriptions(sid, endpoint.endpoint)
async def _update_user_subscriptions(self, sid: str, endpoint: str):
deployment = self.users[sid]["deployment"]
endpoint_info = EndpointInfo(
RedisAtlasEndpoints.redis_data(deployment, endpoint), Any, MessageOp.STREAM
)
room = RedisAtlasEndpoints.socketio_endpoint_room(endpoint)
self.redis.register(endpoint_info, cb=self.on_redis_message, parent=self, room=room)
if endpoint not in self.users[sid]["subscriptions"]:
await self.socket.enter_room(sid, room)
self.users[sid]["subscriptions"].append(endpoint)
await self.socket.manager.update_websocket_states()
@staticmethod
def on_redis_message(message, parent):
def on_redis_message(message, parent, room):
async def emit_message(message):
outgoing = {
"data": message.value.model_dump_json(),
"message_type": message.value.__class__.__name__,
}
await parent.socket.emit("new_message", data=outgoing, room=message.topic)
# check that the event loop is running
if not parent.loop.is_running():
parent.loop.run_until_complete(emit_message(message))
if "pubsub_data" in message:
msg = message["pubsub_data"]
else:
msg = message["data"]
outgoing = {"data": msg.content, "metadata": msg.metadata}
outgoing = json.dumps(outgoing)
await parent.socket.emit("message", data=outgoing, room=room)
# Run the coroutine in this loop
asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop)

View File

@ -1,6 +1,6 @@
import pymongo
from bec_atlas.model import Deployments, Realm
from bec_atlas.model import Deployments, Realm, Session
class DemoSetupLoader:
@ -30,8 +30,16 @@ class DemoSetupLoader:
deployment = Deployments(
realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"]
)
if self.db["deployments"].find_one({"name": deployment.name}) is None:
self.db["deployments"].insert_one(deployment.__dict__)
if self.db["sessions"].find_one({"name": "_default_"}) is None:
deployment = self.db["deployments"].find_one({"name": deployment["name"]})
default_session = Session(
owner_groups=["admin", "demo"], deployment_id=deployment["_id"], name="_default_"
)
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
if __name__ == "__main__":
loader = DemoSetupLoader({"host": "localhost", "port": 27017})