feat(backend): basic login and scylladb setup

This commit is contained in:
2024-11-19 20:35:49 +01:00
parent ee06100c23
commit ff64cbd619
23 changed files with 1569 additions and 0 deletions

View File

View File

@ -0,0 +1,66 @@
import os
from datetime import datetime, timedelta
from typing import Annotated
import jwt
from bec_atlas.datasources.scylladb import scylladb_schema as schema
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form")
password_hash = PasswordHash.recommended()
def get_secret_key():
val = os.getenv("SECRET_KEY", "test_secret")
return val
def verify_password(plain_password, hashed_password):
return password_hash.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, get_secret_key(), algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, get_secret_key(), algorithms=[ALGORITHM])
return payload
except InvalidTokenError as exc:
raise credentials_exception from exc
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> schema.User:
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(token)
groups = payload.get("groups")
email = payload.get("email")
if not groups or not email:
raise credentials_exception
except Exception as exc:
raise credentials_exception from exc
return schema.User(groups=groups, email=email)

View File

@ -0,0 +1,24 @@
from bec_atlas.datasources.redis_datasource import RedisDatasource
from bec_atlas.datasources.scylladb.scylladb import ScylladbDatasource
class DatasourceManager:
def __init__(self, config: dict):
self.config = config
self.datasources = {}
self.load_datasources()
def connect(self):
for datasource in self.datasources.values():
datasource.connect()
def load_datasources(self):
for datasource_name, datasource_config in self.config.items():
if datasource_name == "scylla":
self.datasources[datasource_name] = ScylladbDatasource(datasource_config)
if datasource_name == "redis":
self.datasources[datasource_name] = RedisDatasource(datasource_config)
def shutdown(self):
for datasource in self.datasources.values():
datasource.shutdown()

View File

@ -0,0 +1,13 @@
from bec_lib.redis_connector import RedisConnector
class RedisDatasource:
def __init__(self, config: dict):
self.config = config
self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}")
def connect(self):
pass
def shutdown(self):
self.connector.shutdown()

View File

@ -0,0 +1,112 @@
import json
import os
import uuid
from datetime import datetime
from bec_atlas.authentication import get_password_hash
from bec_atlas.datasources.scylladb import scylladb_schema as schema
from cassandra.cluster import Cluster
from cassandra.cqlengine import columns, connection
from cassandra.cqlengine.management import create_keyspace_simple, sync_table
from pydantic import BaseModel
class ScylladbDatasource:
KEYSPACE = "bec_atlas"
def __init__(self, config):
self.config = config
self.cluster = None
self.session = None
def connect(self):
self.start_client()
self.load_functional_accounts()
def start_client(self):
"""
Start the ScyllaDB client by creating a Cluster object and a Session object.
"""
hosts = self.config.get("hosts")
if not hosts:
raise ValueError("Hosts are not provided in the configuration")
#
connection.setup(hosts, self.KEYSPACE, protocol_version=3)
create_keyspace_simple(self.KEYSPACE, 1)
self._sync_tables()
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
def _sync_tables(self):
"""
Sync the tables with the schema defined in the scylladb_schema.py file.
"""
sync_table(schema.Realm)
sync_table(schema.Deployments)
sync_table(schema.Experiments)
sync_table(schema.StateCondition)
sync_table(schema.State)
sync_table(schema.Session)
sync_table(schema.Datasets)
sync_table(schema.DatasetUserData)
sync_table(schema.Scan)
sync_table(schema.ScanUserData)
sync_table(schema.ScanData)
sync_table(schema.SignalDataInt)
sync_table(schema.SignalDataFloat)
sync_table(schema.SignalDataString)
sync_table(schema.SignalDataBool)
sync_table(schema.SignalDataBlob)
sync_table(schema.SignalDataDateTime)
sync_table(schema.SignalDataUUID)
sync_table(schema.User)
sync_table(schema.UserCredentials)
def load_functional_accounts(self):
"""
Load the functional accounts to the database.
"""
functional_accounts_file = os.path.join(
os.path.dirname(__file__), "functional_accounts.json"
)
with open(functional_accounts_file, "r", encoding="utf-8") as file:
functional_accounts = json.load(file)
for account in functional_accounts:
# check if the account already exists in the database
password_hash = get_password_hash(account.pop("password"))
result = schema.User.objects.filter(email=account["email"])
if result.count() > 0:
continue
user = schema.User.create(**account)
schema.UserCredentials.create(user_id=user.user_id, password=password_hash)
def get(self, table_name: str, filter: str = None, parameters: tuple = None):
"""
Get the data from the specified table.
"""
# schema.User.objects.get(email=)
if filter:
query = f"SELECT * FROM {self.KEYSPACE}.{table_name} WHERE {filter};"
else:
query = f"SELECT * FROM {self.KEYSPACE}.{table_name};"
if parameters:
return self.session.execute(query, parameters)
return self.session.execute(query)
def post(self, table_name: str, data: BaseModel):
"""
Post the data to the specified table.
Args:
table_name (str): The name of the table to post the data.
data (BaseModel): The data to be posted.
"""
query = f"INSERT INTO {self.KEYSPACE}.{table_name} JSON '{data.model_dump_json(exclude_none=True)}';"
return self.session.execute(query)
def shutdown(self):
self.cluster.shutdown()

View File

@ -0,0 +1,139 @@
import uuid
from cassandra.cqlengine import columns
from cassandra.cqlengine.models import Model
class User(Model):
email = columns.Text(primary_key=True)
user_id = columns.UUID(default=uuid.uuid4())
first_name = columns.Text()
last_name = columns.Text()
groups = columns.Set(columns.Text)
created_at = columns.DateTime()
updated_at = columns.DateTime()
class UserCredentials(Model):
user_id = columns.UUID(primary_key=True)
password = columns.Text()
class Realm(Model):
realm_id = columns.Text(primary_key=True)
deployment_id = columns.Text(primary_key=True)
name = columns.Text()
class Deployments(Model):
realm_id = columns.Text(primary_key=True)
deployment_id = columns.Text(primary_key=True)
name = columns.Text()
active_session_id = columns.UUID()
class Experiments(Model):
realm_id = columns.Text(primary_key=True)
pgroup = columns.Text(primary_key=True)
proposal = columns.Text()
text = columns.Text()
class StateCondition(Model):
realm_id = columns.Text(primary_key=True)
name = columns.Text(primary_key=True)
description = columns.Text()
device = columns.Text()
signal_value = columns.Text()
signal_type = columns.Text()
tolerance = columns.Text()
class State(Model):
realm_id = columns.Text(primary_key=True)
name = columns.Text(primary_key=True)
description = columns.Text()
conditions = columns.List(columns.Text)
class Session(Model):
realm_id = columns.Text(primary_key=True)
session_id = columns.UUID(primary_key=True)
config = columns.Text()
class Datasets(Model):
session_id = columns.UUID(primary_key=True)
dataset_id = columns.UUID(primary_key=True)
scan_id = columns.UUID()
class DatasetUserData(Model):
dataset_id = columns.UUID(primary_key=True)
name = columns.Text()
rating = columns.Integer()
comments = columns.Text()
preview = columns.Blob()
class Scan(Model):
session_id = columns.UUID(primary_key=True)
scan_id = columns.UUID(primary_key=True)
scan_number = columns.Integer()
name = columns.Text()
scan_class = columns.Text()
parameters = columns.Text()
start_time = columns.DateTime()
end_time = columns.DateTime()
exit_status = columns.Text()
class ScanUserData(Model):
scan_id = columns.UUID(primary_key=True)
name = columns.Text()
rating = columns.Integer()
comments = columns.Text()
preview = columns.Blob()
class ScanData(Model):
scan_id = columns.UUID(primary_key=True)
device_name = columns.Text(primary_key=True)
signal_name = columns.Text(primary_key=True)
shape = columns.List(columns.Integer)
dtype = columns.Text()
class SignalDataBase(Model):
realm_id = columns.Text(partition_key=True)
signal_name = columns.Text(partition_key=True)
scan_id = columns.UUID(primary_key=True)
index = columns.Integer(primary_key=True)
class SignalDataInt(SignalDataBase):
data = columns.Integer()
class SignalDataFloat(SignalDataBase):
data = columns.Float()
class SignalDataString(SignalDataBase):
data = columns.Text()
class SignalDataBlob(SignalDataBase):
data = columns.Blob()
class SignalDataBool(SignalDataBase):
data = columns.Boolean()
class SignalDataDateTime(SignalDataBase):
data = columns.DateTime()
class SignalDataUUID(SignalDataBase):
data = columns.UUID()

51
backend/bec_atlas/main.py Normal file
View File

@ -0,0 +1,51 @@
import socketio
import uvicorn
from bec_atlas.datasources.datasource_manager import DatasourceManager
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
from bec_atlas.router.scan_router import ScanRouter
from bec_atlas.router.user import UserRouter
from fastapi import FastAPI
CONFIG = {"redis": {"host": "localhost", "port": 6379}, "scylla": {"hosts": ["localhost"]}}
class HorizonApp:
API_VERSION = "v1"
def __init__(self):
self.app = FastAPI()
self.prefix = f"/api/{self.API_VERSION}"
self.datasources = DatasourceManager(config=CONFIG)
self.register_event_handler()
self.add_routers()
def register_event_handler(self):
self.app.add_event_handler("startup", self.on_startup)
self.app.add_event_handler("shutdown", self.on_shutdown)
async def on_startup(self):
self.datasources.connect()
async def on_shutdown(self):
self.datasources.shutdown()
def add_routers(self):
if not self.datasources.datasources:
raise ValueError("Datasources not loaded")
if "scylla" in self.datasources.datasources:
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.scan_router.router)
self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources)
self.app.include_router(self.user_router.router)
if "redis" in self.datasources.datasources:
self.redis_websocket = RedisWebsocket(prefix=self.prefix, datasources=self.datasources)
self.app.mount("/", self.redis_websocket.app)
def run(self):
uvicorn.run(self.app, host="localhost", port=8000)
if __name__ == "__main__":
horizon_app = HorizonApp()
horizon_app.run()

View File

View File

@ -0,0 +1,4 @@
class BaseRouter:
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
self.datasources = datasources
self.prefix = prefix

View File

@ -0,0 +1,98 @@
import asyncio
import inspect
import json
from typing import TYPE_CHECKING
import socketio
from bec_atlas.router.base_router import BaseRouter
from bec_lib.endpoints import MessageEndpoints
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
if TYPE_CHECKING:
from bec_lib.redis_connector import RedisConnector
class RedisRouter(BaseRouter):
"""
This class is a router for the Redis API. It exposes the redis client through
the API. For pub/sub and stream operations, a websocket connection can be used.
"""
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.redis = self.datasources.datasources["redis"].connector
self.router = APIRouter(prefix=prefix)
self.router.add_api_route("/redis", self.redis_get, methods=["GET"])
self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
async def redis_get(self, key: str):
return self.redis.get(key)
async def redis_post(self, key: str, value: str):
return self.redis.set(key, value)
async def redis_delete(self, key: str):
return self.redis.delete(key)
class RedisWebsocket:
"""
This class is a websocket handler for the Redis API. It exposes the redis client through
the websocket.
"""
def __init__(self, prefix="/api/v1", datasources=None):
self.redis: RedisConnector = datasources.datasources["redis"].connector
self.prefix = prefix
self.active_connections = set()
self.socket = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi")
self.app = socketio.ASGIApp(self.socket)
self.loop = asyncio.get_event_loop()
self.socket.on("connect", self.connect_client)
self.socket.on("register", self.redis_register)
self.socket.on("disconnect", self.disconnect_client)
def connect_client(self, sid, environ):
print("Client connected")
self.active_connections.add(sid)
def disconnect_client(self, sid, _environ):
print("Client disconnected")
self.active_connections.pop(sid)
async def redis_register(self, sid: str, msg: str):
if sid not in self.active_connections:
self.active_connections.add(sid)
try:
data = json.loads(msg)
except json.JSONDecodeError:
return
endpoint = getattr(MessageEndpoints, data.get("endpoint"))
# check if the endpoint receives arguments
if len(inspect.signature(endpoint).parameters) > 1:
endpoint = endpoint(data.get("args"))
else:
endpoint = endpoint()
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
await self.socket.enter_room(sid, endpoint.endpoint)
await self.socket.emit("registered", data={"endpoint": endpoint.endpoint}, room=sid)
@staticmethod
def on_redis_message(message, parent):
async def emit_message(message):
outgoing = {
"data": message.value.model_dump_json(),
"message_type": message.value.__class__.__name__,
}
await parent.socket.emit("new_message", data=outgoing, room=message.topic)
# check that the event loop is running
if not parent.loop.is_running():
parent.loop.run_until_complete(emit_message(message))
else:
asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop)

View File

@ -0,0 +1,19 @@
from bec_atlas.authentication import get_current_user
from bec_atlas.models import User
from bec_atlas.router.base_router import BaseRouter
from fastapi import APIRouter, Depends
class ScanRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.scylla = self.datasources.datasources.get("scylla")
self.router = APIRouter(prefix=prefix)
self.router.add_api_route("/scan", self.scan, methods=["GET"])
self.router.add_api_route("/scan/{scan_id}", self.scan_with_id, methods=["GET"])
async def scan(self, current_user: User = Depends(get_current_user)):
return self.scylla.get("scan", current_user=current_user)
async def scan_with_id(self, scan_id: str):
return {"scan_id": scan_id}

View File

@ -0,0 +1,45 @@
from typing import Annotated
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
from bec_atlas.datasources.scylladb import scylladb_schema as schema
from bec_atlas.models import User
from bec_atlas.router.base_router import BaseRouter
from fastapi import APIRouter, Depends
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
class UserRouter(BaseRouter):
def __init__(self, prefix="/api/v1", datasources=None):
super().__init__(prefix, datasources)
self.scylla = self.datasources.datasources.get("scylla")
self.router = APIRouter(prefix=prefix)
self.router.add_api_route("/user/me", self.user_me, methods=["GET"])
self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[])
self.router.add_api_route(
"/user/login/form", self.form_login, methods=["POST"], dependencies=[]
)
async def user_me(self, user: User = Depends(get_current_user)):
data = schema.User.objects.filter(email=user.email)
if data.count() == 0:
raise HTTPException(status_code=404, detail="User not found")
return data.first()
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
out = await self.user_login(form_data.username, form_data.password)
return {"access_token": out, "token_type": "bearer"}
async def user_login(self, username: str, password: str):
result = schema.User.objects.filter(email=username)
if result.count() == 0:
raise HTTPException(status_code=404, detail="User not found")
user: schema.User = result.first()
credentials = schema.UserCredentials.objects.filter(user_id=user.user_id)
if credentials.count() == 0:
raise HTTPException(status_code=404, detail="User not found")
user_credentials = credentials.first()
if not verify_password(password, user_credentials.password):
raise HTTPException(status_code=401, detail="Invalid password")
return create_access_token(data={"groups": list(user.groups), "email": user.email})

92
backend/pyproject.toml Normal file
View File

@ -0,0 +1,92 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "bec_atlas"
version = "0.0.0"
description = "BEC Atlas"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering",
]
dependencies = [
"fastapi[standard]",
"pyjwt",
"pwdlib[argon2]",
"scylla-driver",
"bec_lib",
"python-socketio[asyncio_client]",
]
[project.optional-dependencies]
dev = [
"coverage~=7.0",
"pytest-random-order~=1.1",
"pytest-timeout~=2.2",
"pytest~=8.0",
]
[project.urls]
"Bug Tracker" = "https://gitlab.psi.ch/bec/bec_atlas/issues"
Homepage = "https://gitlab.psi.ch/bec/bec_atlas"
[tool.hatch.build.targets.wheel]
include = ["*"]
exclude = ["docs/**", "tests/**"]
[tool.hatch.build.targets.sdist]
include = ["*"]
exclude = ["docs/**", "tests/**"]
[tool.black]
line-length = 100
skip-magic-trailing-comma = true
[tool.isort]
profile = "black"
line_length = 100
multi_line_output = 3
include_trailing_comma = true
known_first_party = ["bec_widgets"]
[tool.semantic_release]
build_command = "python -m build"
version_toml = ["pyproject.toml:project.version"]
[tool.semantic_release.commit_author]
env = "GIT_COMMIT_AUTHOR"
default = "semantic-release <semantic-release>"
[tool.semantic_release.commit_parser_options]
allowed_tags = [
"build",
"chore",
"ci",
"docs",
"feat",
"fix",
"perf",
"style",
"refactor",
"test",
]
minor_tags = ["feat"]
patch_tags = ["fix", "perf"]
default_bump_level = 0
[tool.semantic_release.remote]
name = "origin"
type = "gitlab"
ignore_token_for_push = false
[tool.semantic_release.remote.token]
env = "GL_TOKEN"
[tool.semantic_release.publish]
dist_glob_patterns = ["dist/*"]
upload_to_vcs_release = true