feat: added support for mongodb

This commit is contained in:
2024-11-27 10:46:41 +01:00
parent 8160d9a383
commit 4a45119549
25 changed files with 923 additions and 348 deletions

View File

@ -4,6 +4,13 @@ variables:
SCYLLA_HOST: scylla SCYLLA_HOST: scylla
SCYLLA_PORT: 9042 SCYLLA_PORT: 9042
SCYLLA_KEYSPACE: bec_atlas SCYLLA_KEYSPACE: bec_atlas
REDIS_HOST: redis
REDIS_PORT: 6380
DOCKER_TLS_CERTDIR: ""
CHILD_PIPELINE_BRANCH: $CI_DEFAULT_BRANCH
BEC_CORE_BRANCH:
description: "BEC branch to use for testing"
value: main
workflow: workflow:
rules: rules:
@ -68,20 +75,29 @@ pylint:
interruptible: true interruptible: true
backend_pytest: backend_pytest:
services:
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/scylladb/scylla:latest
alias: scylla
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/redis:latest
alias: redis
command: ["redis-server", "--port", "6380"]
stage: test stage: test
image: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/docker:23-dind
services:
- name: docker:dind
entrypoint: ["dockerd-entrypoint.sh", "--tls=false"]
needs: [] needs: []
script: script:
- pip install ./backend[dev] - if [[ "$CI_PROJECT_PATH" != "bec/bec_atlas" ]]; then
- pip install coverage pytest-asyncio apk update; apk add git; echo -e "\033[35;1m Using branch $CHILD_PIPELINE_BRANCH of BEC Atlas \033[0;m";
- coverage run --concurrency=thread --source=./backend --omit=*/backend/tests/* -m pytest -v --junitxml=report.xml --skip-docker --random-order --full-trace ./backend/tests test -d bec_atlas || git clone --branch $CHILD_PIPELINE_BRANCH https://gitlab.psi.ch/bec/bec_atlas.git; cd bec_atlas;
- coverage report TARGET_BRANCH=$CHILD_PIPELINE_BRANCH;
- coverage xml else
TARGET_BRANCH=$CI_COMMIT_REF_NAME;
fi
# start services
- docker-compose -f ./backend/tests/docker-compose.yml up -d
# build test environment
- echo "$CI_DEPENDENCY_PROXY_PASSWORD" | docker login $CI_DEPENDENCY_PROXY_SERVER --username $CI_DEPENDENCY_PROXY_USER --password-stdin
- docker build -t bec_atlas_backend:test -f ./backend/tests/Dockerfile.run_pytest --build-arg PY_VERSION=3.10 --build-arg BEC_ATLAS_BRANCH=$TARGET_BRANCH --build-arg BEC_CORE_BRANCH=$BEC_CORE_BRANCH --build-arg CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX .
- docker run --network=host --name bec_atlas_backend bec_atlas_backend:test
after_script:
- docker cp bec_atlas_backend:/code/bec_atlas/test_files/. $CI_PROJECT_DIR
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts: artifacts:
reports: reports:
@ -116,7 +132,7 @@ backend_pytest:
# - semantic-release publish # - semantic-release publish
allow_failure: false # allow_failure: false
rules: # rules:
- if: '$CI_COMMIT_REF_NAME == "main" && $CI_PROJECT_PATH == "bec/bec_atlas"' # - if: '$CI_COMMIT_REF_NAME == "main" && $CI_PROJECT_PATH == "bec/bec_atlas"'
interruptible: true # interruptible: true

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated from typing import Annotated
@ -8,7 +10,7 @@ from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash from pwdlib import PasswordHash
from bec_atlas.datasources.scylladb import scylladb_schema as schema from bec_atlas.model import UserInfo
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
@ -54,7 +56,7 @@ def decode_token(token: str):
raise credentials_exception from exc raise credentials_exception from exc
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> schema.User: async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=401, status_code=401,
detail="Could not validate credentials", detail="Could not validate credentials",
@ -68,4 +70,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> sch
raise credentials_exception raise credentials_exception
except Exception as exc: except Exception as exc:
raise credentials_exception from exc raise credentials_exception from exc
return schema.User(groups=groups, email=email) return UserInfo(groups=groups, email=email)

View File

@ -1,5 +1,9 @@
from bec_lib.logger import bec_logger
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.datasources.redis_datasource import RedisDatasource from bec_atlas.datasources.redis_datasource import RedisDatasource
from bec_atlas.datasources.scylladb.scylladb import ScylladbDatasource
logger = bec_logger.logger
class DatasourceManager: class DatasourceManager:
@ -13,11 +17,12 @@ class DatasourceManager:
datasource.connect() datasource.connect()
def load_datasources(self): def load_datasources(self):
logger.info(f"Loading datasources with config: {self.config}")
for datasource_name, datasource_config in self.config.items(): for datasource_name, datasource_config in self.config.items():
if datasource_name == "scylla":
self.datasources[datasource_name] = ScylladbDatasource(datasource_config)
if datasource_name == "redis": if datasource_name == "redis":
self.datasources[datasource_name] = RedisDatasource(datasource_config) self.datasources[datasource_name] = RedisDatasource(datasource_config)
if datasource_name == "mongodb":
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
def shutdown(self): def shutdown(self):
for datasource in self.datasources.values(): for datasource in self.datasources.values():

View File

@ -0,0 +1,170 @@
import json
import os
import pymongo
from bec_lib.logger import bec_logger
from pydantic import BaseModel
from bec_atlas.authentication import get_password_hash
from bec_atlas.model.model import User, UserCredentials
logger = bec_logger.logger
class MongoDBDatasource:
def __init__(self, config: dict) -> None:
self.config = config
self.client = None
self.db = None
def connect(self, include_setup: bool = True):
"""
Connect to the MongoDB database.
"""
host = self.config.get("host", "localhost")
port = self.config.get("port", 27017)
logger.info(f"Connecting to MongoDB at {host}:{port}")
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}/")
self.db = self.client["bec_atlas"]
if include_setup:
self.db["users"].create_index([("email", 1)], unique=True)
self.load_functional_accounts()
def load_functional_accounts(self):
"""
Load the functional accounts to the database.
"""
functional_accounts_file = os.path.join(
os.path.dirname(__file__), "functional_accounts.json"
)
if os.path.exists(functional_accounts_file):
with open(functional_accounts_file, "r", encoding="utf-8") as file:
functional_accounts = json.load(file)
else:
print("Functional accounts file not found. Using default demo accounts.")
# Demo accounts
functional_accounts = [
{
"email": "admin@bec_atlas.ch",
"password": "admin",
"groups": ["demo", "admin"],
"first_name": "Admin",
"last_name": "Admin",
"owner_groups": ["admin"],
},
{
"email": "jane.doe@bec_atlas.ch",
"password": "atlas",
"groups": ["demo_user"],
"first_name": "Jane",
"last_name": "Doe",
"owner_groups": ["admin"],
},
]
for account in functional_accounts:
# check if the account already exists in the database
password = account.pop("password")
password_hash = get_password_hash(password)
result = self.db["users"].find_one({"email": account["email"]})
if result is not None:
continue
user = User(**account)
user = self.db["users"].insert_one(user.__dict__)
credentials = UserCredentials(
owner_groups=["admin"], user_id=user.inserted_id, password=password_hash
)
self.db["user_credentials"].insert_one(credentials.__dict__)
def get_user_by_email(self, email: str) -> User | None:
"""
Get the user from the database.
"""
out = self.db["users"].find_one({"email": email})
if out is None:
return None
return User(**out)
def get_user_credentials(self, user_id: str) -> UserCredentials | None:
"""
Get the user credentials from the database.
"""
out = self.db["user_credentials"].find_one({"user_id": user_id})
if out is None:
return None
return UserCredentials(**out)
def find_one(
self, collection: str, query_filter: dict, dtype: BaseModel, user: User | None = None
) -> BaseModel | None:
"""
Find one document in the collection.
Args:
collection (str): The collection name
query_filter (dict): The filter to apply
dtype (BaseModel): The data type to return
user (User): The user making the request
Returns:
BaseModel: The data type with the document data
"""
if user is not None:
query_filter = self.add_user_filter(user, query_filter)
out = self.db[collection].find_one(query_filter)
if out is None:
return None
return dtype(**out)
def find(
self, collection: str, query_filter: dict, dtype: BaseModel, user: User | None = None
) -> list[BaseModel]:
"""
Find all documents in the collection.
Args:
collection (str): The collection name
query_filter (dict): The filter to apply
dtype (BaseModel): The data type to return
user (User): The user making the request
Returns:
list[BaseModel]: The data type with the document data
"""
if user is not None:
query_filter = self.add_user_filter(user, query_filter)
out = self.db[collection].find(query_filter)
return [dtype(**x) for x in out]
def add_user_filter(self, user: User, query_filter: dict) -> dict:
"""
Add the user filter to the query filter.
Args:
user (User): The user making the request
query_filter (dict): The query filter
Returns:
dict: The updated query filter
"""
if "admin" not in user.groups:
query_filter = {
"$and": [
query_filter,
{
"$or": [
{"owner_groups": {"$in": user.groups}},
{"access_groups": {"$in": user.groups}},
]
},
]
}
return query_filter
def shutdown(self):
"""
Shutdown the connection to the database.
"""
if self.client is not None:
self.client.close()
logger.info("Connection to MongoDB closed.")

View File

@ -1,132 +0,0 @@
import json
import os
from cassandra.cluster import Cluster
from cassandra.cqlengine import connection
from cassandra.cqlengine.management import create_keyspace_simple, sync_table
from pydantic import BaseModel
from bec_atlas.authentication import get_password_hash
from bec_atlas.datasources.scylladb import scylladb_schema as schema
class ScylladbDatasource:
KEYSPACE = "bec_atlas"
def __init__(self, config):
self.config = config
self.cluster = None
self.session = None
def connect(self):
self.start_client()
self.load_functional_accounts()
def start_client(self):
"""
Start the ScyllaDB client by creating a Cluster object and a Session object.
"""
hosts = self.config.get("hosts")
if not hosts:
raise ValueError("Hosts are not provided in the configuration")
#
connection.setup(hosts, self.KEYSPACE, protocol_version=3)
create_keyspace_simple(self.KEYSPACE, 1)
self._sync_tables()
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
def _sync_tables(self):
"""
Sync the tables with the schema defined in the scylladb_schema.py file.
"""
sync_table(schema.Realm)
sync_table(schema.Deployments)
sync_table(schema.Experiments)
sync_table(schema.StateCondition)
sync_table(schema.State)
sync_table(schema.Session)
sync_table(schema.Datasets)
sync_table(schema.DatasetUserData)
sync_table(schema.Scan)
sync_table(schema.ScanUserData)
sync_table(schema.ScanData)
sync_table(schema.SignalDataInt)
sync_table(schema.SignalDataFloat)
sync_table(schema.SignalDataString)
sync_table(schema.SignalDataBool)
sync_table(schema.SignalDataBlob)
sync_table(schema.SignalDataDateTime)
sync_table(schema.SignalDataUUID)
sync_table(schema.User)
sync_table(schema.UserCredentials)
def load_functional_accounts(self):
"""
Load the functional accounts to the database.
"""
functional_accounts_file = os.path.join(
os.path.dirname(__file__), "functional_accounts.json"
)
if os.path.exists(functional_accounts_file):
with open(functional_accounts_file, "r", encoding="utf-8") as file:
functional_accounts = json.load(file)
else:
print("Functional accounts file not found. Using default demo accounts.")
# Demo accounts
functional_accounts = [
{
"email": "admin@bec_atlas.ch",
"password": "admin",
"groups": ["demo"],
"first_name": "Admin",
"last_name": "Admin",
},
{
"email": "jane.doe@bec_atlas.ch",
"password": "atlas",
"groups": ["demo_user"],
"first_name": "Jane",
"last_name": "Doe",
},
]
for account in functional_accounts:
# check if the account already exists in the database
password = account.pop("password")
password_hash = get_password_hash(password)
result = schema.User.objects.filter(email=account["email"])
if result.count() > 0:
continue
user = schema.User.create(**account)
schema.UserCredentials.create(user_id=user.user_id, password=password_hash)
def get(self, table_name: str, filter: str = None, parameters: tuple = None):
"""
Get the data from the specified table.
"""
# schema.User.objects.get(email=)
if filter:
query = f"SELECT * FROM {self.KEYSPACE}.{table_name} WHERE {filter};"
else:
query = f"SELECT * FROM {self.KEYSPACE}.{table_name};"
if parameters:
return self.session.execute(query, parameters)
return self.session.execute(query)
def post(self, table_name: str, data: BaseModel):
"""
Post the data to the specified table.
Args:
table_name (str): The name of the table to post the data.
data (BaseModel): The data to be posted.
"""
query = f"INSERT INTO {self.KEYSPACE}.{table_name} JSON '{data.model_dump_json(exclude_none=True)}';"
return self.session.execute(query)
def shutdown(self):
self.cluster.shutdown()

View File

@ -1,139 +0,0 @@
import uuid
from cassandra.cqlengine import columns
from cassandra.cqlengine.models import Model
class User(Model):
email = columns.Text(primary_key=True)
user_id = columns.UUID(default=uuid.uuid4)
first_name = columns.Text()
last_name = columns.Text()
groups = columns.Set(columns.Text)
created_at = columns.DateTime()
updated_at = columns.DateTime()
class UserCredentials(Model):
user_id = columns.UUID(primary_key=True)
password = columns.Text()
class Realm(Model):
realm_id = columns.Text(primary_key=True)
deployment_id = columns.Text(primary_key=True)
name = columns.Text()
class Deployments(Model):
realm_id = columns.Text(primary_key=True)
deployment_id = columns.Text(primary_key=True)
name = columns.Text()
active_session_id = columns.UUID()
class Experiments(Model):
realm_id = columns.Text(primary_key=True)
pgroup = columns.Text(primary_key=True)
proposal = columns.Text()
text = columns.Text()
class StateCondition(Model):
realm_id = columns.Text(primary_key=True)
name = columns.Text(primary_key=True)
description = columns.Text()
device = columns.Text()
signal_value = columns.Text()
signal_type = columns.Text()
tolerance = columns.Text()
class State(Model):
realm_id = columns.Text(primary_key=True)
name = columns.Text(primary_key=True)
description = columns.Text()
conditions = columns.List(columns.Text)
class Session(Model):
realm_id = columns.Text(primary_key=True)
session_id = columns.UUID(primary_key=True)
config = columns.Text()
class Datasets(Model):
session_id = columns.UUID(primary_key=True)
dataset_id = columns.UUID(primary_key=True)
scan_id = columns.UUID()
class DatasetUserData(Model):
dataset_id = columns.UUID(primary_key=True)
name = columns.Text()
rating = columns.Integer()
comments = columns.Text()
preview = columns.Blob()
class Scan(Model):
session_id = columns.UUID(primary_key=True)
scan_id = columns.UUID(primary_key=True)
scan_number = columns.Integer()
name = columns.Text()
scan_class = columns.Text()
parameters = columns.Text()
start_time = columns.DateTime()
end_time = columns.DateTime()
exit_status = columns.Text()
class ScanUserData(Model):
scan_id = columns.UUID(primary_key=True)
name = columns.Text()
rating = columns.Integer()
comments = columns.Text()
preview = columns.Blob()
class ScanData(Model):
scan_id = columns.UUID(primary_key=True)
device_name = columns.Text(primary_key=True)
signal_name = columns.Text(primary_key=True)
shape = columns.List(columns.Integer)
dtype = columns.Text()
class SignalDataBase(Model):
realm_id = columns.Text(partition_key=True)
signal_name = columns.Text(partition_key=True)
scan_id = columns.UUID(primary_key=True)
index = columns.Integer(primary_key=True)
class SignalDataInt(SignalDataBase):
data = columns.Integer()
class SignalDataFloat(SignalDataBase):
data = columns.Float()
class SignalDataString(SignalDataBase):
data = columns.Text()
class SignalDataBlob(SignalDataBase):
data = columns.Blob()
class SignalDataBool(SignalDataBase):
data = columns.Boolean()
class SignalDataDateTime(SignalDataBase):
data = columns.DateTime()
class SignalDataUUID(SignalDataBase):
data = columns.UUID()

View File

View File

@ -0,0 +1,171 @@
from __future__ import annotations
import os
import threading
from bec_lib import messages
from bec_lib.logger import bec_logger
from bec_lib.serialization import MsgpackSerialization
from redis import Redis
from redis.exceptions import ResponseError
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import ScanStatus
logger = bec_logger.logger
class DataIngestor:
def __init__(self, config: dict) -> None:
self.config = config
self.datasource = MongoDBDatasource(config=self.config["mongodb"])
self.datasource.connect(include_setup=False)
redis_host = config.get("redis", {}).get("host", "localhost")
redis_port = config.get("redis", {}).get("port", 6380)
self.redis = Redis(host=redis_host, port=redis_port)
self.shutdown_event = threading.Event()
self.available_deployments = {}
self.deployment_listener_thread = None
self.receiver_thread = None
self.consumer_name = f"ingestor_{os.getpid()}"
self.create_consumer_group()
self.start_deployment_listener()
self.start_receiver()
def create_consumer_group(self):
"""
Create the consumer group for the ingestor.
"""
try:
self.redis.xgroup_create(
name="internal/database_ingest", groupname="ingestor", mkstream=True
)
except ResponseError as exc:
if "BUSYGROUP Consumer Group name already exists" in str(exc):
logger.info("Consumer group already exists.")
else:
raise exc
def start_deployment_listener(self):
"""
Start the listener for the available deployments.
"""
out = self.redis.get("deployments")
if out:
self.available_deployments = out
self.deployment_listener_thread = threading.Thread(
target=self.update_available_deployments, name="deployment_listener"
)
self.deployment_listener_thread.start()
def start_receiver(self):
"""
Start the receiver for the Redis queue.
"""
self.receiver_thread = threading.Thread(target=self.ingestor_loop, name="receiver")
self.receiver_thread.start()
def update_available_deployments(self):
"""
Update the available deployments from the Redis queue.
"""
sub = self.redis.pubsub()
sub.subscribe("deployments")
while not self.shutdown_event.is_set():
out = sub.get_message(ignore_subscribe_messages=True, timeout=1)
if out:
logger.info(f"Updating available deployments: {out}")
self.available_deployments = out
sub.close()
def ingestor_loop(self):
"""
The main loop for the ingestor.
"""
while not self.shutdown_event.is_set():
data = self.redis.xreadgroup(
groupname="ingestor",
consumername=self.consumer_name,
streams={"internal/database_ingest": ">"},
block=1000,
)
if not data:
logger.debug("No messages to ingest.")
continue
for stream, msgs in data:
for message in msgs:
msg_dict = MsgpackSerialization.loads(message[1])
self.handle_message(msg_dict)
self.redis.xack(stream, "ingestor", message[0])
def handle_message(self, msg_dict: dict):
"""
Handle a message from the Redis queue.
Args:
msg_dict (dict): The message dictionary.
parent (DataIngestor): The parent object.
"""
data = msg_dict.get("data")
if data is None:
return
deployment = msg_dict.get("deployment")
if deployment is None:
return
if not deployment == self.available_deployments.get(deployment):
return
if isinstance(data, messages.ScanStatusMessage):
self.update_scan_status(data)
def update_scan_status(self, msg: messages.ScanStatusMessage):
"""
Update the status of a scan in the database. If the scan does not exist, create it.
Args:
msg (messages.ScanStatusMessage): The message containing the scan status.
"""
if not hasattr(msg, "session_id"):
# TODO for compatibility with the old message format; remove once the bec_lib is updated
session_id = msg.info.get("session_id")
else:
session_id = msg.session_id
if not session_id:
return
# scans are indexed by the scan_id, hence we can use find_one and search by the ObjectId
data = self.datasource.db["scans"].find_one({"_id": msg.scan_id})
if data is None:
msg_conv = ScanStatus(
owner_groups=["admin"], access_groups=["admin"], **msg.model_dump()
)
msg_conv._id = msg_conv.scan_id
# TODO for compatibility with the old message format; remove once the bec_lib is updated
out = msg_conv.__dict__
out["session_id"] = session_id
self.datasource.db["scans"].insert_one(out)
else:
self.datasource.db["scans"].update_one(
{"_id": msg.scan_id}, {"$set": {"status": msg.status}}
)
def shutdown(self):
self.shutdown_event.set()
if self.deployment_listener_thread:
self.deployment_listener_thread.join()
if self.receiver_thread:
self.receiver_thread.join()
self.datasource.shutdown()

View File

@ -3,11 +3,17 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from bec_atlas.datasources.datasource_manager import DatasourceManager from bec_atlas.datasources.datasource_manager import DatasourceManager
from bec_atlas.router.deployments_router import DeploymentsRouter
from bec_atlas.router.realm_router import RealmRouter
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
from bec_atlas.router.scan_router import ScanRouter from bec_atlas.router.scan_router import ScanRouter
from bec_atlas.router.user import UserRouter from bec_atlas.router.user_router import UserRouter
CONFIG = {"redis": {"host": "localhost", "port": 6380}, "scylla": {"hosts": ["localhost"]}} CONFIG = {
"redis": {"host": "localhost", "port": 6379},
"scylla": {"hosts": ["localhost"]},
"mongodb": {"host": "localhost", "port": 27017},
}
class AtlasApp: class AtlasApp:
@ -35,11 +41,14 @@ class AtlasApp:
def add_routers(self): def add_routers(self):
if not self.datasources.datasources: if not self.datasources.datasources:
raise ValueError("Datasources not loaded") raise ValueError("Datasources not loaded")
if "scylla" in self.datasources.datasources: self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources) self.app.include_router(self.scan_router.router, tags=["Scan"])
self.app.include_router(self.scan_router.router) self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources)
self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources) self.app.include_router(self.user_router.router, tags=["User"])
self.app.include_router(self.user_router.router) self.deployment_router = DeploymentsRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.deployment_router.router, tags=["Deployment"])
self.realm_router = RealmRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.realm_router.router, tags=["Realm"])
if "redis" in self.datasources.datasources: if "redis" in self.datasources.datasources:
self.redis_websocket = RedisWebsocket( self.redis_websocket = RedisWebsocket(

View File

@ -0,0 +1 @@
from .model import *

View File

@ -0,0 +1,144 @@
import uuid
from typing import Literal
from bec_lib import messages
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field
class AccessProfile(BaseModel):
owner_groups: list[str]
access_groups: list[str] = []
class ScanStatus(AccessProfile, messages.ScanStatusMessage):
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class UserCredentials(AccessProfile):
user_id: str | ObjectId
password: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class User(AccessProfile):
id: str | ObjectId | None = Field(default=None, alias="_id")
email: str
groups: list[str]
first_name: str
last_name: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class UserInfo(BaseModel):
email: str
groups: list[str]
class Deployments(AccessProfile):
realm_id: str
name: str
deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4()))
active_session_id: str | None = None
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Experiments(AccessProfile):
realm_id: str
pgroup: str
proposal: str
text: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class StateCondition(AccessProfile):
realm_id: str
name: str
description: str
device: str
signal_value: str
signal_type: str
tolerance: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class State(AccessProfile):
realm_id: str
name: str
description: str
conditions: list[str]
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Session(AccessProfile):
realm_id: str
session_id: str
config: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Realm(AccessProfile):
realm_id: str
deployments: list[Deployments] = []
name: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class Datasets(AccessProfile):
realm_id: str
dataset_id: str
name: str
description: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class DatasetUserData(AccessProfile):
dataset_id: str
name: str
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class ScanUserData(AccessProfile):
scan_id: str
name: str
rating: int
comments: str
preview: bytes
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
class DeviceConfig(AccessProfile):
device_name: str
readout_priority: Literal["monitored", "baseline", "on_request", "async", "continuous"]
device_config: dict
device_class: str
tags: list[str] = []
software_trigger: bool
class SignalData(AccessProfile):
scan_id: str
device_id: str
device_name: str
signal_name: str
data: float | int | str | bool | bytes | dict | list | None
timestamp: float
kind: Literal["hinted", "omitted", "normal", "config"]
class DeviceData(AccessProfile):
scan_id: str | None
device_name: str
device_config_id: str
signals: list[SignalData]

View File

@ -0,0 +1,54 @@
from fastapi import APIRouter, Depends
from bec_atlas.authentication import get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import Deployments, UserInfo
from bec_atlas.router.base_router import BaseRouter
class DeploymentsRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/deployments/realm/{realm}",
self.deployments,
methods=["GET"],
description="Get all deployments for the realm",
response_model=list[Deployments],
)
self.router.add_api_route(
"/deployments/id/{deployment_id}",
self.deployment_with_id,
methods=["GET"],
description="Get a single deployment by id for a realm",
response_model=Deployments,
)
async def deployments(
self, realm: str, current_user: UserInfo = Depends(get_current_user)
) -> list[Deployments]:
"""
Get all deployments for a realm.
Args:
realm (str): The realm id
Returns:
list[Deployments]: List of deployments for the realm
"""
return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user)
async def deployment_with_id(
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
):
"""
Get deployment with id from realm
Args:
scan_id (str): The scan id
"""
return self.db.find_one(
"deployments", {"_id": deployment_id}, Deployments, user=current_user
)

View File

@ -0,0 +1,47 @@
from fastapi import APIRouter
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import Realm
from bec_atlas.router.base_router import BaseRouter
class RealmRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.router = APIRouter(prefix=prefix)
self.router.add_api_route(
"/realms",
self.realms,
methods=["GET"],
description="Get all deployments for the realm",
response_model=list[Realm],
)
self.router.add_api_route(
"/realms/{realm_id}",
self.realm_with_id,
methods=["GET"],
description="Get a single deployment by id for a realm",
response_model=Realm,
)
async def realms(self) -> list[Realm]:
"""
Get all realms.
Returns:
list[Realm]: List of realms
"""
return self.db.find("realms", {}, Realm)
async def realm_with_id(self, realm_id: str):
"""
Get realm with id.
Args:
realm_id (str): The realm id
Returns:
Realm: The realm with the id
"""
return self.db.find_one("realms", {"_id": realm_id}, Realm)

View File

@ -1,20 +1,47 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from bec_atlas.authentication import get_current_user from bec_atlas.authentication import get_current_user
from bec_atlas.datasources.scylladb import scylladb_schema as schema from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model.model import ScanStatus, UserInfo
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
class ScanRouter(BaseRouter): class ScanRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None): def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources) super().__init__(prefix, datasources)
self.scylla = self.datasources.datasources.get("scylla") self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.router = APIRouter(prefix=prefix) self.router = APIRouter(prefix=prefix)
self.router.add_api_route("/scan", self.scan, methods=["GET"]) self.router.add_api_route(
self.router.add_api_route("/scan/{scan_id}", self.scan_with_id, methods=["GET"]) "/scans/session/{session_id}",
self.scans,
methods=["GET"],
description="Get all scans for a session",
response_model=list[ScanStatus],
)
self.router.add_api_route(
"/scans/id/{scan_id}",
self.scans_with_id,
methods=["GET"],
description="Get a single scan by id for a session",
response_model=ScanStatus,
)
async def scan(self, current_user: schema.User = Depends(get_current_user)): async def scans(
return self.scylla.get("scan", current_user=current_user) self, session_id: str, current_user: UserInfo = Depends(get_current_user)
) -> list[ScanStatus]:
"""
Get all scans for a session.
async def scan_with_id(self, scan_id: str): Args:
return {"scan_id": scan_id} session_id (str): The session id
"""
return self.db.find("scans", {"session_id": session_id}, ScanStatus)
async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)):
"""
Get scan with id from session
Args:
scan_id (str): The scan id
"""
return self.db.find_one("scans", {"_id": scan_id}, ScanStatus)

View File

@ -6,7 +6,8 @@ from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from bec_atlas.authentication import create_access_token, get_current_user, verify_password from bec_atlas.authentication import create_access_token, get_current_user, verify_password
from bec_atlas.datasources.scylladb import scylladb_schema as schema from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
from bec_atlas.model import UserInfo
from bec_atlas.router.base_router import BaseRouter from bec_atlas.router.base_router import BaseRouter
@ -18,7 +19,7 @@ class UserLoginRequest(BaseModel):
class UserRouter(BaseRouter): class UserRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None): def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources) super().__init__(prefix, datasources)
self.scylla = self.datasources.datasources.get("scylla") self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
self.router = APIRouter(prefix=prefix) self.router = APIRouter(prefix=prefix)
self.router.add_api_route("/user/me", self.user_me, methods=["GET"]) self.router.add_api_route("/user/me", self.user_me, methods=["GET"])
self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[]) self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[])
@ -26,11 +27,11 @@ class UserRouter(BaseRouter):
"/user/login/form", self.form_login, methods=["POST"], dependencies=[] "/user/login/form", self.form_login, methods=["POST"], dependencies=[]
) )
async def user_me(self, user: schema.User = Depends(get_current_user)): async def user_me(self, user: UserInfo = Depends(get_current_user)):
data = schema.User.objects.filter(email=user.email) data = self.db.get_user_by_email(user.email)
if data.count() == 0: if data is None:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return data.first() return data
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]): async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
user_login = UserLoginRequest(username=form_data.username, password=form_data.password) user_login = UserLoginRequest(username=form_data.username, password=form_data.password)
@ -38,15 +39,13 @@ class UserRouter(BaseRouter):
return {"access_token": out, "token_type": "bearer"} return {"access_token": out, "token_type": "bearer"}
async def user_login(self, user_login: UserLoginRequest): async def user_login(self, user_login: UserLoginRequest):
result = schema.User.objects.filter(email=user_login.username) user = self.db.get_user_by_email(user_login.username)
if result.count() == 0: if user is None:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
user: schema.User = result.first() credentials = self.db.get_user_credentials(user.id)
credentials = schema.UserCredentials.objects.filter(user_id=user.user_id) if credentials is None:
if credentials.count() == 0:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
user_credentials = credentials.first() if not verify_password(user_login.password, credentials.password):
if not verify_password(user_login.password, user_credentials.password):
raise HTTPException(status_code=401, detail="Invalid password") raise HTTPException(status_code=401, detail="Invalid password")
return create_access_token(data={"groups": list(user.groups), "email": user.email}) return create_access_token(data={"groups": list(user.groups), "email": user.email})

View File

@ -0,0 +1,38 @@
import pymongo
from bec_atlas.model import Deployments, Realm
class DemoSetupLoader:
def __init__(self, config: dict):
self.config = config
self.client = pymongo.MongoClient(config.get("host"), config.get("port"))
self.db = self.client["bec_atlas"]
self.data = {}
def load(self):
self.load_realm()
self.load_deployments()
def load_realm(self):
realm = Realm(realm_id="demo_beamline_1", name="Demo Beamline 1", owner_groups=["admin"])
realm._id = realm.realm_id
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
self.db["realms"].insert_one(realm.__dict__)
realm = Realm(realm_id="demo_beamline_2", name="Demo Beamline 2", owner_groups=["admin"])
realm._id = realm.realm_id
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
self.db["realms"].insert_one(realm.__dict__)
def load_deployments(self):
deployment = Deployments(
realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"]
)
self.db["deployments"].insert_one(deployment.__dict__)
if __name__ == "__main__":
loader = DemoSetupLoader({"host": "localhost", "port": 27017})
loader.load()

View File

@ -26,6 +26,7 @@ def wait_for_scylladb(scylla_host: str = SCYLLA_HOST, scylla_port: int = SCYLLA_
print("Connected to ScyllaDB") print("Connected to ScyllaDB")
return session return session
except Exception as e: except Exception as e:
# breakpoint()
print(f"ScyllaDB not ready yet: {e}") print(f"ScyllaDB not ready yet: {e}")
time.sleep(5) time.sleep(5)

View File

@ -21,6 +21,8 @@ dependencies = [
"python-socketio[asyncio_client]", "python-socketio[asyncio_client]",
"libtmux", "libtmux",
"websocket-client", "websocket-client",
"pydantic",
"pymongo",
] ]
@ -32,6 +34,7 @@ dev = [
"pytest~=8.0", "pytest~=8.0",
"pytest-docker", "pytest-docker",
"isort~=5.13, >=5.13.2", "isort~=5.13, >=5.13.2",
"pytest-asyncio",
] ]
[project.scripts] [project.scripts]

View File

@ -0,0 +1,32 @@
# set base image (host OS)
ARG PY_VERSION=3.10 CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX
FROM python:${PY_VERSION}
ARG BEC_ATLAS_BRANCH=main BEC_CORE_BRANCH=main
RUN echo "Building BEC Atlas environment for branch ${BEC_ATLAS_BRANCH} with BEC branch ${BEC_CORE_BRANCH}"
RUN apt update
RUN apt install git -y
RUN apt install netcat-openbsd -y
# set the working directory in the container
WORKDIR /code
# clone the bec repo
RUN git clone --branch ${BEC_CORE_BRANCH} https://gitlab.psi.ch/bec/bec.git
WORKDIR /code/bec/
RUN pip install -e bec_lib[dev]
WORKDIR /code
RUN git clone --branch ${BEC_ATLAS_BRANCH} https://gitlab.psi.ch/bec/bec_atlas.git
WORKDIR /code/bec_atlas
RUN pip install -e ./backend[dev]
RUN mkdir -p /code/bec_atlas/test_files
# command to run on container start
ENTRYPOINT [ "./backend/tests/coverage_run.sh" ]

View File

@ -3,11 +3,11 @@ import os
from typing import Iterator from typing import Iterator
import pytest import pytest
from bec_atlas.main import AtlasApp
from bec_atlas.utils.setup_database import setup_database
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pytest_docker.plugin import DockerComposeExecutor, Services from pytest_docker.plugin import DockerComposeExecutor, Services
from bec_atlas.main import AtlasApp
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
@ -101,42 +101,39 @@ def docker_services(
yield docker_service yield docker_service
@pytest.fixture(scope="session")
def scylla_container(docker_ip, docker_services):
host = docker_ip
if os.path.exists("/.dockerenv"):
# if we are running in the CI, scylla was started as 'scylla' service
host = "scylla"
if docker_services is None:
port = 9042
else:
port = docker_services.port_for("scylla", 9042)
setup_database(host=host, port=port)
return host, port
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def redis_container(docker_ip, docker_services): def redis_container(docker_ip, docker_services):
host = docker_ip host = docker_ip
if os.path.exists("/.dockerenv"): if os.path.exists("/.dockerenv"):
# if we are running in the CI, scylla was started as 'scylla' service
host = "redis" host = "redis"
if docker_services is None: if docker_services is None:
port = 6380 port = 6380
else: else:
port = docker_services.port_for("redis", 6379) port = docker_services.port_for("redis", 6379)
return host, port return "localhost", port
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def backend(scylla_container, redis_container): def mongo_container(docker_ip, docker_services):
scylla_host, scylla_port = scylla_container host = docker_ip
if os.path.exists("/.dockerenv"):
host = "mongo"
if docker_services is None:
port = 27017
else:
port = docker_services.port_for("mongodb", 27017)
return "localhost", port
@pytest.fixture(scope="session")
def backend(redis_container, mongo_container):
redis_host, redis_port = redis_container redis_host, redis_port = redis_container
mongo_host, mongo_port = mongo_container
config = { config = {
"scylla": {"hosts": [(scylla_host, scylla_port)]},
"redis": {"host": redis_host, "port": redis_port}, "redis": {"host": redis_host, "port": redis_port},
"mongodb": {"host": mongo_host, "port": mongo_port},
} }
app = AtlasApp(config) app = AtlasApp(config)

23
backend/tests/coverage_run.sh Executable file
View File

@ -0,0 +1,23 @@
#!/bin/bash
# check if redis is running on port 6380 and mongodb on port 27017 and scylladb on port 9070
if [ "$(nc -z localhost 6380; echo $?)" -ne 0 ]; then
echo "Redis is not running on port 6380"
exit 1
fi
if [ "$(nc -z localhost 27017; echo $?)" -ne 0 ]; then
echo "MongoDB is not running on port 27017"
exit 1
fi
coverage run --concurrency=thread --source=./backend --omit=*/backend/tests/* -m pytest -v --junitxml=./test_files/report.xml --skip-docker --random-order --full-trace ./backend/tests
EXIT_STATUS=$?
if [ $EXIT_STATUS -ne 0 ]; then
exit $EXIT_STATUS
fi
coverage report
coverage xml -o ./test_files/coverage.xml

View File

@ -1,10 +1,10 @@
version: '2' version: '2'
services: services:
scylla:
image: scylladb/scylla:latest
ports:
- "9070:9042"
redis: redis:
image: redis:latest image: redis:latest
ports: ports:
- "6380:6379" - "6380:6379"
mongodb:
image: mongo:latest
ports:
- "27017:27017"

View File

@ -72,7 +72,7 @@ async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
async def test_redis_websocket_register_wrong_endpoint_raises(backend_client): async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
client, app = backend_client client, app = backend_client
with mock.patch.object(app.redis_websocket.socket, "emit") as emit: with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
app.redis_websocket.socket.handlers["/"]["connect"]("sid") await app.redis_websocket.socket.handlers["/"]["connect"]("sid")
await app.redis_websocket.socket.handlers["/"]["register"]( await app.redis_websocket.socket.handlers["/"]["register"](
"sid", json.dumps({"endpoint": "wrong_endpoint"}) "sid", json.dumps({"endpoint": "wrong_endpoint"})
) )

View File

@ -0,0 +1,107 @@
import pytest
from bec_lib import messages
from bec_atlas.ingestor.data_ingestor import DataIngestor
@pytest.fixture
def scan_ingestor(backend):
client, app = backend
app.redis_websocket.users = {}
ingestor = DataIngestor(config=app.config)
yield ingestor
ingestor.shutdown()
@pytest.mark.timeout(60)
def test_scan_ingestor_create_scan(scan_ingestor, backend):
"""
Test that the login endpoint returns a token.
"""
client, app = backend
msg = messages.ScanStatusMessage(
metadata={},
scan_id="92429a81-4bd4-41c2-82df-eccfaddf3d96",
status="open",
# session_id="5cc67967-744d-4115-a46b-13246580cb3f",
info={
"readout_priority": {
"monitored": ["bpm3i", "diode", "ftp", "bpm5c", "bpm3x", "bpm3z", "bpm4x"],
"baseline": ["ddg1a", "bs1y", "mobdco"],
"async": ["eiger", "monitor_async", "waveform"],
"continuous": [],
"on_request": ["flyer_sim"],
},
"file_suffix": None,
"file_directory": None,
"user_metadata": {"sample_name": "testA"},
"RID": "5cc67967-744d-4115-a46b-13246580cb3f",
"scan_id": "92429a81-4bd4-41c2-82df-eccfaddf3d96",
"queue_id": "7d77d976-bee0-4bb8-aabb-2b862b4506ec",
"session_id": "5cc67967-744d-4115-a46b-13246580cb3f",
"scan_motors": ["samx"],
"num_points": 10,
"positions": [
[-5.0024118137239455],
[-3.8913007026128343],
[-2.780189591501723],
[-1.6690784803906122],
[-0.557967369279501],
[0.5531437418316097],
[1.6642548529427212],
[2.775365964053833],
[3.886477075164944],
[4.9975881862760545],
],
"scan_name": "line_scan",
"scan_type": "step",
"scan_number": 2,
"dataset_number": 2,
"exp_time": 0,
"frames_per_trigger": 1,
"settling_time": 0,
"readout_time": 0,
"acquisition_config": {"default": {"exp_time": 0, "readout_time": 0}},
"scan_report_devices": ["samx"],
"monitor_sync": "bec",
"scan_msgs": [
"metadata={'file_suffix': None, 'file_directory': None, 'user_metadata': {'sample_name': 'testA'}, 'RID': '5cc67967-744d-4115-a46b-13246580cb3f'} scan_type='line_scan' parameter={'args': {'samx': [-5, 5]}, 'kwargs': {'steps': 10, 'exp_time': 0, 'relative': True, 'system_config': {'file_suffix': None, 'file_directory': None}}} queue='primary'"
],
"args": {"samx": [-5, 5]},
"kwargs": {
"steps": 10,
"exp_time": 0,
"relative": True,
"system_config": {"file_suffix": None, "file_directory": None},
},
},
timestamp=1732610545.15924,
)
scan_ingestor.update_scan_status(msg)
response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
client.headers.update({"Authorization": f"Bearer {response.json()}"})
session_id = msg.info.get("session_id")
scan_id = msg.scan_id
response = client.get(f"/api/v1/scans/session/{session_id}")
assert response.status_code == 200
out = response.json()[0]
# assert out["session_id"] == session_id
assert out["scan_id"] == scan_id
assert out["status"] == "open"
msg.status = "closed"
scan_ingestor.update_scan_status(msg)
response = client.get(f"/api/v1/scans/id/{scan_id}")
assert response.status_code == 200
out = response.json()
assert out["status"] == "closed"
assert out["scan_id"] == scan_id
response = client.get(f"/api/v1/scans/session/{session_id}")
assert response.status_code == 200
out = response.json()
assert len(out) == 1