mirror of
https://github.com/bec-project/bec_atlas.git
synced 2025-07-13 22:51:49 +02:00
tests: added more backend tests for scans
This commit is contained in:
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from bec_atlas.authentication import get_current_user
|
||||
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
|
||||
@ -26,7 +29,7 @@ class ScanRouter(BaseRouter):
|
||||
self.scans_with_id,
|
||||
methods=["GET"],
|
||||
description="Get a single scan by id for a session",
|
||||
response_model=ScanStatusPartial,
|
||||
response_model=ScanStatusPartial | None,
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
self.router.add_api_route(
|
||||
@ -73,19 +76,18 @@ class ScanRouter(BaseRouter):
|
||||
"""
|
||||
|
||||
if fields:
|
||||
fields = {
|
||||
field: 1
|
||||
for field in fields
|
||||
if field in ScanStatusPartial.model_json_schema()["properties"].keys()
|
||||
}
|
||||
fields = self._update_fields(fields)
|
||||
|
||||
if not ObjectId.is_valid(session_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid session ID")
|
||||
filters = {"session_id": session_id}
|
||||
|
||||
if filter:
|
||||
filter = json.loads(filter)
|
||||
filter = self._update_filter(filter)
|
||||
filters.update(filter)
|
||||
|
||||
if sort:
|
||||
sort = json.loads(sort)
|
||||
sort = self._update_sort(sort)
|
||||
|
||||
return self.db.find(
|
||||
"scans",
|
||||
@ -111,15 +113,12 @@ class ScanRouter(BaseRouter):
|
||||
scan_id (str): The scan id
|
||||
"""
|
||||
if fields:
|
||||
fields = {
|
||||
field: 1
|
||||
for field in fields
|
||||
if field in ScanStatusPartial.model_json_schema()["properties"].keys()
|
||||
}
|
||||
fields = self._update_fields(fields)
|
||||
return self.db.find_one(
|
||||
collection="scans",
|
||||
query_filter={"_id": scan_id},
|
||||
dtype=ScanStatusPartial,
|
||||
fields=fields,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
@ -145,7 +144,7 @@ class ScanRouter(BaseRouter):
|
||||
return_document=True,
|
||||
)
|
||||
if out is None:
|
||||
return {"message": "Scan not found."}
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
return {"message": "Scan user data updated."}
|
||||
|
||||
async def count_scans(
|
||||
@ -163,11 +162,81 @@ class ScanRouter(BaseRouter):
|
||||
"""
|
||||
pipeline = []
|
||||
if filter:
|
||||
filter = json.loads(filter)
|
||||
filter = self._update_filter(filter)
|
||||
pipeline.append({"$match": filter})
|
||||
pipeline.append({"$count": "count"})
|
||||
|
||||
out = self.db.aggregate("scans", pipeline=pipeline, dtype=None, user=current_user)
|
||||
if out:
|
||||
return out[0]
|
||||
return {"count": 0}
|
||||
# I don't think this will ever be reached
|
||||
else: # pragma: no cover
|
||||
return {"count": 0}
|
||||
|
||||
def _update_filter(self, filter: str) -> dict:
|
||||
"""
|
||||
Update the filter for the query.
|
||||
|
||||
Args:
|
||||
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
|
||||
|
||||
Returns:
|
||||
dict: The filter for the query
|
||||
"""
|
||||
exc = HTTPException(status_code=400, detail="Invalid filter. Must be a JSON object.")
|
||||
try:
|
||||
filter = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise exc
|
||||
if not isinstance(filter, dict):
|
||||
raise exc
|
||||
return filter
|
||||
|
||||
def _update_fields(self, fields: list[str]) -> dict:
|
||||
"""
|
||||
Update the fields to return in the query.
|
||||
|
||||
Args:
|
||||
fields (list[str]): List of fields to return
|
||||
|
||||
Returns:
|
||||
dict: The fields to return
|
||||
"""
|
||||
exc = HTTPException(
|
||||
status_code=400, detail="Invalid fields. Must be a list of valid fields."
|
||||
)
|
||||
if not all(
|
||||
field in ScanStatusPartial.model_json_schema()["properties"].keys() for field in fields
|
||||
):
|
||||
raise exc
|
||||
fields = {field: 1 for field in fields}
|
||||
return fields
|
||||
|
||||
def _update_sort(self, sort: str) -> dict:
|
||||
"""
|
||||
Update the sort order for the query.
|
||||
|
||||
Args:
|
||||
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
|
||||
'{"name": -1}' for descending order. Multiple fields can be sorted by
|
||||
separating them with a comma, e.g. '{"name": 1, "description": -1}'
|
||||
|
||||
Returns:
|
||||
dict: The sort order
|
||||
"""
|
||||
exc = HTTPException(
|
||||
status_code=400, detail="Invalid sort order. Must be a JSON object with valid keys."
|
||||
)
|
||||
try:
|
||||
sort = json.loads(sort)
|
||||
except json.JSONDecodeError:
|
||||
# pylint: disable=raise-missing-from
|
||||
raise exc
|
||||
if not isinstance(sort, dict):
|
||||
raise exc
|
||||
if not all(
|
||||
key in ScanStatusPartial.model_json_schema()["properties"].keys() for key in sort.keys()
|
||||
):
|
||||
raise exc
|
||||
return sort
|
||||
|
@ -73,13 +73,13 @@ def convert_to_object_id(data):
|
||||
return data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture()
|
||||
def redis_server():
|
||||
redis_server = fakeredis.FakeServer()
|
||||
yield redis_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture()
|
||||
def backend(redis_server):
|
||||
|
||||
def _fake_redis(host, port):
|
||||
|
File diff suppressed because it is too large
Load Diff
383
backend/tests/test_scans_router.py
Normal file
383
backend/tests/test_scans_router.py
Normal file
@ -0,0 +1,383 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_client(backend):
|
||||
client, _ = backend
|
||||
response = client.post(
|
||||
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
token = response.json()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 20
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
return client
|
||||
|
||||
|
||||
def _get_session(client):
|
||||
deployments = client.get(
|
||||
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
|
||||
).json()
|
||||
deployment_id = deployments[0]["_id"]
|
||||
|
||||
response = client.get("/api/v1/sessions", params={"deployment_id": deployment_id})
|
||||
assert response.status_code == 200
|
||||
|
||||
session_id = response.json()[0]["_id"]
|
||||
return session_id
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session(logged_in_client):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns the correct number of scans.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 3
|
||||
|
||||
# this endpoint should enforce the session_id as param
|
||||
response = client.get("/api/v1/scans/session")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_wrong_id(logged_in_client):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns 400 for a wrong session id.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": "wrong_id"})
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Invalid session ID"}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_with_filter(logged_in_client):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns the correct number of scans with a filter.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session",
|
||||
params={"session_id": session_id, "filter": '{"scan_number": 2251}'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 1
|
||||
assert scans[0]["scan_number"] == 2251
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_with_fields(logged_in_client):
|
||||
"""
|
||||
Test that the scans/session endpoint returns the correct fields.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "fields": ["scan_number"]}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 3
|
||||
assert "scan_number" in scans[0]
|
||||
assert "num_points" not in scans[0]
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_with_offset_limit(logged_in_client):
|
||||
"""
|
||||
Test that the scans/session endpoint returns the correct number of scans with offset and limit.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "offset": 0, "limit": 1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 1
|
||||
assert scans[0]["scan_number"] == 2251
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "offset": 1, "limit": 1}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 1
|
||||
assert scans[0]["scan_number"] == 2252
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_with_sort(logged_in_client):
|
||||
"""
|
||||
Test that the scans/session endpoint returns the correct number of scans with sort.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "sort": '{"scan_number": 1}'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 3
|
||||
assert scans[0]["scan_number"] == 2251
|
||||
assert scans[1]["scan_number"] == 2252
|
||||
assert scans[2]["scan_number"] == 2253
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "sort": '{"scan_number": -1}'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
scans = response.json()
|
||||
assert len(scans) == 3
|
||||
assert scans[0]["scan_number"] == 2253
|
||||
assert scans[1]["scan_number"] == 2252
|
||||
assert scans[2]["scan_number"] == 2251
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scans_for_session_with_invalid_sort(logged_in_client):
|
||||
"""
|
||||
Test that the scans/session endpoint returns 400 for an invalid sort order.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "sort": "invalid"}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"detail": "Invalid sort order. Must be a JSON object with valid keys."
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.parametrize(
|
||||
"fields", ["invalid", "{'scan_number': 2251}", 123, [123], ["scan_number", 123]]
|
||||
)
|
||||
def test_get_scans_for_session_with_invalid_fields(logged_in_client, fields):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns 400 for invalid fields.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "fields": fields}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Invalid fields. Must be a list of valid fields."}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.parametrize(
|
||||
"filter", ["invalid", 123, [123], '{"scan_number": 2251', '{"scan_number": 2251}}']
|
||||
)
|
||||
def test_get_scans_for_session_with_invalid_filter(logged_in_client, filter):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns 400 for invalid filter.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/session", params={"session_id": session_id, "filter": filter}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Invalid filter. Must be a JSON object."}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.parametrize(
|
||||
"sort", ["invalid", 123, [123], '{"scan_number": 1', '{"scan_number": 1}}', '{"invalid": 1}']
|
||||
)
|
||||
def test_get_scans_for_session_with_invalid_sort_key(logged_in_client, sort):
|
||||
"""
|
||||
Test that the scans/sessions endpoint returns 400 for invalid sort key.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": session_id, "sort": sort})
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {
|
||||
"detail": "Invalid sort order. Must be a JSON object with valid keys."
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scan_with_id(logged_in_client):
|
||||
"""
|
||||
Test that scans/id endpoint returns the correct scan.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
|
||||
assert response.status_code == 200
|
||||
scan_id = response.json()[0]["scan_id"]
|
||||
|
||||
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
|
||||
assert response.status_code == 200
|
||||
scan = response.json()
|
||||
assert scan["scan_id"] == scan_id
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scan_with_id_wrong_id(logged_in_client):
|
||||
"""
|
||||
Test that the scans/id endpoint returns None for a wrong scan id.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
response = client.get("/api/v1/scans/id", params={"scan_id": "wrong_id"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_get_scan_with_id_and_fields(logged_in_client):
|
||||
"""
|
||||
Test that the scans/id endpoint returns the correct fields.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
|
||||
assert response.status_code == 200
|
||||
scan_id = response.json()[0]["scan_id"]
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/scans/id", params={"scan_id": scan_id, "fields": ["scan_number"]}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
scan = response.json()
|
||||
assert "scan_number" in scan
|
||||
assert "dataset_number" not in scan
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_update_scan_user_data(logged_in_client):
|
||||
"""
|
||||
Test that the scans/id endpoint updates the user_data.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
session_id = _get_session(client)
|
||||
|
||||
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
|
||||
assert response.status_code == 200
|
||||
scan_id = response.json()[0]["scan_id"]
|
||||
|
||||
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
|
||||
assert response.status_code == 200
|
||||
scan = response.json()
|
||||
assert "scan_data" not in scan
|
||||
|
||||
response = client.patch(
|
||||
"/api/v1/scans/user_data",
|
||||
params={"scan_id": scan_id},
|
||||
json={"name": "test", "user_rating": 5},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
|
||||
assert response.status_code == 200
|
||||
scan = response.json()
|
||||
assert scan["user_data"] == {"name": "test", "user_rating": 5}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_update_scan_user_data_wrong_id(logged_in_client):
|
||||
"""
|
||||
Test that the scans/id endpoint returns 404 for a wrong scan id.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
response = client.patch(
|
||||
"/api/v1/scans/user_data",
|
||||
params={"scan_id": "wrong_id"},
|
||||
json={"name": "test", "user_rating": 5},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Scan not found"}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.parametrize(
|
||||
"filter, count", [({}, 4), ('{"scan_number": 2251}', 1), ('{"scan_number": 2}', 0)]
|
||||
)
|
||||
def test_count_scans(logged_in_client, filter, count):
|
||||
"""
|
||||
Test that the scans/count endpoint returns the correct number of scans.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
response = client.get("/api/v1/scans/count", params={"filter": filter})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": count}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_count_scans_with_invalid_filter(logged_in_client):
|
||||
"""
|
||||
Test that the scans/count endpoint returns 400 for an invalid filter.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
response = client.get("/api/v1/scans/count", params={"filter": "invalid"})
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Invalid filter. Must be a JSON object."}
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_count_scans_with_no_results(logged_in_client):
|
||||
"""
|
||||
Test that the scans/count endpoint returns 0 for no results.
|
||||
"""
|
||||
client = logged_in_client
|
||||
|
||||
_filter = {"session_id": str(ObjectId())}
|
||||
response = client.get("/api/v1/scans/count", params={"filter": json.dumps(_filter)})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 0}
|
Reference in New Issue
Block a user