mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
wip
This commit is contained in:
@ -1,4 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
from redis.exceptions import AuthenticationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_atlas.model.model import Deployments
|
||||
|
||||
|
||||
class RedisDatasource:
|
||||
@ -6,6 +14,69 @@ class RedisDatasource:
|
||||
self.config = config
|
||||
self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}")
|
||||
|
||||
try:
|
||||
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
|
||||
self.reconfigured_acls = False
|
||||
except AuthenticationError:
|
||||
self.setup_acls()
|
||||
self.connector._redis_conn.auth(config.get("password", "ingestor"), username="ingestor")
|
||||
self.reconfigured_acls = True
|
||||
print("Connected to Redis")
|
||||
|
||||
def setup_acls(self):
|
||||
"""
|
||||
Setup the ACLs for the Redis proxy server.
|
||||
"""
|
||||
|
||||
# Create the ingestor user. This user is used by the data ingestor to write data to the database.
|
||||
self.connector._redis_conn.acl_setuser(
|
||||
"ingestor",
|
||||
enabled=True,
|
||||
passwords=f'+{self.config.get("password", "ingestor")}',
|
||||
categories=["+@all"],
|
||||
keys=["*"],
|
||||
channels=["*"],
|
||||
)
|
||||
|
||||
self.connector._redis_conn.acl_setuser(
|
||||
"user",
|
||||
enabled=True,
|
||||
passwords="+user",
|
||||
categories=["+@all"],
|
||||
keys=["*"],
|
||||
channels=["*"],
|
||||
)
|
||||
self.connector._redis_conn.acl_setuser(
|
||||
"default", enabled=True, categories=["-@all"], commands=["+auth", "+acl|whoami"]
|
||||
)
|
||||
|
||||
def add_deployment_acl(self, deployment: Deployments):
|
||||
"""
|
||||
Add ACLs for the deployment.
|
||||
|
||||
Args:
|
||||
deployment (Deployments): The deployment object
|
||||
"""
|
||||
print(f"Adding ACLs for deployment <{deployment.name}>({deployment.id})")
|
||||
self.connector._redis_conn.acl_setuser(
|
||||
f"ingestor_{deployment.id}",
|
||||
enabled=True,
|
||||
passwords=f"+{deployment.deployment_key}",
|
||||
categories=["+@all", "-@dangerous"],
|
||||
keys=[
|
||||
f"internal/deployment/{deployment.id}/*",
|
||||
f"internal/deployment/{deployment.id}/*/state",
|
||||
f"internal/deployment/{deployment.id}/*/data/*",
|
||||
],
|
||||
channels=[
|
||||
f"internal/deployment/{deployment.id}/*/state",
|
||||
f"internal/deployment/{deployment.id}/*",
|
||||
],
|
||||
commands=[f"+keys|internal/deployment/{deployment.id}/*/state"],
|
||||
reset_channels=True,
|
||||
reset_keys=True,
|
||||
)
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
|
@ -1,12 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
|
||||
from bec_lib import messages
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
from redis import Redis
|
||||
|
||||
# from redis import Redis
|
||||
from bson import ObjectId
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
@ -24,32 +29,25 @@ class DataIngestor:
|
||||
|
||||
redis_host = config.get("redis", {}).get("host", "localhost")
|
||||
redis_port = config.get("redis", {}).get("port", 6380)
|
||||
self.redis = Redis(host=redis_host, port=redis_port)
|
||||
self.redis = RedisConnector(
|
||||
f"{redis_host}:{redis_port}" # username="ingestor", password="ingestor"
|
||||
)
|
||||
# self.redis.authenticate(
|
||||
# config.get("redis", {}).get("password", "ingestor"), username="ingestor"
|
||||
# )
|
||||
|
||||
self.redis._redis_conn.connection_pool.connection_kwargs["username"] = "ingestor"
|
||||
self.redis._redis_conn.connection_pool.connection_kwargs["password"] = "ingestor"
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
self.available_deployments = {}
|
||||
self.available_deployments = []
|
||||
self.deployment_listener_thread = None
|
||||
self.receiver_thread = None
|
||||
self.reclaim_pending_messages_thread = None
|
||||
self.consumer_name = f"ingestor_{os.getpid()}"
|
||||
self.create_consumer_group()
|
||||
self.start_deployment_listener()
|
||||
self.start_receiver()
|
||||
|
||||
def create_consumer_group(self):
|
||||
"""
|
||||
Create the consumer group for the ingestor.
|
||||
|
||||
"""
|
||||
try:
|
||||
self.redis.xgroup_create(
|
||||
name="internal/database_ingest", groupname="ingestor", mkstream=True
|
||||
)
|
||||
except ResponseError as exc:
|
||||
if "BUSYGROUP Consumer Group name already exists" in str(exc):
|
||||
logger.info("Consumer group already exists.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
def start_deployment_listener(self):
|
||||
"""
|
||||
Start the listener for the available deployments.
|
||||
@ -57,7 +55,8 @@ class DataIngestor:
|
||||
"""
|
||||
out = self.redis.get("deployments")
|
||||
if out:
|
||||
self.available_deployments = out
|
||||
self.available_deployments = json.loads(out)
|
||||
self.update_consumer_groups()
|
||||
self.deployment_listener_thread = threading.Thread(
|
||||
target=self.update_available_deployments, name="deployment_listener"
|
||||
)
|
||||
@ -70,19 +69,65 @@ class DataIngestor:
|
||||
"""
|
||||
self.receiver_thread = threading.Thread(target=self.ingestor_loop, name="receiver")
|
||||
self.receiver_thread.start()
|
||||
self.reclaim_pending_messages_thread = threading.Thread(
|
||||
target=self.reclaim_pending_messages, name="reclaim_pending_messages"
|
||||
)
|
||||
self.reclaim_pending_messages_thread.start()
|
||||
|
||||
def update_available_deployments(self):
|
||||
"""
|
||||
Update the available deployments from the Redis queue.
|
||||
"""
|
||||
sub = self.redis.pubsub()
|
||||
sub.subscribe("deployments")
|
||||
|
||||
def _update_deployments(data, parent):
|
||||
parent.available_deployments = data
|
||||
parent.update_consumer_groups()
|
||||
|
||||
self.redis.register("deployments", cb=_update_deployments, parent=self)
|
||||
|
||||
def update_consumer_groups(self):
|
||||
"""
|
||||
Update the consumer groups for the available deployments.
|
||||
|
||||
"""
|
||||
for deployment in self.available_deployments:
|
||||
try:
|
||||
self.redis._redis_conn.xgroup_create(
|
||||
name=f"internal/deployment/{deployment['id']}/ingest",
|
||||
groupname="ingestor",
|
||||
mkstream=True,
|
||||
)
|
||||
except ResponseError as exc:
|
||||
if "BUSYGROUP Consumer Group name already exists" in str(exc):
|
||||
logger.info("Consumer group already exists.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
def reclaim_pending_messages(self):
|
||||
"""
|
||||
Reclaim any pending messages from the Redis queue.
|
||||
|
||||
"""
|
||||
while not self.shutdown_event.is_set():
|
||||
out = sub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
if out:
|
||||
logger.info(f"Updating available deployments: {out}")
|
||||
self.available_deployments = out
|
||||
sub.close()
|
||||
to_process = []
|
||||
for deployment in self.available_deployments:
|
||||
pending_messages = self.redis._redis_conn.xautoclaim(
|
||||
f"internal/deployment/{deployment['id']}/ingest",
|
||||
"ingestor",
|
||||
self.consumer_name,
|
||||
min_idle_time=1000,
|
||||
)
|
||||
if pending_messages[1]:
|
||||
to_process.append(
|
||||
[
|
||||
f"internal/deployment/{deployment['id']}/ingest".encode(),
|
||||
pending_messages[1],
|
||||
]
|
||||
)
|
||||
|
||||
if to_process:
|
||||
self._handle_stream_messages(to_process)
|
||||
self.shutdown_event.wait(10)
|
||||
|
||||
def ingestor_loop(self):
|
||||
"""
|
||||
@ -90,24 +135,32 @@ class DataIngestor:
|
||||
|
||||
"""
|
||||
while not self.shutdown_event.is_set():
|
||||
data = self.redis.xreadgroup(
|
||||
groupname="ingestor",
|
||||
consumername=self.consumer_name,
|
||||
streams={"internal/database_ingest": ">"},
|
||||
block=1000,
|
||||
streams = {
|
||||
f"internal/deployment/{deployment['id']}/ingest": ">"
|
||||
for deployment in self.available_deployments
|
||||
}
|
||||
data = self.redis._redis_conn.xreadgroup(
|
||||
groupname="ingestor", consumername=self.consumer_name, streams=streams, block=1000
|
||||
)
|
||||
|
||||
if not data:
|
||||
logger.debug("No messages to ingest.")
|
||||
continue
|
||||
|
||||
self._handle_stream_messages(data)
|
||||
|
||||
def _handle_stream_messages(self, data):
|
||||
for stream, msgs in data:
|
||||
for message in msgs:
|
||||
msg_dict = MsgpackSerialization.loads(message[1])
|
||||
self.handle_message(msg_dict)
|
||||
self.redis.xack(stream, "ingestor", message[0])
|
||||
msg = message[1]
|
||||
out = {}
|
||||
for key, val in msg.items():
|
||||
out[key.decode()] = MsgpackSerialization.loads(val)
|
||||
deployment_id = stream.decode().split("/")[-2]
|
||||
self.handle_message(out, deployment_id)
|
||||
self.redis._redis_conn.xack(stream, "ingestor", message[0])
|
||||
|
||||
def handle_message(self, msg_dict: dict):
|
||||
def handle_message(self, msg_dict: dict, deploymend_id: str):
|
||||
"""
|
||||
Handle a message from the Redis queue.
|
||||
|
||||
@ -119,22 +172,34 @@ class DataIngestor:
|
||||
data = msg_dict.get("data")
|
||||
if data is None:
|
||||
return
|
||||
deployment = msg_dict.get("deployment")
|
||||
if deployment is None:
|
||||
return
|
||||
|
||||
if not deployment == self.available_deployments.get(deployment):
|
||||
return
|
||||
|
||||
if isinstance(data, messages.ScanStatusMessage):
|
||||
self.update_scan_status(data)
|
||||
self.update_scan_status(data, deploymend_id)
|
||||
|
||||
def update_scan_status(self, msg: messages.ScanStatusMessage):
|
||||
@lru_cache()
|
||||
def get_default_session_id(self, deployment_id: str):
|
||||
"""
|
||||
Get the session id for a deployment.
|
||||
|
||||
Args:
|
||||
deployment_id (str): The deployment id
|
||||
|
||||
Returns:
|
||||
str: The session id
|
||||
|
||||
"""
|
||||
out = self.datasource.db["sessions"].find_one(
|
||||
{"name": "_default_", "deployment_id": ObjectId(deployment_id)}
|
||||
)
|
||||
return out["_id"]
|
||||
|
||||
def update_scan_status(self, msg: messages.ScanStatusMessage, deployment_id: str):
|
||||
"""
|
||||
Update the status of a scan in the database. If the scan does not exist, create it.
|
||||
|
||||
Args:
|
||||
msg (messages.ScanStatusMessage): The message containing the scan status.
|
||||
deployment_id (str): The deployment id
|
||||
|
||||
"""
|
||||
if not hasattr(msg, "session_id"):
|
||||
@ -143,17 +208,22 @@ class DataIngestor:
|
||||
else:
|
||||
session_id = msg.session_id
|
||||
if not session_id:
|
||||
return
|
||||
session_id = "_default_"
|
||||
|
||||
if session_id == "_default_":
|
||||
session_id = self.get_default_session_id(deployment_id)
|
||||
|
||||
# scans are indexed by the scan_id, hence we can use find_one and search by the ObjectId
|
||||
data = self.datasource.db["scans"].find_one({"_id": msg.scan_id})
|
||||
if data is None:
|
||||
msg_conv = ScanStatus(
|
||||
owner_groups=["admin"], access_groups=["admin"], **msg.model_dump()
|
||||
)
|
||||
msg_conv._id = msg_conv.scan_id
|
||||
|
||||
out = msg_conv.model_dump(exclude_none=True)
|
||||
out["_id"] = msg.scan_id
|
||||
|
||||
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
||||
out = msg_conv.__dict__
|
||||
out["session_id"] = session_id
|
||||
|
||||
self.datasource.db["scans"].insert_one(out)
|
||||
@ -168,4 +238,25 @@ class DataIngestor:
|
||||
self.deployment_listener_thread.join()
|
||||
if self.receiver_thread:
|
||||
self.receiver_thread.join()
|
||||
if self.reclaim_pending_messages_thread:
|
||||
self.reclaim_pending_messages_thread.join()
|
||||
self.redis.shutdown()
|
||||
self.datasource.shutdown()
|
||||
|
||||
|
||||
def main(): # pragma: no cover
|
||||
from bec_atlas.main import CONFIG
|
||||
|
||||
ingestor = DataIngestor(config=CONFIG)
|
||||
event = threading.Event()
|
||||
while not event.is_set():
|
||||
try:
|
||||
event.wait(1)
|
||||
except KeyboardInterrupt:
|
||||
event.set()
|
||||
ingestor.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bec_logger.level = bec_logger.LOGLEVEL.INFO
|
||||
main()
|
||||
|
@ -5,12 +5,12 @@ from fastapi import FastAPI
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.router.deployments_router import DeploymentsRouter
|
||||
from bec_atlas.router.realm_router import RealmRouter
|
||||
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
|
||||
from bec_atlas.router.redis_router import RedisWebsocket
|
||||
from bec_atlas.router.scan_router import ScanRouter
|
||||
from bec_atlas.router.user_router import UserRouter
|
||||
|
||||
CONFIG = {
|
||||
"redis": {"host": "localhost", "port": 6379},
|
||||
"redis": {"host": "localhost", "port": 6380},
|
||||
"scylla": {"hosts": ["localhost"]},
|
||||
"mongodb": {"host": "localhost", "port": 27017},
|
||||
}
|
||||
@ -25,15 +25,16 @@ class AtlasApp:
|
||||
self.server = None
|
||||
self.prefix = f"/api/{self.API_VERSION}"
|
||||
self.datasources = DatasourceManager(config=self.config)
|
||||
self.datasources.connect()
|
||||
self.register_event_handler()
|
||||
self.add_routers()
|
||||
# self.add_routers()
|
||||
|
||||
def register_event_handler(self):
|
||||
self.app.add_event_handler("startup", self.on_startup)
|
||||
self.app.add_event_handler("shutdown", self.on_shutdown)
|
||||
self.app.add_event_handler("startup", self.on_startup)
|
||||
|
||||
async def on_startup(self):
|
||||
self.datasources.connect()
|
||||
self.add_routers()
|
||||
|
||||
async def on_shutdown(self):
|
||||
self.datasources.shutdown()
|
||||
|
@ -3,7 +3,19 @@ from typing import Literal
|
||||
|
||||
from bec_lib import messages
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class MongoBaseModel(BaseModel):
|
||||
id: str | ObjectId | None = Field(default=None, alias="_id")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("id")
|
||||
def serialize_id(self, id: str | ObjectId):
|
||||
if isinstance(id, ObjectId):
|
||||
return str(id)
|
||||
return id
|
||||
|
||||
|
||||
class AccessProfile(BaseModel):
|
||||
@ -11,40 +23,32 @@ class AccessProfile(BaseModel):
|
||||
access_groups: list[str] = []
|
||||
|
||||
|
||||
class ScanStatus(AccessProfile, messages.ScanStatusMessage):
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
class ScanStatus(MongoBaseModel, AccessProfile, messages.ScanStatusMessage): ...
|
||||
|
||||
|
||||
class UserCredentials(AccessProfile):
|
||||
class UserCredentials(MongoBaseModel, AccessProfile):
|
||||
user_id: str | ObjectId
|
||||
password: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class User(AccessProfile):
|
||||
id: str | ObjectId | None = Field(default=None, alias="_id")
|
||||
class User(MongoBaseModel, AccessProfile):
|
||||
email: str
|
||||
groups: list[str]
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
email: str
|
||||
groups: list[str]
|
||||
|
||||
|
||||
class Deployments(AccessProfile):
|
||||
class Deployments(MongoBaseModel, AccessProfile):
|
||||
realm_id: str
|
||||
name: str
|
||||
deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
active_session_id: str | None = None
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Experiments(AccessProfile):
|
||||
realm_id: str
|
||||
@ -76,21 +80,16 @@ class State(AccessProfile):
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Session(AccessProfile):
|
||||
realm_id: str
|
||||
session_id: str
|
||||
config: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
class Session(MongoBaseModel, AccessProfile):
|
||||
deployment_id: str | ObjectId
|
||||
name: str
|
||||
|
||||
|
||||
class Realm(AccessProfile):
|
||||
class Realm(MongoBaseModel, AccessProfile):
|
||||
realm_id: str
|
||||
deployments: list[Deployments] = []
|
||||
name: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Datasets(AccessProfile):
|
||||
realm_id: str
|
||||
|
@ -1,3 +1,6 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
@ -5,6 +8,9 @@ from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import Deployments, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||
|
||||
|
||||
class DeploymentsRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
@ -25,6 +31,7 @@ class DeploymentsRouter(BaseRouter):
|
||||
description="Get a single deployment by id for a realm",
|
||||
response_model=Deployments,
|
||||
)
|
||||
self.update_available_deployments()
|
||||
|
||||
async def deployments(
|
||||
self, realm: str, current_user: UserInfo = Depends(get_current_user)
|
||||
@ -52,3 +59,16 @@ class DeploymentsRouter(BaseRouter):
|
||||
return self.db.find_one(
|
||||
"deployments", {"_id": deployment_id}, Deployments, user=current_user
|
||||
)
|
||||
|
||||
def update_available_deployments(self):
|
||||
"""
|
||||
Update the available deployments.
|
||||
"""
|
||||
self.available_deployments = self.db.find("deployments", {}, Deployments)
|
||||
|
||||
redis: RedisDatasource = self.datasources.datasources.get("redis")
|
||||
msg = json.dumps([msg.model_dump() for msg in self.available_deployments])
|
||||
redis.connector.set_and_publish("deployments", msg)
|
||||
if redis.reconfigured_acls:
|
||||
for deployment in self.available_deployments:
|
||||
redis.add_deployment_acl(deployment)
|
||||
|
@ -3,10 +3,10 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import socketio
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp
|
||||
from bec_lib.logger import bec_logger
|
||||
from fastapi import APIRouter
|
||||
|
||||
@ -18,6 +18,55 @@ if TYPE_CHECKING:
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
|
||||
class RedisAtlasEndpoints:
|
||||
"""
|
||||
This class contains the endpoints for the Redis API. It is used to
|
||||
manage the subscriptions and the state information for the websocket
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def websocket_state(deployment: str, host_id: str):
|
||||
"""
|
||||
Endpoint for the websocket state information, containing the users and their subscriptions
|
||||
per backend host.
|
||||
|
||||
Args:
|
||||
deployment (str): The deployment name
|
||||
host_id (str): The host id of the backend
|
||||
|
||||
Returns:
|
||||
str: The endpoint for the websocket state information
|
||||
"""
|
||||
return f"internal/deployment/{deployment}/{host_id}/state"
|
||||
|
||||
@staticmethod
|
||||
def redis_data(deployment: str, endpoint: str):
|
||||
"""
|
||||
Endpoint for the redis data for a deployment and endpoint.
|
||||
|
||||
Args:
|
||||
deployment (str): The deployment name
|
||||
endpoint (str): The endpoint name
|
||||
|
||||
Returns:
|
||||
str: The endpoint for the redis data
|
||||
"""
|
||||
return f"internal/deployment/{deployment}/data/{endpoint}"
|
||||
|
||||
@staticmethod
|
||||
def socketio_endpoint_room(endpoint: str):
|
||||
"""
|
||||
Endpoint for the socketio room for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint (str): The endpoint name
|
||||
|
||||
Returns:
|
||||
str: The endpoint for the socketio room
|
||||
"""
|
||||
return f"ENDPOINT::{endpoint}"
|
||||
|
||||
|
||||
class RedisRouter(BaseRouter):
|
||||
"""
|
||||
This class is a router for the Redis API. It exposes the redis client through
|
||||
@ -47,6 +96,7 @@ def safe_socket(fcn):
|
||||
async def wrapper(self, sid, *args, **kwargs):
|
||||
try:
|
||||
out = await fcn(self, sid, *args, **kwargs)
|
||||
# pylint: disable=broad-except
|
||||
except Exception as exc:
|
||||
content = traceback.format_exc()
|
||||
logger.error(content)
|
||||
@ -71,15 +121,16 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager):
|
||||
super().__init__(url, channel, write_only, logger, redis_options)
|
||||
self.requested_channels = []
|
||||
self.started_update_loop = False
|
||||
self.known_deployments = set()
|
||||
|
||||
# task = asyncio.create_task(self._required_channel_heartbeat())
|
||||
# loop.run_until_complete(task)
|
||||
|
||||
def start_update_loop(self):
|
||||
self.started_update_loop = True
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(self._backend_heartbeat())
|
||||
return task
|
||||
# loop = asyncio.get_event_loop()
|
||||
# task = loop.create_task(self._backend_heartbeat())
|
||||
# return task
|
||||
|
||||
async def disconnect(self, sid, namespace, **kwargs):
|
||||
if kwargs.get("ignore_queue"):
|
||||
@ -105,16 +156,25 @@ class BECAsyncRedisManager(socketio.AsyncRedisManager):
|
||||
|
||||
async def _backend_heartbeat(self):
|
||||
while not self.parent.fastapi_app.server.should_exit:
|
||||
await asyncio.sleep(1)
|
||||
await self.redis.publish(f"deployments/{self.host_id}/heartbeat/", "ping")
|
||||
data = json.dumps(self.parent.users)
|
||||
print(f"Sending heartbeat: {data}")
|
||||
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
|
||||
await asyncio.sleep(10)
|
||||
await self.update_state_info()
|
||||
|
||||
async def update_state_info(self):
|
||||
data = json.dumps(self.parent.users)
|
||||
await self.redis.set(f"deployments/{self.host_id}/state/", data, ex=30)
|
||||
await self.redis.publish(f"deployments/{self.host_id}/state/", data)
|
||||
deployments = {deployment: {} for deployment in self.known_deployments}
|
||||
for user in self.parent.users:
|
||||
deployment = self.parent.users[user]["deployment"]
|
||||
if deployment not in deployments:
|
||||
deployments[deployment] = {}
|
||||
self.known_deployments.add(deployment)
|
||||
deployments[deployment][user] = self.parent.users[user]
|
||||
for name, data in deployments.items():
|
||||
data_json = json.dumps(data)
|
||||
await self.redis.set(
|
||||
RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json, ex=30
|
||||
)
|
||||
await self.redis.publish(
|
||||
RedisAtlasEndpoints.websocket_state(name, self.host_id), data_json
|
||||
)
|
||||
|
||||
async def update_websocket_states(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -140,11 +200,16 @@ class RedisWebsocket:
|
||||
self.prefix = prefix
|
||||
self.fastapi_app = app
|
||||
self.active_connections = set()
|
||||
redis_host = datasources.datasources["redis"].config["host"]
|
||||
redis_port = datasources.datasources["redis"].config["port"]
|
||||
redis_password = datasources.datasources["redis"].config.get("password", "ingestor")
|
||||
self.socket = socketio.AsyncServer(
|
||||
cors_allowed_origins="*",
|
||||
async_mode="asgi",
|
||||
client_manager=BECAsyncRedisManager(
|
||||
self, url=f"redis://{self.redis.host}:{self.redis.port}/0"
|
||||
self,
|
||||
url=f"redis://{redis_host}:{redis_port}/0",
|
||||
redis_options={"username": "ingestor", "password": redis_password},
|
||||
),
|
||||
)
|
||||
self.app = socketio.ASGIApp(self.socket)
|
||||
@ -155,11 +220,22 @@ class RedisWebsocket:
|
||||
self.socket.on("register", self.redis_register)
|
||||
self.socket.on("unregister", self.redis.unregister)
|
||||
self.socket.on("disconnect", self.disconnect_client)
|
||||
print("Redis websocket started")
|
||||
|
||||
@safe_socket
|
||||
async def connect_client(self, sid, environ=None):
|
||||
print("Client connected")
|
||||
http_query = environ.get("HTTP_QUERY")
|
||||
def _validate_new_user(self, http_query: str | None):
|
||||
"""
|
||||
Validate the connection of a new user. In particular,
|
||||
the user must provide a valid token as well as have access
|
||||
to the deployment. If subscriptions are provided, the user
|
||||
must have access to the endpoints.
|
||||
|
||||
Args:
|
||||
http_query (str): The query parameters of the websocket connection
|
||||
|
||||
Returns:
|
||||
str: The user name
|
||||
|
||||
"""
|
||||
if not http_query:
|
||||
raise ValueError("Query parameters not found")
|
||||
query = json.loads(http_query)
|
||||
@ -168,13 +244,34 @@ class RedisWebsocket:
|
||||
raise ValueError("User not found in query parameters")
|
||||
user = query["user"]
|
||||
|
||||
if sid not in self.users:
|
||||
# TODO: Validate the user token
|
||||
|
||||
deployment = query.get("deployment")
|
||||
if not deployment:
|
||||
raise ValueError("Deployment not found in query parameters")
|
||||
|
||||
# TODO: Validate the user has access to the deployment
|
||||
|
||||
return user, deployment
|
||||
|
||||
@safe_socket
|
||||
async def connect_client(self, sid, environ=None):
|
||||
if sid in self.users:
|
||||
logger.info("User already connected")
|
||||
return
|
||||
|
||||
http_query = environ.get("HTTP_QUERY")
|
||||
|
||||
user, deployment = self._validate_new_user(http_query)
|
||||
|
||||
# check if the user was already registered in redis
|
||||
deployment_keys = await self.socket.manager.redis.keys("deployments/*/state/")
|
||||
if not deployment_keys:
|
||||
socketio_server_keys = await self.socket.manager.redis.keys(
|
||||
RedisAtlasEndpoints.websocket_state(deployment, "*")
|
||||
)
|
||||
if not socketio_server_keys:
|
||||
state_data = []
|
||||
else:
|
||||
state_data = await self.socket.manager.redis.mget(*deployment_keys)
|
||||
state_data = await self.socket.manager.redis.mget(*socketio_server_keys)
|
||||
info = {}
|
||||
for data in state_data:
|
||||
if not data:
|
||||
@ -184,12 +281,12 @@ class RedisWebsocket:
|
||||
info[value["user"]] = value["subscriptions"]
|
||||
|
||||
if user in info:
|
||||
self.users[sid] = {"user": user, "subscriptions": info[user]}
|
||||
for endpoint in set(self.users[sid]["subscriptions"]):
|
||||
await self.socket.enter_room(sid, f"ENDPOINT::{endpoint}")
|
||||
await self.socket.manager.update_websocket_states()
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
for endpoint in set(info[user]):
|
||||
print(f"Registering {endpoint}")
|
||||
await self._update_user_subscriptions(sid, endpoint)
|
||||
else:
|
||||
self.users[sid] = {"user": query["user"], "subscriptions": []}
|
||||
self.users[sid] = {"user": user, "subscriptions": [], "deployment": deployment}
|
||||
|
||||
await self.socket.manager.update_websocket_states()
|
||||
|
||||
@ -214,8 +311,8 @@ class RedisWebsocket:
|
||||
try:
|
||||
print(msg)
|
||||
data = json.loads(msg)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON message")
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("Invalid JSON message") from exc
|
||||
|
||||
endpoint = getattr(MessageEndpoints, data.get("endpoint"), None)
|
||||
if endpoint is None:
|
||||
@ -223,27 +320,36 @@ class RedisWebsocket:
|
||||
|
||||
# check if the endpoint receives arguments
|
||||
if len(inspect.signature(endpoint).parameters) > 0:
|
||||
endpoint = endpoint(data.get("args"))
|
||||
endpoint: MessageEndpoints = endpoint(data.get("args"))
|
||||
else:
|
||||
endpoint = endpoint()
|
||||
endpoint: MessageEndpoints = endpoint()
|
||||
|
||||
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
|
||||
if data.get("endpoint") not in self.users[sid]["subscriptions"]:
|
||||
await self.socket.enter_room(sid, f"ENDPOINT::{data.get('endpoint')}")
|
||||
self.users[sid]["subscriptions"].append(data.get("endpoint"))
|
||||
await self._update_user_subscriptions(sid, endpoint.endpoint)
|
||||
|
||||
async def _update_user_subscriptions(self, sid: str, endpoint: str):
|
||||
deployment = self.users[sid]["deployment"]
|
||||
|
||||
endpoint_info = EndpointInfo(
|
||||
RedisAtlasEndpoints.redis_data(deployment, endpoint), Any, MessageOp.STREAM
|
||||
)
|
||||
|
||||
room = RedisAtlasEndpoints.socketio_endpoint_room(endpoint)
|
||||
self.redis.register(endpoint_info, cb=self.on_redis_message, parent=self, room=room)
|
||||
if endpoint not in self.users[sid]["subscriptions"]:
|
||||
await self.socket.enter_room(sid, room)
|
||||
self.users[sid]["subscriptions"].append(endpoint)
|
||||
await self.socket.manager.update_websocket_states()
|
||||
|
||||
@staticmethod
|
||||
def on_redis_message(message, parent):
|
||||
def on_redis_message(message, parent, room):
|
||||
async def emit_message(message):
|
||||
outgoing = {
|
||||
"data": message.value.model_dump_json(),
|
||||
"message_type": message.value.__class__.__name__,
|
||||
}
|
||||
await parent.socket.emit("new_message", data=outgoing, room=message.topic)
|
||||
|
||||
# check that the event loop is running
|
||||
if not parent.loop.is_running():
|
||||
parent.loop.run_until_complete(emit_message(message))
|
||||
if "pubsub_data" in message:
|
||||
msg = message["pubsub_data"]
|
||||
else:
|
||||
msg = message["data"]
|
||||
outgoing = {"data": msg.content, "metadata": msg.metadata}
|
||||
outgoing = json.dumps(outgoing)
|
||||
await parent.socket.emit("message", data=outgoing, room=room)
|
||||
|
||||
# Run the coroutine in this loop
|
||||
asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pymongo
|
||||
|
||||
from bec_atlas.model import Deployments, Realm
|
||||
from bec_atlas.model import Deployments, Realm, Session
|
||||
|
||||
|
||||
class DemoSetupLoader:
|
||||
@ -30,8 +30,16 @@ class DemoSetupLoader:
|
||||
deployment = Deployments(
|
||||
realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"]
|
||||
)
|
||||
if self.db["deployments"].find_one({"name": deployment.name}) is None:
|
||||
self.db["deployments"].insert_one(deployment.__dict__)
|
||||
|
||||
if self.db["sessions"].find_one({"name": "_default_"}) is None:
|
||||
deployment = self.db["deployments"].find_one({"name": deployment["name"]})
|
||||
default_session = Session(
|
||||
owner_groups=["admin", "demo"], deployment_id=deployment["_id"], name="_default_"
|
||||
)
|
||||
self.db["sessions"].insert_one(default_session.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = DemoSetupLoader({"host": "localhost", "port": 27017})
|
||||
|
Reference in New Issue
Block a user