mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-14 07:01:48 +02:00
feat(backend): basic login and scylladb setup
This commit is contained in:
0
backend/bec_atlas/__init__.py
Normal file
0
backend/bec_atlas/__init__.py
Normal file
66
backend/bec_atlas/authentication.py
Normal file
66
backend/bec_atlas/authentication.py
Normal file
@ -0,0 +1,66 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from pwdlib import PasswordHash
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/login/form")
|
||||
password_hash = PasswordHash.recommended()
|
||||
|
||||
|
||||
def get_secret_key():
|
||||
val = os.getenv("SECRET_KEY", "test_secret")
|
||||
return val
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return password_hash.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, get_secret_key(), algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, get_secret_key(), algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except InvalidTokenError as exc:
|
||||
raise credentials_exception from exc
|
||||
|
||||
|
||||
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> schema.User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
groups = payload.get("groups")
|
||||
email = payload.get("email")
|
||||
if not groups or not email:
|
||||
raise credentials_exception
|
||||
except Exception as exc:
|
||||
raise credentials_exception from exc
|
||||
return schema.User(groups=groups, email=email)
|
0
backend/bec_atlas/datasources/__init__.py
Normal file
0
backend/bec_atlas/datasources/__init__.py
Normal file
24
backend/bec_atlas/datasources/datasource_manager.py
Normal file
24
backend/bec_atlas/datasources/datasource_manager.py
Normal file
@ -0,0 +1,24 @@
|
||||
from bec_atlas.datasources.redis_datasource import RedisDatasource
|
||||
from bec_atlas.datasources.scylladb.scylladb import ScylladbDatasource
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.datasources = {}
|
||||
self.load_datasources()
|
||||
|
||||
def connect(self):
|
||||
for datasource in self.datasources.values():
|
||||
datasource.connect()
|
||||
|
||||
def load_datasources(self):
|
||||
for datasource_name, datasource_config in self.config.items():
|
||||
if datasource_name == "scylla":
|
||||
self.datasources[datasource_name] = ScylladbDatasource(datasource_config)
|
||||
if datasource_name == "redis":
|
||||
self.datasources[datasource_name] = RedisDatasource(datasource_config)
|
||||
|
||||
def shutdown(self):
|
||||
for datasource in self.datasources.values():
|
||||
datasource.shutdown()
|
13
backend/bec_atlas/datasources/redis_datasource.py
Normal file
13
backend/bec_atlas/datasources/redis_datasource.py
Normal file
@ -0,0 +1,13 @@
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
|
||||
class RedisDatasource:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.connector = RedisConnector(f"{config.get('host')}:{config.get('port')}")
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
self.connector.shutdown()
|
0
backend/bec_atlas/datasources/scylladb/__init__.py
Normal file
0
backend/bec_atlas/datasources/scylladb/__init__.py
Normal file
112
backend/bec_atlas/datasources/scylladb/scylladb.py
Normal file
112
backend/bec_atlas/datasources/scylladb/scylladb.py
Normal file
@ -0,0 +1,112 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from bec_atlas.authentication import get_password_hash
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqlengine import columns, connection
|
||||
from cassandra.cqlengine.management import create_keyspace_simple, sync_table
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScylladbDatasource:
|
||||
KEYSPACE = "bec_atlas"
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
def connect(self):
|
||||
self.start_client()
|
||||
self.load_functional_accounts()
|
||||
|
||||
def start_client(self):
|
||||
"""
|
||||
Start the ScyllaDB client by creating a Cluster object and a Session object.
|
||||
"""
|
||||
hosts = self.config.get("hosts")
|
||||
if not hosts:
|
||||
raise ValueError("Hosts are not provided in the configuration")
|
||||
|
||||
#
|
||||
connection.setup(hosts, self.KEYSPACE, protocol_version=3)
|
||||
create_keyspace_simple(self.KEYSPACE, 1)
|
||||
self._sync_tables()
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
def _sync_tables(self):
|
||||
"""
|
||||
Sync the tables with the schema defined in the scylladb_schema.py file.
|
||||
"""
|
||||
sync_table(schema.Realm)
|
||||
sync_table(schema.Deployments)
|
||||
sync_table(schema.Experiments)
|
||||
sync_table(schema.StateCondition)
|
||||
sync_table(schema.State)
|
||||
sync_table(schema.Session)
|
||||
sync_table(schema.Datasets)
|
||||
sync_table(schema.DatasetUserData)
|
||||
sync_table(schema.Scan)
|
||||
sync_table(schema.ScanUserData)
|
||||
sync_table(schema.ScanData)
|
||||
sync_table(schema.SignalDataInt)
|
||||
sync_table(schema.SignalDataFloat)
|
||||
sync_table(schema.SignalDataString)
|
||||
sync_table(schema.SignalDataBool)
|
||||
sync_table(schema.SignalDataBlob)
|
||||
sync_table(schema.SignalDataDateTime)
|
||||
sync_table(schema.SignalDataUUID)
|
||||
sync_table(schema.User)
|
||||
sync_table(schema.UserCredentials)
|
||||
|
||||
def load_functional_accounts(self):
|
||||
"""
|
||||
Load the functional accounts to the database.
|
||||
"""
|
||||
functional_accounts_file = os.path.join(
|
||||
os.path.dirname(__file__), "functional_accounts.json"
|
||||
)
|
||||
with open(functional_accounts_file, "r", encoding="utf-8") as file:
|
||||
functional_accounts = json.load(file)
|
||||
|
||||
for account in functional_accounts:
|
||||
# check if the account already exists in the database
|
||||
password_hash = get_password_hash(account.pop("password"))
|
||||
result = schema.User.objects.filter(email=account["email"])
|
||||
if result.count() > 0:
|
||||
continue
|
||||
user = schema.User.create(**account)
|
||||
|
||||
schema.UserCredentials.create(user_id=user.user_id, password=password_hash)
|
||||
|
||||
def get(self, table_name: str, filter: str = None, parameters: tuple = None):
|
||||
"""
|
||||
Get the data from the specified table.
|
||||
"""
|
||||
# schema.User.objects.get(email=)
|
||||
if filter:
|
||||
query = f"SELECT * FROM {self.KEYSPACE}.{table_name} WHERE {filter};"
|
||||
else:
|
||||
query = f"SELECT * FROM {self.KEYSPACE}.{table_name};"
|
||||
if parameters:
|
||||
return self.session.execute(query, parameters)
|
||||
return self.session.execute(query)
|
||||
|
||||
def post(self, table_name: str, data: BaseModel):
|
||||
"""
|
||||
Post the data to the specified table.
|
||||
|
||||
Args:
|
||||
table_name (str): The name of the table to post the data.
|
||||
data (BaseModel): The data to be posted.
|
||||
|
||||
"""
|
||||
query = f"INSERT INTO {self.KEYSPACE}.{table_name} JSON '{data.model_dump_json(exclude_none=True)}';"
|
||||
return self.session.execute(query)
|
||||
|
||||
def shutdown(self):
|
||||
self.cluster.shutdown()
|
139
backend/bec_atlas/datasources/scylladb/scylladb_schema.py
Normal file
139
backend/bec_atlas/datasources/scylladb/scylladb_schema.py
Normal file
@ -0,0 +1,139 @@
|
||||
import uuid
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.models import Model
|
||||
|
||||
|
||||
class User(Model):
|
||||
email = columns.Text(primary_key=True)
|
||||
user_id = columns.UUID(default=uuid.uuid4())
|
||||
first_name = columns.Text()
|
||||
last_name = columns.Text()
|
||||
groups = columns.Set(columns.Text)
|
||||
created_at = columns.DateTime()
|
||||
updated_at = columns.DateTime()
|
||||
|
||||
|
||||
class UserCredentials(Model):
|
||||
user_id = columns.UUID(primary_key=True)
|
||||
password = columns.Text()
|
||||
|
||||
|
||||
class Realm(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
deployment_id = columns.Text(primary_key=True)
|
||||
name = columns.Text()
|
||||
|
||||
|
||||
class Deployments(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
deployment_id = columns.Text(primary_key=True)
|
||||
name = columns.Text()
|
||||
active_session_id = columns.UUID()
|
||||
|
||||
|
||||
class Experiments(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
pgroup = columns.Text(primary_key=True)
|
||||
proposal = columns.Text()
|
||||
text = columns.Text()
|
||||
|
||||
|
||||
class StateCondition(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
name = columns.Text(primary_key=True)
|
||||
description = columns.Text()
|
||||
device = columns.Text()
|
||||
signal_value = columns.Text()
|
||||
signal_type = columns.Text()
|
||||
tolerance = columns.Text()
|
||||
|
||||
|
||||
class State(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
name = columns.Text(primary_key=True)
|
||||
description = columns.Text()
|
||||
conditions = columns.List(columns.Text)
|
||||
|
||||
|
||||
class Session(Model):
|
||||
realm_id = columns.Text(primary_key=True)
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
config = columns.Text()
|
||||
|
||||
|
||||
class Datasets(Model):
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
dataset_id = columns.UUID(primary_key=True)
|
||||
scan_id = columns.UUID()
|
||||
|
||||
|
||||
class DatasetUserData(Model):
|
||||
dataset_id = columns.UUID(primary_key=True)
|
||||
name = columns.Text()
|
||||
rating = columns.Integer()
|
||||
comments = columns.Text()
|
||||
preview = columns.Blob()
|
||||
|
||||
|
||||
class Scan(Model):
|
||||
session_id = columns.UUID(primary_key=True)
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
scan_number = columns.Integer()
|
||||
name = columns.Text()
|
||||
scan_class = columns.Text()
|
||||
parameters = columns.Text()
|
||||
start_time = columns.DateTime()
|
||||
end_time = columns.DateTime()
|
||||
exit_status = columns.Text()
|
||||
|
||||
|
||||
class ScanUserData(Model):
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
name = columns.Text()
|
||||
rating = columns.Integer()
|
||||
comments = columns.Text()
|
||||
preview = columns.Blob()
|
||||
|
||||
|
||||
class ScanData(Model):
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
device_name = columns.Text(primary_key=True)
|
||||
signal_name = columns.Text(primary_key=True)
|
||||
shape = columns.List(columns.Integer)
|
||||
dtype = columns.Text()
|
||||
|
||||
|
||||
class SignalDataBase(Model):
|
||||
realm_id = columns.Text(partition_key=True)
|
||||
signal_name = columns.Text(partition_key=True)
|
||||
scan_id = columns.UUID(primary_key=True)
|
||||
index = columns.Integer(primary_key=True)
|
||||
|
||||
|
||||
class SignalDataInt(SignalDataBase):
|
||||
data = columns.Integer()
|
||||
|
||||
|
||||
class SignalDataFloat(SignalDataBase):
|
||||
data = columns.Float()
|
||||
|
||||
|
||||
class SignalDataString(SignalDataBase):
|
||||
data = columns.Text()
|
||||
|
||||
|
||||
class SignalDataBlob(SignalDataBase):
|
||||
data = columns.Blob()
|
||||
|
||||
|
||||
class SignalDataBool(SignalDataBase):
|
||||
data = columns.Boolean()
|
||||
|
||||
|
||||
class SignalDataDateTime(SignalDataBase):
|
||||
data = columns.DateTime()
|
||||
|
||||
|
||||
class SignalDataUUID(SignalDataBase):
|
||||
data = columns.UUID()
|
51
backend/bec_atlas/main.py
Normal file
51
backend/bec_atlas/main.py
Normal file
@ -0,0 +1,51 @@
|
||||
import socketio
|
||||
import uvicorn
|
||||
from bec_atlas.datasources.datasource_manager import DatasourceManager
|
||||
from bec_atlas.router.redis_router import RedisRouter, RedisWebsocket
|
||||
from bec_atlas.router.scan_router import ScanRouter
|
||||
from bec_atlas.router.user import UserRouter
|
||||
from fastapi import FastAPI
|
||||
|
||||
CONFIG = {"redis": {"host": "localhost", "port": 6379}, "scylla": {"hosts": ["localhost"]}}
|
||||
|
||||
|
||||
class HorizonApp:
|
||||
API_VERSION = "v1"
|
||||
|
||||
def __init__(self):
|
||||
self.app = FastAPI()
|
||||
self.prefix = f"/api/{self.API_VERSION}"
|
||||
self.datasources = DatasourceManager(config=CONFIG)
|
||||
self.register_event_handler()
|
||||
self.add_routers()
|
||||
|
||||
def register_event_handler(self):
|
||||
self.app.add_event_handler("startup", self.on_startup)
|
||||
self.app.add_event_handler("shutdown", self.on_shutdown)
|
||||
|
||||
async def on_startup(self):
|
||||
self.datasources.connect()
|
||||
|
||||
async def on_shutdown(self):
|
||||
self.datasources.shutdown()
|
||||
|
||||
def add_routers(self):
|
||||
if not self.datasources.datasources:
|
||||
raise ValueError("Datasources not loaded")
|
||||
if "scylla" in self.datasources.datasources:
|
||||
self.scan_router = ScanRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.scan_router.router)
|
||||
self.user_router = UserRouter(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.include_router(self.user_router.router)
|
||||
|
||||
if "redis" in self.datasources.datasources:
|
||||
self.redis_websocket = RedisWebsocket(prefix=self.prefix, datasources=self.datasources)
|
||||
self.app.mount("/", self.redis_websocket.app)
|
||||
|
||||
def run(self):
|
||||
uvicorn.run(self.app, host="localhost", port=8000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
horizon_app = HorizonApp()
|
||||
horizon_app.run()
|
0
backend/bec_atlas/router/__init__.py
Normal file
0
backend/bec_atlas/router/__init__.py
Normal file
4
backend/bec_atlas/router/base_router.py
Normal file
4
backend/bec_atlas/router/base_router.py
Normal file
@ -0,0 +1,4 @@
|
||||
class BaseRouter:
|
||||
def __init__(self, prefix: str = "/api/v1", datasources=None) -> None:
|
||||
self.datasources = datasources
|
||||
self.prefix = prefix
|
98
backend/bec_atlas/router/redis_router.py
Normal file
98
backend/bec_atlas/router/redis_router.py
Normal file
@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import socketio
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
|
||||
class RedisRouter(BaseRouter):
|
||||
"""
|
||||
This class is a router for the Redis API. It exposes the redis client through
|
||||
the API. For pub/sub and stream operations, a websocket connection can be used.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.redis = self.datasources.datasources["redis"].connector
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/redis", self.redis_get, methods=["GET"])
|
||||
self.router.add_api_route("/redis", self.redis_post, methods=["POST"])
|
||||
self.router.add_api_route("/redis", self.redis_delete, methods=["DELETE"])
|
||||
|
||||
async def redis_get(self, key: str):
|
||||
return self.redis.get(key)
|
||||
|
||||
async def redis_post(self, key: str, value: str):
|
||||
return self.redis.set(key, value)
|
||||
|
||||
async def redis_delete(self, key: str):
|
||||
return self.redis.delete(key)
|
||||
|
||||
|
||||
class RedisWebsocket:
|
||||
"""
|
||||
This class is a websocket handler for the Redis API. It exposes the redis client through
|
||||
the websocket.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
self.redis: RedisConnector = datasources.datasources["redis"].connector
|
||||
self.prefix = prefix
|
||||
self.active_connections = set()
|
||||
self.socket = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi")
|
||||
self.app = socketio.ASGIApp(self.socket)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self.socket.on("connect", self.connect_client)
|
||||
self.socket.on("register", self.redis_register)
|
||||
self.socket.on("disconnect", self.disconnect_client)
|
||||
|
||||
def connect_client(self, sid, environ):
|
||||
print("Client connected")
|
||||
self.active_connections.add(sid)
|
||||
|
||||
def disconnect_client(self, sid, _environ):
|
||||
print("Client disconnected")
|
||||
self.active_connections.pop(sid)
|
||||
|
||||
async def redis_register(self, sid: str, msg: str):
|
||||
if sid not in self.active_connections:
|
||||
self.active_connections.add(sid)
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
endpoint = getattr(MessageEndpoints, data.get("endpoint"))
|
||||
|
||||
# check if the endpoint receives arguments
|
||||
if len(inspect.signature(endpoint).parameters) > 1:
|
||||
endpoint = endpoint(data.get("args"))
|
||||
else:
|
||||
endpoint = endpoint()
|
||||
|
||||
self.redis.register(endpoint, cb=self.on_redis_message, parent=self)
|
||||
await self.socket.enter_room(sid, endpoint.endpoint)
|
||||
await self.socket.emit("registered", data={"endpoint": endpoint.endpoint}, room=sid)
|
||||
|
||||
@staticmethod
|
||||
def on_redis_message(message, parent):
|
||||
async def emit_message(message):
|
||||
outgoing = {
|
||||
"data": message.value.model_dump_json(),
|
||||
"message_type": message.value.__class__.__name__,
|
||||
}
|
||||
await parent.socket.emit("new_message", data=outgoing, room=message.topic)
|
||||
|
||||
# check that the event loop is running
|
||||
if not parent.loop.is_running():
|
||||
parent.loop.run_until_complete(emit_message(message))
|
||||
else:
|
||||
asyncio.run_coroutine_threadsafe(emit_message(message), parent.loop)
|
19
backend/bec_atlas/router/scan_router.py
Normal file
19
backend/bec_atlas/router/scan_router.py
Normal file
@ -0,0 +1,19 @@
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.models import User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
|
||||
class ScanRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.scylla = self.datasources.datasources.get("scylla")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/scan", self.scan, methods=["GET"])
|
||||
self.router.add_api_route("/scan/{scan_id}", self.scan_with_id, methods=["GET"])
|
||||
|
||||
async def scan(self, current_user: User = Depends(get_current_user)):
|
||||
return self.scylla.get("scan", current_user=current_user)
|
||||
|
||||
async def scan_with_id(self, scan_id: str):
|
||||
return {"scan_id": scan_id}
|
45
backend/bec_atlas/router/user.py
Normal file
45
backend/bec_atlas/router/user.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Annotated
|
||||
|
||||
from bec_atlas.authentication import create_access_token, get_current_user, verify_password
|
||||
from bec_atlas.datasources.scylladb import scylladb_schema as schema
|
||||
from bec_atlas.models import User
|
||||
from bec_atlas.router.base_router import BaseRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
|
||||
class UserRouter(BaseRouter):
|
||||
def __init__(self, prefix="/api/v1", datasources=None):
|
||||
super().__init__(prefix, datasources)
|
||||
self.scylla = self.datasources.datasources.get("scylla")
|
||||
self.router = APIRouter(prefix=prefix)
|
||||
self.router.add_api_route("/user/me", self.user_me, methods=["GET"])
|
||||
self.router.add_api_route("/user/login", self.user_login, methods=["POST"], dependencies=[])
|
||||
self.router.add_api_route(
|
||||
"/user/login/form", self.form_login, methods=["POST"], dependencies=[]
|
||||
)
|
||||
|
||||
async def user_me(self, user: User = Depends(get_current_user)):
|
||||
data = schema.User.objects.filter(email=user.email)
|
||||
if data.count() == 0:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return data.first()
|
||||
|
||||
async def form_login(self, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
out = await self.user_login(form_data.username, form_data.password)
|
||||
return {"access_token": out, "token_type": "bearer"}
|
||||
|
||||
async def user_login(self, username: str, password: str):
|
||||
result = schema.User.objects.filter(email=username)
|
||||
if result.count() == 0:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user: schema.User = result.first()
|
||||
credentials = schema.UserCredentials.objects.filter(user_id=user.user_id)
|
||||
if credentials.count() == 0:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user_credentials = credentials.first()
|
||||
if not verify_password(password, user_credentials.password):
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
return create_access_token(data={"groups": list(user.groups), "email": user.email})
|
92
backend/pyproject.toml
Normal file
92
backend/pyproject.toml
Normal file
@ -0,0 +1,92 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bec_atlas"
|
||||
version = "0.0.0"
|
||||
description = "BEC Atlas"
|
||||
requires-python = ">=3.10"
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Topic :: Scientific/Engineering",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi[standard]",
|
||||
"pyjwt",
|
||||
"pwdlib[argon2]",
|
||||
"scylla-driver",
|
||||
"bec_lib",
|
||||
"python-socketio[asyncio_client]",
|
||||
]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"coverage~=7.0",
|
||||
"pytest-random-order~=1.1",
|
||||
"pytest-timeout~=2.2",
|
||||
"pytest~=8.0",
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
"Bug Tracker" = "https://gitlab.psi.ch/bec/bec_atlas/issues"
|
||||
Homepage = "https://gitlab.psi.ch/bec/bec_atlas"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
include = ["*"]
|
||||
exclude = ["docs/**", "tests/**"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["*"]
|
||||
exclude = ["docs/**", "tests/**"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 100
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
known_first_party = ["bec_widgets"]
|
||||
|
||||
[tool.semantic_release]
|
||||
build_command = "python -m build"
|
||||
version_toml = ["pyproject.toml:project.version"]
|
||||
|
||||
[tool.semantic_release.commit_author]
|
||||
env = "GIT_COMMIT_AUTHOR"
|
||||
default = "semantic-release <semantic-release>"
|
||||
|
||||
[tool.semantic_release.commit_parser_options]
|
||||
allowed_tags = [
|
||||
"build",
|
||||
"chore",
|
||||
"ci",
|
||||
"docs",
|
||||
"feat",
|
||||
"fix",
|
||||
"perf",
|
||||
"style",
|
||||
"refactor",
|
||||
"test",
|
||||
]
|
||||
minor_tags = ["feat"]
|
||||
patch_tags = ["fix", "perf"]
|
||||
default_bump_level = 0
|
||||
|
||||
[tool.semantic_release.remote]
|
||||
name = "origin"
|
||||
type = "gitlab"
|
||||
ignore_token_for_push = false
|
||||
|
||||
[tool.semantic_release.remote.token]
|
||||
env = "GL_TOKEN"
|
||||
|
||||
[tool.semantic_release.publish]
|
||||
dist_glob_patterns = ["dist/*"]
|
||||
upload_to_vcs_release = true
|
Reference in New Issue
Block a user