mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
feat: added support for mongodb
This commit is contained in:
@ -4,6 +4,13 @@ variables:
|
||||
SCYLLA_HOST: scylla
|
||||
SCYLLA_PORT: 9042
|
||||
SCYLLA_KEYSPACE: bec_atlas
|
||||
REDIS_HOST: redis
|
||||
REDIS_PORT: 6380
|
||||
DOCKER_TLS_CERTDIR: ""
|
||||
CHILD_PIPELINE_BRANCH: $CI_DEFAULT_BRANCH
|
||||
BEC_CORE_BRANCH:
|
||||
description: "BEC branch to use for testing"
|
||||
value: main
|
||||
|
||||
workflow:
|
||||
rules:
|
||||
@ -68,20 +75,29 @@ pylint:
|
||||
interruptible: true
|
||||
|
||||
backend_pytest:
|
||||
services:
|
||||
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/scylladb/scylla:latest
|
||||
alias: scylla
|
||||
- name: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/redis:latest
|
||||
alias: redis
|
||||
command: ["redis-server", "--port", "6380"]
|
||||
stage: test
|
||||
image: $CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX/docker:23-dind
|
||||
services:
|
||||
- name: docker:dind
|
||||
entrypoint: ["dockerd-entrypoint.sh", "--tls=false"]
|
||||
needs: []
|
||||
script:
|
||||
- pip install ./backend[dev]
|
||||
- pip install coverage pytest-asyncio
|
||||
- coverage run --concurrency=thread --source=./backend --omit=*/backend/tests/* -m pytest -v --junitxml=report.xml --skip-docker --random-order --full-trace ./backend/tests
|
||||
- coverage report
|
||||
- coverage xml
|
||||
- if [[ "$CI_PROJECT_PATH" != "bec/bec_atlas" ]]; then
|
||||
apk update; apk add git; echo -e "\033[35;1m Using branch $CHILD_PIPELINE_BRANCH of BEC Atlas \033[0;m";
|
||||
test -d bec_atlas || git clone --branch $CHILD_PIPELINE_BRANCH https://gitlab.psi.ch/bec/bec_atlas.git; cd bec_atlas;
|
||||
TARGET_BRANCH=$CHILD_PIPELINE_BRANCH;
|
||||
else
|
||||
TARGET_BRANCH=$CI_COMMIT_REF_NAME;
|
||||
fi
|
||||
# start services
|
||||
- docker-compose -f ./backend/tests/docker-compose.yml up -d
|
||||
|
||||
# build test environment
|
||||
- echo "$CI_DEPENDENCY_PROXY_PASSWORD" | docker login $CI_DEPENDENCY_PROXY_SERVER --username $CI_DEPENDENCY_PROXY_USER --password-stdin
|
||||
- docker build -t bec_atlas_backend:test -f ./backend/tests/Dockerfile.run_pytest --build-arg PY_VERSION=3.10 --build-arg BEC_ATLAS_BRANCH=$TARGET_BRANCH --build-arg BEC_CORE_BRANCH=$BEC_CORE_BRANCH --build-arg CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX .
|
||||
- docker run --network=host --name bec_atlas_backend bec_atlas_backend:test
|
||||
after_script:
|
||||
- docker cp bec_atlas_backend:/code/bec_atlas/test_files/. $CI_PROJECT_DIR
|
||||
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
|
||||
artifacts:
|
||||
reports:
|
||||
@ -116,7 +132,7 @@ backend_pytest:
|
||||
# - semantic-release publish
|
||||
|
||||
|
||||
allow_failure: false
|
||||
rules:
|
||||
- if: '$CI_COMMIT_REF_NAME == "main" && $CI_PROJECT_PATH == "bec/bec_atlas"'
|
||||
interruptible: true
|
||||
# allow_failure: false
|
||||
# rules:
|
||||
# - if: '$CI_COMMIT_REF_NAME == "main" && $CI_PROJECT_PATH == "bec/bec_atlas"'
|
||||
# interruptible: true
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
@ -8,7 +10,7 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from pwdlib import PasswordHash
|
||||
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from bec_atlas.model import UserInfo
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
@ -54,7 +56,7 @@ def decode_token(token: str):
|
||||
raise credentials_exception from exc
|
||||
|
||||
|
||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> schema.User:
|
||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> UserInfo:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
@ -68,4 +70,4 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> sch
|
||||
raise credentials_exception
|
||||
except Exception as exc:
|
||||
raise credentials_exception from exc
|
||||
return schema.User(groups=groups, email=email)
|
||||
return UserInfo(groups=groups, email=email)
|
||||
|
@ -1,5 +1,9 @@
|
||||
from bec_lib.logger import bec_logger
|
||||
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||
from bec_atlas.datasources.scylladb.scylladb import ScylladbDatasource
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
@ -13,11 +17,12 @@ class DatasourceManager:
|
||||
datasource.connect()
|
||||
|
||||
def load_datasources(self):
|
||||
logger.info(f"Loading datasources with config: {self.config}")
|
||||
for datasource_name, datasource_config in self.config.items():
|
||||
if datasource_name == "scylla":
|
||||
self.datasources[datasource_name] = ScylladbDatasource(datasource_config)
|
||||
if datasource_name == "redis":
|
||||
self.datasources[datasource_name] = RedisDatasource(datasource_config)
|
||||
if datasource_name == "mongodb":
|
||||
self.datasources[datasource_name] = MongoDBDatasource(datasource_config)
|
||||
|
||||
def shutdown(self):
|
||||
for datasource in self.datasources.values():
|
||||
|
170
backend/bec_atlas/datasources/mongodb/mongodb.py
Normal file
170
backend/bec_atlas/datasources/mongodb/mongodb.py
Normal file
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pymongo
|
||||
from bec_lib.logger import bec_logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bec_atlas.authentication import get_password_hash
|
||||
from bec_atlas.model.model import User, UserCredentials
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
class MongoDBDatasource:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.client = None
|
||||
self.db = None
|
||||
|
||||
def connect(self, include_setup: bool = True):
|
||||
"""
|
||||
Connect to the MongoDB database.
|
||||
"""
|
||||
host = self.config.get("host", "localhost")
|
||||
port = self.config.get("port", 27017)
|
||||
logger.info(f"Connecting to MongoDB at {host}:{port}")
|
||||
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}/")
|
||||
self.db = self.client["bec_atlas"]
|
||||
if include_setup:
|
||||
self.db["users"].create_index([("email", 1)], unique=True)
|
||||
self.load_functional_accounts()
|
||||
|
||||
def load_functional_accounts(self):
|
||||
"""
|
||||
Load the functional accounts to the database.
|
||||
"""
|
||||
functional_accounts_file = os.path.join(
|
||||
os.path.dirname(__file__), "functional_accounts.json"
|
||||
)
|
||||
if os.path.exists(functional_accounts_file):
|
||||
with open(functional_accounts_file, "r", encoding="utf-8") as file:
|
||||
functional_accounts = json.load(file)
|
||||
else:
|
||||
print("Functional accounts file not found. Using default demo accounts.")
|
||||
# Demo accounts
|
||||
functional_accounts = [
|
||||
{
|
||||
"email": "admin@bec_atlas.ch",
|
||||
"password": "admin",
|
||||
"groups": ["demo", "admin"],
|
||||
"first_name": "Admin",
|
||||
"last_name": "Admin",
|
||||
"owner_groups": ["admin"],
|
||||
},
|
||||
{
|
||||
"email": "jane.doe@bec_atlas.ch",
|
||||
"password": "atlas",
|
||||
"groups": ["demo_user"],
|
||||
"first_name": "Jane",
|
||||
"last_name": "Doe",
|
||||
"owner_groups": ["admin"],
|
||||
},
|
||||
]
|
||||
|
||||
for account in functional_accounts:
|
||||
# check if the account already exists in the database
|
||||
password = account.pop("password")
|
||||
password_hash = get_password_hash(password)
|
||||
result = self.db["users"].find_one({"email": account["email"]})
|
||||
if result is not None:
|
||||
continue
|
||||
user = User(**account)
|
||||
user = self.db["users"].insert_one(user.__dict__)
|
||||
credentials = UserCredentials(
|
||||
owner_groups=["admin"], user_id=user.inserted_id, password=password_hash
|
||||
)
|
||||
self.db["user_credentials"].insert_one(credentials.__dict__)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
"""
|
||||
Get the user from the database.
|
||||
"""
|
||||
out = self.db["users"].find_one({"email": email})
|
||||
if out is None:
|
||||
return None
|
||||
return User(**out)
|
||||
|
||||
def get_user_credentials(self, user_id: str) -> UserCredentials | None:
|
||||
"""
|
||||
Get the user credentials from the database.
|
||||
"""
|
||||
out = self.db["user_credentials"].find_one({"user_id": user_id})
|
||||
if out is None:
|
||||
return None
|
||||
return UserCredentials(**out)
|
||||
|
||||
def find_one(
|
||||
self, collection: str, query_filter: dict, dtype: BaseModel, user: User | None = None
|
||||
) -> BaseModel | None:
|
||||
"""
|
||||
Find one document in the collection.
|
||||
|
||||
Args:
|
||||
collection (str): The collection name
|
||||
query_filter (dict): The filter to apply
|
||||
dtype (BaseModel): The data type to return
|
||||
user (User): The user making the request
|
||||
|
||||
Returns:
|
||||
BaseModel: The data type with the document data
|
||||
"""
|
||||
if user is not None:
|
||||
query_filter = self.add_user_filter(user, query_filter)
|
||||
out = self.db[collection].find_one(query_filter)
|
||||
if out is None:
|
||||
return None
|
||||
return dtype(**out)
|
||||
|
||||
def find(
|
||||
self, collection: str, query_filter: dict, dtype: BaseModel, user: User | None = None
|
||||
) -> list[BaseModel]:
|
||||
"""
|
||||
Find all documents in the collection.
|
||||
|
||||
Args:
|
||||
collection (str): The collection name
|
||||
query_filter (dict): The filter to apply
|
||||
dtype (BaseModel): The data type to return
|
||||
user (User): The user making the request
|
||||
|
||||
Returns:
|
||||
list[BaseModel]: The data type with the document data
|
||||
"""
|
||||
if user is not None:
|
||||
query_filter = self.add_user_filter(user, query_filter)
|
||||
out = self.db[collection].find(query_filter)
|
||||
return [dtype(**x) for x in out]
|
||||
|
||||
def add_user_filter(self, user: User, query_filter: dict) -> dict:
|
||||
"""
|
||||
Add the user filter to the query filter.
|
||||
|
||||
Args:
|
||||
user (User): The user making the request
|
||||
query_filter (dict): The query filter
|
||||
|
||||
Returns:
|
||||
dict: The updated query filter
|
||||
"""
|
||||
if "admin" not in user.groups:
|
||||
query_filter = {
|
||||
"$and": [
|
||||
query_filter,
|
||||
{
|
||||
"$or": [
|
||||
{"owner_groups": {"$in": user.groups}},
|
||||
{"access_groups": {"$in": user.groups}},
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
return query_filter
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connection to the database.
|
||||
"""
|
||||
if self.client is not None:
|
||||
self.client.close()
|
||||
logger.info("Connection to MongoDB closed.")
|
@ -1,132 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqlengine import connection
|
||||
from cassandra.cqlengine.management import create_keyspace_simple, sync_table
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bec_atlas.authentication import get_password_hash
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
|
||||
|
||||
class ScylladbDatasource:
|
||||
KEYSPACE = "bec_atlas"
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect(self):
|
||||
self.start_client()
|
||||
self.load_functional_accounts()
|
||||
|
||||
def start_client(self):
|
||||
"""
|
||||
Start the ScyllaDB client by creating a Cluster object and a Session object.
|
||||
"""
|
||||
hosts = self.config.get("hosts")
|
||||
if not hosts:
|
||||
raise ValueError("Hosts are not provided in the configuration")
|
||||
|
||||
#
|
||||
connection.setup(hosts, self.KEYSPACE, protocol_version=3)
|
||||
create_keyspace_simple(self.KEYSPACE, 1)
|
||||
self._sync_tables()
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
def _sync_tables(self):
|
||||
"""
|
||||
Sync the tables with the schema defined in the scylladb_schema.py file.
|
||||
"""
|
||||
sync_table(schema.Realm)
|
||||
sync_table(schema.Deployments)
|
||||
sync_table(schema.Experiments)
|
||||
sync_table(schema.StateCondition)
|
||||
sync_table(schema.State)
|
||||
sync_table(schema.Session)
|
||||
sync_table(schema.Datasets)
|
||||
sync_table(schema.DatasetUserData)
|
||||
sync_table(schema.Scan)
|
||||
sync_table(schema.ScanUserData)
|
||||
sync_table(schema.ScanData)
|
||||
sync_table(schema.SignalDataInt)
|
||||
sync_table(schema.SignalDataFloat)
|
||||
sync_table(schema.SignalDataString)
|
||||
sync_table(schema.SignalDataBool)
|
||||
sync_table(schema.SignalDataBlob)
|
||||
sync_table(schema.SignalDataDateTime)
|
||||
sync_table(schema.SignalDataUUID)
|
||||
sync_table(schema.User)
|
||||
sync_table(schema.UserCredentials)
|
||||
|
||||
def load_functional_accounts(self):
|
||||
"""
|
||||
Load the functional accounts to the database.
|
||||
"""
|
||||
functional_accounts_file = os.path.join(
|
||||
os.path.dirname(__file__), "functional_accounts.json"
|
||||
)
|
||||
if os.path.exists(functional_accounts_file):
|
||||
with open(functional_accounts_file, "r", encoding="utf-8") as file:
|
||||
functional_accounts = json.load(file)
|
||||
else:
|
||||
print("Functional accounts file not found. Using default demo accounts.")
|
||||
# Demo accounts
|
||||
functional_accounts = [
|
||||
{
|
||||
"email": "admin@bec_atlas.ch",
|
||||
"password": "admin",
|
||||
"groups": ["demo"],
|
||||
"first_name": "Admin",
|
||||
"last_name": "Admin",
|
||||
},
|
||||
{
|
||||
"email": "jane.doe@bec_atlas.ch",
|
||||
"password": "atlas",
|
||||
"groups": ["demo_user"],
|
||||
"first_name": "Jane",
|
||||
"last_name": "Doe",
|
||||
},
|
||||
]
|
||||
|
||||
for account in functional_accounts:
|
||||
# check if the account already exists in the database
|
||||
password = account.pop("password")
|
||||
password_hash = get_password_hash(password)
|
||||
result = schema.User.objects.filter(email=account["email"])
|
||||
if result.count() > 0:
|
||||
continue
|
||||
user = schema.User.create(**account)
|
||||
|
||||
schema.UserCredentials.create(user_id=user.user_id, password=password_hash)
|
||||
|
||||
def get(self, table_name: str, filter: str = None, parameters: tuple = None):
|
||||
"""
|
||||
Get the data from the specified table.
|
||||
"""
|
||||
# schema.User.objects.get(email=)
|
||||
if filter:
|
||||
query = f"SELECT * FROM {self.KEYSPACE}.{table_name} WHERE {filter};"
|
||||
else:
|
||||
query = f"SELECT * FROM {self.KEYSPACE}.{table_name};"
|
||||
if parameters:
|
||||
return self.session.execute(query, parameters)
|
||||
return self.session.execute(query)
|
||||
|
||||
def post(self, table_name: str, data: BaseModel):
|
||||
"""
|
||||
Post the data to the specified table.
|
||||
|
||||
Args:
|
||||
table_name (str): The name of the table to post the data.
|
||||
data (BaseModel): The data to be posted.
|
||||
|
||||
"""
|
||||
query = f"INSERT INTO {self.KEYSPACE}.{table_name} JSON '{data.model_dump_json(exclude_none=True)}';"
|
||||
return self.session.execute(query)
|
||||
|
||||
def shutdown(self):
|
||||
self.cluster.shutdown()
|
@ -1,139 +0,0 @@
|
||||
import uuid
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.models import Model
|
||||
|
||||
|
||||
class User(Model):
|
||||
email = columns.Text(primary_key=True)
|
||||
user_id = columns.UUID(default=uuid.uuid4)
|
||||
first_name = columns.Text()
|
||||
last_name = columns.Text()
|
||||
groups = columns.Set(columns.Text)
|
||||
created_at = columns.DateTime()
|
||||
updated_at = columns.DateTime()
|
||||
|
||||
|
||||
class UserCredentials(Model):
|
||||
user_id = columns.UUID(primary_key=True)
|
||||
password = columns.Text()
|
||||
|
||||
|
||||
class Realm(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
deployment_id = columns.Text(primary_key=True)
|
||||
name = columns.Text()
|
||||
|
||||
|
||||
class Deployments(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
deployment_id = columns.Text(primary_key=True)
|
||||
name = columns.Text()
|
||||
active_session_id = columns.UUID()
|
||||
|
||||
|
||||
class Experiments(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
pgroup = columns.Text(primary_key=True)
|
||||
proposal = columns.Text()
|
||||
text = columns.Text()
|
||||
|
||||
|
||||
class StateCondition(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
name = columns.Text(primary_key=True)
|
||||
description = columns.Text()
|
||||
device = columns.Text()
|
||||
signal_value = columns.Text()
|
||||
signal_type = columns.Text()
|
||||
tolerance = columns.Text()
|
||||
|
||||
|
||||
class State(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
name = columns.Text(primary_key=True)
|
||||
description = columns.Text()
|
||||
conditions = columns.List(columns.Text)
|
||||
|
||||
|
||||
class Session(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
config = columns.Text()
|
||||
|
||||
|
||||
class Datasets(Model):
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
dataset_id = columns.UUID(primary_key=True)
|
||||
scan_id = columns.UUID()
|
||||
|
||||
|
||||
class DatasetUserData(Model):
|
||||
dataset_id = columns.UUID(primary_key=True)
|
||||
name = columns.Text()
|
||||
rating = columns.Integer()
|
||||
comments = columns.Text()
|
||||
preview = columns.Blob()
|
||||
|
||||
|
||||
class Scan(Model):
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
scan_number = columns.Integer()
|
||||
name = columns.Text()
|
||||
scan_class = columns.Text()
|
||||
parameters = columns.Text()
|
||||
start_time = columns.DateTime()
|
||||
end_time = columns.DateTime()
|
||||
exit_status = columns.Text()
|
||||
|
||||
|
||||
class ScanUserData(Model):
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
name = columns.Text()
|
||||
rating = columns.Integer()
|
||||
comments = columns.Text()
|
||||
preview = columns.Blob()
|
||||
|
||||
|
||||
class ScanData(Model):
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
device_name = columns.Text(primary_key=True)
|
||||
signal_name = columns.Text(primary_key=True)
|
||||
shape = columns.List(columns.Integer)
|
||||
dtype = columns.Text()
|
||||
|
||||
|
||||
class SignalDataBase(Model):
|
||||
realm_id = columns.Text(partition_key=True)
|
||||
signal_name = columns.Text(partition_key=True)
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
index = columns.Integer(primary_key=True)
|
||||
|
||||
|
||||
class SignalDataInt(SignalDataBase):
|
||||
data = columns.Integer()
|
||||
|
||||
|
||||
class SignalDataFloat(SignalDataBase):
|
||||
data = columns.Float()
|
||||
|
||||
|
||||
class SignalDataString(SignalDataBase):
|
||||
data = columns.Text()
|
||||
|
||||
|
||||
class SignalDataBlob(SignalDataBase):
|
||||
data = columns.Blob()
|
||||
|
||||
|
||||
class SignalDataBool(SignalDataBase):
|
||||
data = columns.Boolean()
|
||||
|
||||
|
||||
class SignalDataDateTime(SignalDataBase):
|
||||
data = columns.DateTime()
|
||||
|
||||
|
||||
class SignalDataUUID(SignalDataBase):
|
||||
data = columns.UUID()
|
0
backend/bec_atlas/ingestor/__init__.py
Normal file
0
backend/bec_atlas/ingestor/__init__.py
Normal file
171
backend/bec_atlas/ingestor/data_ingestor.py
Normal file
171
backend/bec_atlas/ingestor/data_ingestor.py
Normal file
@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from bec_lib import messages
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
from redis import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import ScanStatus
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
class DataIngestor:
|
||||
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.config = config
|
||||
self.datasource = MongoDBDatasource(config=self.config["mongodb"])
|
||||
self.datasource.connect(include_setup=False)
|
||||
|
||||
redis_host = config.get("redis", {}).get("host", "localhost")
|
||||
redis_port = config.get("redis", {}).get("port", 6380)
|
||||
self.redis = Redis(host=redis_host, port=redis_port)
|
||||
|
||||
self.shutdown_event = threading.Event()
|
||||
self.available_deployments = {}
|
||||
self.deployment_listener_thread = None
|
||||
self.receiver_thread = None
|
||||
self.consumer_name = f"ingestor_{os.getpid()}"
|
||||
self.create_consumer_group()
|
||||
self.start_deployment_listener()
|
||||
self.start_receiver()
|
||||
|
||||
def create_consumer_group(self):
|
||||
"""
|
||||
Create the consumer group for the ingestor.
|
||||
|
||||
"""
|
||||
try:
|
||||
self.redis.xgroup_create(
|
||||
name="internal/database_ingest", groupname="ingestor", mkstream=True
|
||||
)
|
||||
except ResponseError as exc:
|
||||
if "BUSYGROUP Consumer Group name already exists" in str(exc):
|
||||
logger.info("Consumer group already exists.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
def start_deployment_listener(self):
|
||||
"""
|
||||
Start the listener for the available deployments.
|
||||
|
||||
"""
|
||||
out = self.redis.get("deployments")
|
||||
if out:
|
||||
self.available_deployments = out
|
||||
self.deployment_listener_thread = threading.Thread(
|
||||
target=self.update_available_deployments, name="deployment_listener"
|
||||
)
|
||||
self.deployment_listener_thread.start()
|
||||
|
||||
def start_receiver(self):
|
||||
"""
|
||||
Start the receiver for the Redis queue.
|
||||
|
||||
"""
|
||||
self.receiver_thread = threading.Thread(target=self.ingestor_loop, name="receiver")
|
||||
self.receiver_thread.start()
|
||||
|
||||
def update_available_deployments(self):
|
||||
"""
|
||||
Update the available deployments from the Redis queue.
|
||||
"""
|
||||
sub = self.redis.pubsub()
|
||||
sub.subscribe("deployments")
|
||||
while not self.shutdown_event.is_set():
|
||||
out = sub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
if out:
|
||||
logger.info(f"Updating available deployments: {out}")
|
||||
self.available_deployments = out
|
||||
sub.close()
|
||||
|
||||
def ingestor_loop(self):
|
||||
"""
|
||||
The main loop for the ingestor.
|
||||
|
||||
"""
|
||||
while not self.shutdown_event.is_set():
|
||||
data = self.redis.xreadgroup(
|
||||
groupname="ingestor",
|
||||
consumername=self.consumer_name,
|
||||
streams={"internal/database_ingest": ">"},
|
||||
block=1000,
|
||||
)
|
||||
|
||||
if not data:
|
||||
logger.debug("No messages to ingest.")
|
||||
continue
|
||||
|
||||
for stream, msgs in data:
|
||||
for message in msgs:
|
||||
msg_dict = MsgpackSerialization.loads(message[1])
|
||||
self.handle_message(msg_dict)
|
||||
self.redis.xack(stream, "ingestor", message[0])
|
||||
|
||||
def handle_message(self, msg_dict: dict):
|
||||
"""
|
||||
Handle a message from the Redis queue.
|
||||
|
||||
Args:
|
||||
msg_dict (dict): The message dictionary.
|
||||
parent (DataIngestor): The parent object.
|
||||
|
||||
"""
|
||||
data = msg_dict.get("data")
|
||||
if data is None:
|
||||
return
|
||||
deployment = msg_dict.get("deployment")
|
||||
if deployment is None:
|
||||
return
|
||||
|
||||
if not deployment == self.available_deployments.get(deployment):
|
||||
return
|
||||
|
||||
if isinstance(data, messages.ScanStatusMessage):
|
||||
self.update_scan_status(data)
|
||||
|
||||
def update_scan_status(self, msg: messages.ScanStatusMessage):
|
||||
"""
|
||||
Update the status of a scan in the database. If the scan does not exist, create it.
|
||||
|
||||
Args:
|
||||
msg (messages.ScanStatusMessage): The message containing the scan status.
|
||||
|
||||
"""
|
||||
if not hasattr(msg, "session_id"):
|
||||
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
||||
session_id = msg.info.get("session_id")
|
||||
else:
|
||||
session_id = msg.session_id
|
||||
if not session_id:
|
||||
return
|
||||
# scans are indexed by the scan_id, hence we can use find_one and search by the ObjectId
|
||||
data = self.datasource.db["scans"].find_one({"_id": msg.scan_id})
|
||||
if data is None:
|
||||
msg_conv = ScanStatus(
|
||||
owner_groups=["admin"], access_groups=["admin"], **msg.model_dump()
|
||||
)
|
||||
msg_conv._id = msg_conv.scan_id
|
||||
|
||||
# TODO for compatibility with the old message format; remove once the bec_lib is updated
|
||||
out = msg_conv.__dict__
|
||||
out["session_id"] = session_id
|
||||
|
||||
self.datasource.db["scans"].insert_one(out)
|
||||
else:
|
||||
self.datasource.db["scans"].update_one(
|
||||
{"_id": msg.scan_id}, {"$set": {"status": msg.status}}
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self.shutdown_event.set()
|
||||
if self.deployment_listener_thread:
|
||||
self.deployment_listener_thread.join()
|
||||
if self.receiver_thread:
|
||||
self.receiver_thread.join()
|
||||
self.datasource.shutdown()
|
@ -3,11 +3,17 @@ import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.router.deployments_router import DeploymentsRouter
|
||||
from bec_atlas.router.realm_router import RealmRouter
|
||||
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
|
||||
from bec_atlas.router.scan_router import ScanRouter
|
||||
from bec_atlas.router.user import UserRouter
|
||||
from bec_atlas.router.user_router import UserRouter
|
||||
|
||||
CONFIG = {"redis": {"host": "localhost", "port": 6380}, "scylla": {"hosts": ["localhost"]}}
|
||||
CONFIG = {
|
||||
"redis": {"host": "localhost", "port": 6379},
|
||||
"scylla": {"hosts": ["localhost"]},
|
||||
"mongodb": {"host": "localhost", "port": 27017},
|
||||
}
|
||||
|
||||
|
||||
class AtlasApp:
|
||||
@ -35,11 +41,14 @@ class AtlasApp:
|
||||
def add_routers(self):
|
||||
if not self.datasources.datasources:
|
||||
raise ValueError("Datasources not loaded")
|
||||
if "scylla" in self.datasources.datasources:
|
||||
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.scan_router.router)
|
||||
self.app.include_router(self.scan_router.router, tags=["Scan"])
|
||||
self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.user_router.router)
|
||||
self.app.include_router(self.user_router.router, tags=["User"])
|
||||
self.deployment_router = DeploymentsRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.deployment_router.router, tags=["Deployment"])
|
||||
self.realm_router = RealmRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.realm_router.router, tags=["Realm"])
|
||||
|
||||
if "redis" in self.datasources.datasources:
|
||||
self.redis_websocket = RedisWebsocket(
|
||||
|
1
backend/bec_atlas/model/__init__.py
Normal file
1
backend/bec_atlas/model/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import *
|
144
backend/bec_atlas/model/model.py
Normal file
144
backend/bec_atlas/model/model.py
Normal file
@ -0,0 +1,144 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from bec_lib import messages
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class AccessProfile(BaseModel):
|
||||
owner_groups: list[str]
|
||||
access_groups: list[str] = []
|
||||
|
||||
|
||||
class ScanStatus(AccessProfile, messages.ScanStatusMessage):
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class UserCredentials(AccessProfile):
|
||||
user_id: str | ObjectId
|
||||
password: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class User(AccessProfile):
|
||||
id: str | ObjectId | None = Field(default=None, alias="_id")
|
||||
email: str
|
||||
groups: list[str]
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
email: str
|
||||
groups: list[str]
|
||||
|
||||
|
||||
class Deployments(AccessProfile):
|
||||
realm_id: str
|
||||
name: str
|
||||
deployment_key: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
active_session_id: str | None = None
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Experiments(AccessProfile):
|
||||
realm_id: str
|
||||
pgroup: str
|
||||
proposal: str
|
||||
text: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class StateCondition(AccessProfile):
|
||||
realm_id: str
|
||||
name: str
|
||||
description: str
|
||||
device: str
|
||||
signal_value: str
|
||||
signal_type: str
|
||||
tolerance: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class State(AccessProfile):
|
||||
realm_id: str
|
||||
name: str
|
||||
description: str
|
||||
conditions: list[str]
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Session(AccessProfile):
|
||||
realm_id: str
|
||||
session_id: str
|
||||
config: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Realm(AccessProfile):
|
||||
realm_id: str
|
||||
deployments: list[Deployments] = []
|
||||
name: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Datasets(AccessProfile):
|
||||
realm_id: str
|
||||
dataset_id: str
|
||||
name: str
|
||||
description: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class DatasetUserData(AccessProfile):
|
||||
dataset_id: str
|
||||
name: str
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ScanUserData(AccessProfile):
|
||||
scan_id: str
|
||||
name: str
|
||||
rating: int
|
||||
comments: str
|
||||
preview: bytes
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class DeviceConfig(AccessProfile):
|
||||
device_name: str
|
||||
readout_priority: Literal["monitored", "baseline", "on_request", "async", "continuous"]
|
||||
device_config: dict
|
||||
device_class: str
|
||||
tags: list[str] = []
|
||||
software_trigger: bool
|
||||
|
||||
|
||||
class SignalData(AccessProfile):
|
||||
scan_id: str
|
||||
device_id: str
|
||||
device_name: str
|
||||
signal_name: str
|
||||
data: float | int | str | bool | bytes | dict | list | None
|
||||
timestamp: float
|
||||
kind: Literal["hinted", "omitted", "normal", "config"]
|
||||
|
||||
|
||||
class DeviceData(AccessProfile):
|
||||
scan_id: str | None
|
||||
device_name: str
|
||||
device_config_id: str
|
||||
signals: list[SignalData]
|
54
backend/bec_atlas/router/deployments_router.py
Normal file
54
backend/bec_atlas/router/deployments_router.py
Normal file
@ -0,0 +1,54 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import Deployments, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
class DeploymentsRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/deployments/realm/{realm}",
|
||||
self.deployments,
|
||||
methods=["GET"],
|
||||
description="Get all deployments for the realm",
|
||||
response_model=list[Deployments],
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/deployments/id/{deployment_id}",
|
||||
self.deployment_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single deployment by id for a realm",
|
||||
response_model=Deployments,
|
||||
)
|
||||
|
||||
async def deployments(
|
||||
self, realm: str, current_user: UserInfo = Depends(get_current_user)
|
||||
) -> list[Deployments]:
|
||||
"""
|
||||
Get all deployments for a realm.
|
||||
|
||||
Args:
|
||||
realm (str): The realm id
|
||||
|
||||
Returns:
|
||||
list[Deployments]: List of deployments for the realm
|
||||
"""
|
||||
return self.db.find("deployments", {"realm_id": realm}, Deployments, user=current_user)
|
||||
|
||||
async def deployment_with_id(
|
||||
self, deployment_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get deployment with id from realm
|
||||
|
||||
Args:
|
||||
scan_id (str): The scan id
|
||||
"""
|
||||
return self.db.find_one(
|
||||
"deployments", {"_id": deployment_id}, Deployments, user=current_user
|
||||
)
|
47
backend/bec_atlas/router/realm_router.py
Normal file
47
backend/bec_atlas/router/realm_router.py
Normal file
@ -0,0 +1,47 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import Realm
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
class RealmRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route(
|
||||
"/realms",
|
||||
self.realms,
|
||||
methods=["GET"],
|
||||
description="Get all deployments for the realm",
|
||||
response_model=list[Realm],
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/realms/{realm_id}",
|
||||
self.realm_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single deployment by id for a realm",
|
||||
response_model=Realm,
|
||||
)
|
||||
|
||||
async def realms(self) -> list[Realm]:
|
||||
"""
|
||||
Get all realms.
|
||||
|
||||
Returns:
|
||||
list[Realm]: List of realms
|
||||
"""
|
||||
return self.db.find("realms", {}, Realm)
|
||||
|
||||
async def realm_with_id(self, realm_id: str):
|
||||
"""
|
||||
Get realm with id.
|
||||
|
||||
Args:
|
||||
realm_id (str): The realm id
|
||||
|
||||
Returns:
|
||||
Realm: The realm with the id
|
||||
"""
|
||||
return self.db.find_one("realms", {"_id": realm_id}, Realm)
|
@ -1,20 +1,47 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model.model import ScanStatus, UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
class ScanRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.scylla = self.datasources.datasources.get("scylla")
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/scan", self.scan, methods=["GET"])
|
||||
self.router.add_api_route("/scan/{scan_id}", self.scan_with_id, methods=["GET"])
|
||||
self.router.add_api_route(
|
||||
"/scans/session/{session_id}",
|
||||
self.scans,
|
||||
methods=["GET"],
|
||||
description="Get all scans for a session",
|
||||
response_model=list[ScanStatus],
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/scans/id/{scan_id}",
|
||||
self.scans_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single scan by id for a session",
|
||||
response_model=ScanStatus,
|
||||
)
|
||||
|
||||
async def scan(self, current_user: schema.User = Depends(get_current_user)):
|
||||
return self.scylla.get("scan", current_user=current_user)
|
||||
async def scans(
|
||||
self, session_id: str, current_user: UserInfo = Depends(get_current_user)
|
||||
) -> list[ScanStatus]:
|
||||
"""
|
||||
Get all scans for a session.
|
||||
|
||||
async def scan_with_id(self, scan_id: str):
|
||||
return {"scan_id": scan_id}
|
||||
Args:
|
||||
session_id (str): The session id
|
||||
"""
|
||||
return self.db.find("scans", {"session_id": session_id}, ScanStatus)
|
||||
|
||||
async def scans_with_id(self, scan_id: str, current_user: UserInfo = Depends(get_current_user)):
|
||||
"""
|
||||
Get scan with id from session
|
||||
|
||||
Args:
|
||||
scan_id (str): The scan id
|
||||
"""
|
||||
return self.db.find_one("scans", {"_id": scan_id}, ScanStatus)
|
||||
|
@ -6,7 +6,8 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
|
||||
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
from bec_atlas.model import UserInfo
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
|
||||
|
||||
@ -18,7 +19,7 @@ class UserLoginRequest(BaseModel):
|
||||
class UserRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.scylla = self.datasources.datasources.get("scylla")
|
||||
self.db: MongoDBDatasource = self.datasources.datasources.get("mongodb")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/user/me", self.user_me, methods=["GET"])
|
||||
self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[])
|
||||
@ -26,11 +27,11 @@ class UserRouter(BaseRouter):
|
||||
"/user/login/form", self.form_login, methods=["POST"], dependencies=[]
|
||||
)
|
||||
|
||||
async def user_me(self, user: schema.User = Depends(get_current_user)):
|
||||
data = schema.User.objects.filter(email=user.email)
|
||||
if data.count() == 0:
|
||||
async def user_me(self, user: UserInfo = Depends(get_current_user)):
|
||||
data = self.db.get_user_by_email(user.email)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return data.first()
|
||||
return data
|
||||
|
||||
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
user_login = UserLoginRequest(username=form_data.username, password=form_data.password)
|
||||
@ -38,15 +39,13 @@ class UserRouter(BaseRouter):
|
||||
return {"access_token": out, "token_type": "bearer"}
|
||||
|
||||
async def user_login(self, user_login: UserLoginRequest):
|
||||
result = schema.User.objects.filter(email=user_login.username)
|
||||
if result.count() == 0:
|
||||
user = self.db.get_user_by_email(user_login.username)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user: schema.User = result.first()
|
||||
credentials = schema.UserCredentials.objects.filter(user_id=user.user_id)
|
||||
if credentials.count() == 0:
|
||||
credentials = self.db.get_user_credentials(user.id)
|
||||
if credentials is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user_credentials = credentials.first()
|
||||
if not verify_password(user_login.password, user_credentials.password):
|
||||
if not verify_password(user_login.password, credentials.password):
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
return create_access_token(data={"groups": list(user.groups), "email": user.email})
|
38
backend/bec_atlas/utils/demo_database_setup.py
Normal file
38
backend/bec_atlas/utils/demo_database_setup.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pymongo
|
||||
|
||||
from bec_atlas.model import Deployments, Realm
|
||||
|
||||
|
||||
class DemoSetupLoader:
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.client = pymongo.MongoClient(config.get("host"), config.get("port"))
|
||||
self.db = self.client["bec_atlas"]
|
||||
self.data = {}
|
||||
|
||||
def load(self):
|
||||
self.load_realm()
|
||||
self.load_deployments()
|
||||
|
||||
def load_realm(self):
|
||||
realm = Realm(realm_id="demo_beamline_1", name="Demo Beamline 1", owner_groups=["admin"])
|
||||
realm._id = realm.realm_id
|
||||
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
|
||||
self.db["realms"].insert_one(realm.__dict__)
|
||||
|
||||
realm = Realm(realm_id="demo_beamline_2", name="Demo Beamline 2", owner_groups=["admin"])
|
||||
realm._id = realm.realm_id
|
||||
if self.db["realms"].find_one({"realm_id": realm.realm_id}) is None:
|
||||
self.db["realms"].insert_one(realm.__dict__)
|
||||
|
||||
def load_deployments(self):
|
||||
deployment = Deployments(
|
||||
realm_id="demo_beamline_1", name="Demo Deployment 1", owner_groups=["admin", "demo"]
|
||||
)
|
||||
self.db["deployments"].insert_one(deployment.__dict__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = DemoSetupLoader({"host": "localhost", "port": 27017})
|
||||
loader.load()
|
@ -26,6 +26,7 @@ def wait_for_scylladb(scylla_host: str = SCYLLA_HOST, scylla_port: int = SCYLLA_
|
||||
print("Connected to ScyllaDB")
|
||||
return session
|
||||
except Exception as e:
|
||||
# breakpoint()
|
||||
print(f"ScyllaDB not ready yet: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
|
@ -21,6 +21,8 @@ dependencies = [
|
||||
"python-socketio[asyncio_client]",
|
||||
"libtmux",
|
||||
"websocket-client",
|
||||
"pydantic",
|
||||
"pymongo",
|
||||
]
|
||||
|
||||
|
||||
@ -32,6 +34,7 @@ dev = [
|
||||
"pytest~=8.0",
|
||||
"pytest-docker",
|
||||
"isort~=5.13, >=5.13.2",
|
||||
"pytest-asyncio",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
32
backend/tests/Dockerfile.run_pytest
Normal file
32
backend/tests/Dockerfile.run_pytest
Normal file
@ -0,0 +1,32 @@
|
||||
# set base image (host OS)
|
||||
ARG PY_VERSION=3.10 CI_DEPENDENCY_PROXY_GROUP_IMAGE_PREFIX
|
||||
|
||||
FROM python:${PY_VERSION}
|
||||
|
||||
ARG BEC_ATLAS_BRANCH=main BEC_CORE_BRANCH=main
|
||||
|
||||
RUN echo "Building BEC Atlas environment for branch ${BEC_ATLAS_BRANCH} with BEC branch ${BEC_CORE_BRANCH}"
|
||||
|
||||
RUN apt update
|
||||
RUN apt install git -y
|
||||
RUN apt install netcat-openbsd -y
|
||||
|
||||
# set the working directory in the container
|
||||
WORKDIR /code
|
||||
|
||||
# clone the bec repo
|
||||
RUN git clone --branch ${BEC_CORE_BRANCH} https://gitlab.psi.ch/bec/bec.git
|
||||
|
||||
WORKDIR /code/bec/
|
||||
RUN pip install -e bec_lib[dev]
|
||||
|
||||
WORKDIR /code
|
||||
RUN git clone --branch ${BEC_ATLAS_BRANCH} https://gitlab.psi.ch/bec/bec_atlas.git
|
||||
|
||||
WORKDIR /code/bec_atlas
|
||||
RUN pip install -e ./backend[dev]
|
||||
|
||||
RUN mkdir -p /code/bec_atlas/test_files
|
||||
|
||||
# command to run on container start
|
||||
ENTRYPOINT [ "./backend/tests/coverage_run.sh" ]
|
@ -3,11 +3,11 @@ import os
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from bec_atlas.main import AtlasApp
|
||||
from bec_atlas.utils.setup_database import setup_database
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_docker.plugin import DockerComposeExecutor, Services
|
||||
|
||||
from bec_atlas.main import AtlasApp
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
@ -101,42 +101,39 @@ def docker_services(
|
||||
yield docker_service
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scylla_container(docker_ip, docker_services):
|
||||
host = docker_ip
|
||||
if os.path.exists("/.dockerenv"):
|
||||
# if we are running in the CI, scylla was started as 'scylla' service
|
||||
host = "scylla"
|
||||
if docker_services is None:
|
||||
port = 9042
|
||||
else:
|
||||
port = docker_services.port_for("scylla", 9042)
|
||||
|
||||
setup_database(host=host, port=port)
|
||||
return host, port
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def redis_container(docker_ip, docker_services):
|
||||
host = docker_ip
|
||||
if os.path.exists("/.dockerenv"):
|
||||
# if we are running in the CI, scylla was started as 'scylla' service
|
||||
host = "redis"
|
||||
if docker_services is None:
|
||||
port = 6380
|
||||
else:
|
||||
port = docker_services.port_for("redis", 6379)
|
||||
|
||||
return host, port
|
||||
return "localhost", port
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def backend(scylla_container, redis_container):
|
||||
scylla_host, scylla_port = scylla_container
|
||||
def mongo_container(docker_ip, docker_services):
|
||||
host = docker_ip
|
||||
if os.path.exists("/.dockerenv"):
|
||||
host = "mongo"
|
||||
if docker_services is None:
|
||||
port = 27017
|
||||
else:
|
||||
port = docker_services.port_for("mongodb", 27017)
|
||||
|
||||
return "localhost", port
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def backend(redis_container, mongo_container):
|
||||
redis_host, redis_port = redis_container
|
||||
mongo_host, mongo_port = mongo_container
|
||||
config = {
|
||||
"scylla": {"hosts": [(scylla_host, scylla_port)]},
|
||||
"redis": {"host": redis_host, "port": redis_port},
|
||||
"mongodb": {"host": mongo_host, "port": mongo_port},
|
||||
}
|
||||
|
||||
app = AtlasApp(config)
|
||||
|
23
backend/tests/coverage_run.sh
Executable file
23
backend/tests/coverage_run.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# check if redis is running on port 6380 and mongodb on port 27017 and scylladb on port 9070
|
||||
if [ "$(nc -z localhost 6380; echo $?)" -ne 0 ]; then
|
||||
echo "Redis is not running on port 6380"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$(nc -z localhost 27017; echo $?)" -ne 0 ]; then
|
||||
echo "MongoDB is not running on port 27017"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
coverage run --concurrency=thread --source=./backend --omit=*/backend/tests/* -m pytest -v --junitxml=./test_files/report.xml --skip-docker --random-order --full-trace ./backend/tests
|
||||
|
||||
EXIT_STATUS=$?
|
||||
|
||||
if [ $EXIT_STATUS -ne 0 ]; then
|
||||
exit $EXIT_STATUS
|
||||
fi
|
||||
|
||||
coverage report
|
||||
coverage xml -o ./test_files/coverage.xml
|
@ -1,10 +1,10 @@
|
||||
version: '2'
|
||||
services:
|
||||
scylla:
|
||||
image: scylladb/scylla:latest
|
||||
ports:
|
||||
- "9070:9042"
|
||||
redis:
|
||||
image: redis:latest
|
||||
ports:
|
||||
- "6380:6379"
|
||||
mongodb:
|
||||
image: mongo:latest
|
||||
ports:
|
||||
- "27017:27017"
|
@ -72,7 +72,7 @@ async def test_redis_websocket_multiple_disconnect_same_sid(backend_client):
|
||||
async def test_redis_websocket_register_wrong_endpoint_raises(backend_client):
|
||||
client, app = backend_client
|
||||
with mock.patch.object(app.redis_websocket.socket, "emit") as emit:
|
||||
app.redis_websocket.socket.handlers["/"]["connect"]("sid")
|
||||
await app.redis_websocket.socket.handlers["/"]["connect"]("sid")
|
||||
await app.redis_websocket.socket.handlers["/"]["register"](
|
||||
"sid", json.dumps({"endpoint": "wrong_endpoint"})
|
||||
)
|
||||
|
107
backend/tests/test_scan_ingestor.py
Normal file
107
backend/tests/test_scan_ingestor.py
Normal file
@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
from bec_lib import messages
|
||||
|
||||
from bec_atlas.ingestor.data_ingestor import DataIngestor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scan_ingestor(backend):
|
||||
client, app = backend
|
||||
app.redis_websocket.users = {}
|
||||
ingestor = DataIngestor(config=app.config)
|
||||
yield ingestor
|
||||
ingestor.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_scan_ingestor_create_scan(scan_ingestor, backend):
|
||||
"""
|
||||
Test that the login endpoint returns a token.
|
||||
"""
|
||||
client, app = backend
|
||||
msg = messages.ScanStatusMessage(
|
||||
metadata={},
|
||||
scan_id="92429a81-4bd4-41c2-82df-eccfaddf3d96",
|
||||
status="open",
|
||||
# session_id="5cc67967-744d-4115-a46b-13246580cb3f",
|
||||
info={
|
||||
"readout_priority": {
|
||||
"monitored": ["bpm3i", "diode", "ftp", "bpm5c", "bpm3x", "bpm3z", "bpm4x"],
|
||||
"baseline": ["ddg1a", "bs1y", "mobdco"],
|
||||
"async": ["eiger", "monitor_async", "waveform"],
|
||||
"continuous": [],
|
||||
"on_request": ["flyer_sim"],
|
||||
},
|
||||
"file_suffix": None,
|
||||
"file_directory": None,
|
||||
"user_metadata": {"sample_name": "testA"},
|
||||
"RID": "5cc67967-744d-4115-a46b-13246580cb3f",
|
||||
"scan_id": "92429a81-4bd4-41c2-82df-eccfaddf3d96",
|
||||
"queue_id": "7d77d976-bee0-4bb8-aabb-2b862b4506ec",
|
||||
"session_id": "5cc67967-744d-4115-a46b-13246580cb3f",
|
||||
"scan_motors": ["samx"],
|
||||
"num_points": 10,
|
||||
"positions": [
|
||||
[-5.0024118137239455],
|
||||
[-3.8913007026128343],
|
||||
[-2.780189591501723],
|
||||
[-1.6690784803906122],
|
||||
[-0.557967369279501],
|
||||
[0.5531437418316097],
|
||||
[1.6642548529427212],
|
||||
[2.775365964053833],
|
||||
[3.886477075164944],
|
||||
[4.9975881862760545],
|
||||
],
|
||||
"scan_name": "line_scan",
|
||||
"scan_type": "step",
|
||||
"scan_number": 2,
|
||||
"dataset_number": 2,
|
||||
"exp_time": 0,
|
||||
"frames_per_trigger": 1,
|
||||
"settling_time": 0,
|
||||
"readout_time": 0,
|
||||
"acquisition_config": {"default": {"exp_time": 0, "readout_time": 0}},
|
||||
"scan_report_devices": ["samx"],
|
||||
"monitor_sync": "bec",
|
||||
"scan_msgs": [
|
||||
"metadata={'file_suffix': None, 'file_directory': None, 'user_metadata': {'sample_name': 'testA'}, 'RID': '5cc67967-744d-4115-a46b-13246580cb3f'} scan_type='line_scan' parameter={'args': {'samx': [-5, 5]}, 'kwargs': {'steps': 10, 'exp_time': 0, 'relative': True, 'system_config': {'file_suffix': None, 'file_directory': None}}} queue='primary'"
|
||||
],
|
||||
"args": {"samx": [-5, 5]},
|
||||
"kwargs": {
|
||||
"steps": 10,
|
||||
"exp_time": 0,
|
||||
"relative": True,
|
||||
"system_config": {"file_suffix": None, "file_directory": None},
|
||||
},
|
||||
},
|
||||
timestamp=1732610545.15924,
|
||||
)
|
||||
scan_ingestor.update_scan_status(msg)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||
)
|
||||
client.headers.update({"Authorization": f"Bearer {response.json()}"})
|
||||
|
||||
session_id = msg.info.get("session_id")
|
||||
scan_id = msg.scan_id
|
||||
response = client.get(f"/api/v1/scans/session/{session_id}")
|
||||
assert response.status_code == 200
|
||||
out = response.json()[0]
|
||||
# assert out["session_id"] == session_id
|
||||
assert out["scan_id"] == scan_id
|
||||
assert out["status"] == "open"
|
||||
|
||||
msg.status = "closed"
|
||||
scan_ingestor.update_scan_status(msg)
|
||||
response = client.get(f"/api/v1/scans/id/{scan_id}")
|
||||
assert response.status_code == 200
|
||||
out = response.json()
|
||||
assert out["status"] == "closed"
|
||||
assert out["scan_id"] == scan_id
|
||||
|
||||
response = client.get(f"/api/v1/scans/session/{session_id}")
|
||||
assert response.status_code == 200
|
||||
out = response.json()
|
||||
assert len(out) == 1
|
Reference in New Issue
Block a user