tests: added more backend tests for scans

This commit is contained in:
2025-02-11 16:52:07 +01:00
parent ac5980a3ae
commit 27061b24e9
4 changed files with 3340 additions and 36 deletions

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import json
from fastapi import APIRouter, Depends, Query
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, Query
from bec_atlas.authentication import get_current_user
from bec_atlas.datasources.mongodb.mongodb import MongoDBDatasource
@ -26,7 +29,7 @@ class ScanRouter(BaseRouter):
self.scans_with_id,
methods=["GET"],
description="Get a single scan by id for a session",
response_model=ScanStatusPartial,
response_model=ScanStatusPartial | None,
response_model_exclude_none=True,
)
self.router.add_api_route(
@ -73,19 +76,18 @@ class ScanRouter(BaseRouter):
"""
if fields:
fields = {
field: 1
for field in fields
if field in ScanStatusPartial.model_json_schema()["properties"].keys()
}
fields = self._update_fields(fields)
if not ObjectId.is_valid(session_id):
raise HTTPException(status_code=400, detail="Invalid session ID")
filters = {"session_id": session_id}
if filter:
filter = json.loads(filter)
filter = self._update_filter(filter)
filters.update(filter)
if sort:
sort = json.loads(sort)
sort = self._update_sort(sort)
return self.db.find(
"scans",
@ -111,15 +113,12 @@ class ScanRouter(BaseRouter):
scan_id (str): The scan id
"""
if fields:
fields = {
field: 1
for field in fields
if field in ScanStatusPartial.model_json_schema()["properties"].keys()
}
fields = self._update_fields(fields)
return self.db.find_one(
collection="scans",
query_filter={"_id": scan_id},
dtype=ScanStatusPartial,
fields=fields,
user=current_user,
)
@ -145,7 +144,7 @@ class ScanRouter(BaseRouter):
return_document=True,
)
if out is None:
return {"message": "Scan not found."}
raise HTTPException(status_code=404, detail="Scan not found")
return {"message": "Scan user data updated."}
async def count_scans(
@ -163,11 +162,81 @@ class ScanRouter(BaseRouter):
"""
pipeline = []
if filter:
filter = json.loads(filter)
filter = self._update_filter(filter)
pipeline.append({"$match": filter})
pipeline.append({"$count": "count"})
out = self.db.aggregate("scans", pipeline=pipeline, dtype=None, user=current_user)
if out:
return out[0]
return {"count": 0}
# I don't think this will ever be reached
else: # pragma: no cover
return {"count": 0}
def _update_filter(self, filter: str) -> dict:
"""
Update the filter for the query.
Args:
filter (str): JSON filter for the query, e.g. '{"name": "test"}'
Returns:
dict: The filter for the query
"""
exc = HTTPException(status_code=400, detail="Invalid filter. Must be a JSON object.")
try:
filter = json.loads(filter)
except json.JSONDecodeError:
# pylint: disable=raise-missing-from
raise exc
if not isinstance(filter, dict):
raise exc
return filter
def _update_fields(self, fields: list[str]) -> dict:
"""
Update the fields to return in the query.
Args:
fields (list[str]): List of fields to return
Returns:
dict: The fields to return
"""
exc = HTTPException(
status_code=400, detail="Invalid fields. Must be a list of valid fields."
)
if not all(
field in ScanStatusPartial.model_json_schema()["properties"].keys() for field in fields
):
raise exc
fields = {field: 1 for field in fields}
return fields
def _update_sort(self, sort: str) -> dict:
"""
Update the sort order for the query.
Args:
sort (str): Sort order for the query, e.g. '{"name": 1}' for ascending order,
'{"name": -1}' for descending order. Multiple fields can be sorted by
separating them with a comma, e.g. '{"name": 1, "description": -1}'
Returns:
dict: The sort order
"""
exc = HTTPException(
status_code=400, detail="Invalid sort order. Must be a JSON object with valid keys."
)
try:
sort = json.loads(sort)
except json.JSONDecodeError:
# pylint: disable=raise-missing-from
raise exc
if not isinstance(sort, dict):
raise exc
if not all(
key in ScanStatusPartial.model_json_schema()["properties"].keys() for key in sort.keys()
):
raise exc
return sort

View File

@ -73,13 +73,13 @@ def convert_to_object_id(data):
return data
@pytest.fixture(scope="session")
@pytest.fixture()
def redis_server():
redis_server = fakeredis.FakeServer()
yield redis_server
@pytest.fixture(scope="session")
@pytest.fixture()
def backend(redis_server):
def _fake_redis(host, port):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,383 @@
import json
import pytest
from bson import ObjectId
@pytest.fixture
def logged_in_client(backend):
client, _ = backend
response = client.post(
"/api/v1/user/login", json={"username": "admin@bec_atlas.ch", "password": "admin"}
)
assert response.status_code == 200
token = response.json()
assert isinstance(token, str)
assert len(token) > 20
client.headers.update({"Authorization": f"Bearer {token}"})
return client
def _get_session(client):
deployments = client.get(
"/api/v1/deployments/realm", params={"realm": "demo_beamline_1"}
).json()
deployment_id = deployments[0]["_id"]
response = client.get("/api/v1/sessions", params={"deployment_id": deployment_id})
assert response.status_code == 200
session_id = response.json()[0]["_id"]
return session_id
@pytest.mark.timeout(60)
def test_get_scans_for_session(logged_in_client):
"""
Test that the scans/sessions endpoint returns the correct number of scans.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
assert response.status_code == 200
scans = response.json()
assert len(scans) == 3
# this endpoint should enforce the session_id as param
response = client.get("/api/v1/scans/session")
assert response.status_code == 422
@pytest.mark.timeout(60)
def test_get_scans_for_session_wrong_id(logged_in_client):
"""
Test that the scans/sessions endpoint returns 400 for a wrong session id.
"""
client = logged_in_client
response = client.get("/api/v1/scans/session", params={"session_id": "wrong_id"})
assert response.status_code == 400
assert response.json() == {"detail": "Invalid session ID"}
@pytest.mark.timeout(60)
def test_get_scans_for_session_with_filter(logged_in_client):
"""
Test that the scans/sessions endpoint returns the correct number of scans with a filter.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session",
params={"session_id": session_id, "filter": '{"scan_number": 2251}'},
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 1
assert scans[0]["scan_number"] == 2251
@pytest.mark.timeout(60)
def test_get_scans_for_session_with_fields(logged_in_client):
"""
Test that the scans/session endpoint returns the correct fields.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "fields": ["scan_number"]}
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 3
assert "scan_number" in scans[0]
assert "num_points" not in scans[0]
@pytest.mark.timeout(60)
def test_get_scans_for_session_with_offset_limit(logged_in_client):
"""
Test that the scans/session endpoint returns the correct number of scans with offset and limit.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "offset": 0, "limit": 1}
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 1
assert scans[0]["scan_number"] == 2251
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "offset": 1, "limit": 1}
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 1
assert scans[0]["scan_number"] == 2252
@pytest.mark.timeout(60)
def test_get_scans_for_session_with_sort(logged_in_client):
"""
Test that the scans/session endpoint returns the correct number of scans with sort.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "sort": '{"scan_number": 1}'}
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 3
assert scans[0]["scan_number"] == 2251
assert scans[1]["scan_number"] == 2252
assert scans[2]["scan_number"] == 2253
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "sort": '{"scan_number": -1}'}
)
assert response.status_code == 200
scans = response.json()
assert len(scans) == 3
assert scans[0]["scan_number"] == 2253
assert scans[1]["scan_number"] == 2252
assert scans[2]["scan_number"] == 2251
@pytest.mark.timeout(60)
def test_get_scans_for_session_with_invalid_sort(logged_in_client):
"""
Test that the scans/session endpoint returns 400 for an invalid sort order.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "sort": "invalid"}
)
assert response.status_code == 400
assert response.json() == {
"detail": "Invalid sort order. Must be a JSON object with valid keys."
}
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"fields", ["invalid", "{'scan_number': 2251}", 123, [123], ["scan_number", 123]]
)
def test_get_scans_for_session_with_invalid_fields(logged_in_client, fields):
"""
Test that the scans/sessions endpoint returns 400 for invalid fields.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "fields": fields}
)
assert response.status_code == 400
assert response.json() == {"detail": "Invalid fields. Must be a list of valid fields."}
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"filter", ["invalid", 123, [123], '{"scan_number": 2251', '{"scan_number": 2251}}']
)
def test_get_scans_for_session_with_invalid_filter(logged_in_client, filter):
"""
Test that the scans/sessions endpoint returns 400 for invalid filter.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get(
"/api/v1/scans/session", params={"session_id": session_id, "filter": filter}
)
assert response.status_code == 400
assert response.json() == {"detail": "Invalid filter. Must be a JSON object."}
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"sort", ["invalid", 123, [123], '{"scan_number": 1', '{"scan_number": 1}}', '{"invalid": 1}']
)
def test_get_scans_for_session_with_invalid_sort_key(logged_in_client, sort):
"""
Test that the scans/sessions endpoint returns 400 for invalid sort key.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get("/api/v1/scans/session", params={"session_id": session_id, "sort": sort})
assert response.status_code == 400
assert response.json() == {
"detail": "Invalid sort order. Must be a JSON object with valid keys."
}
@pytest.mark.timeout(60)
def test_get_scan_with_id(logged_in_client):
"""
Test that scans/id endpoint returns the correct scan.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
assert response.status_code == 200
scan_id = response.json()[0]["scan_id"]
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
assert response.status_code == 200
scan = response.json()
assert scan["scan_id"] == scan_id
@pytest.mark.timeout(60)
def test_get_scan_with_id_wrong_id(logged_in_client):
"""
Test that the scans/id endpoint returns None for a wrong scan id.
"""
client = logged_in_client
response = client.get("/api/v1/scans/id", params={"scan_id": "wrong_id"})
assert response.status_code == 200
assert response.json() is None
@pytest.mark.timeout(60)
def test_get_scan_with_id_and_fields(logged_in_client):
"""
Test that the scans/id endpoint returns the correct fields.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
assert response.status_code == 200
scan_id = response.json()[0]["scan_id"]
response = client.get(
"/api/v1/scans/id", params={"scan_id": scan_id, "fields": ["scan_number"]}
)
assert response.status_code == 200
scan = response.json()
assert "scan_number" in scan
assert "dataset_number" not in scan
@pytest.mark.timeout(60)
def test_update_scan_user_data(logged_in_client):
"""
Test that the scans/id endpoint updates the user_data.
"""
client = logged_in_client
session_id = _get_session(client)
response = client.get("/api/v1/scans/session", params={"session_id": session_id})
assert response.status_code == 200
scan_id = response.json()[0]["scan_id"]
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
assert response.status_code == 200
scan = response.json()
assert "scan_data" not in scan
response = client.patch(
"/api/v1/scans/user_data",
params={"scan_id": scan_id},
json={"name": "test", "user_rating": 5},
)
assert response.status_code == 200
response = client.get("/api/v1/scans/id", params={"scan_id": scan_id})
assert response.status_code == 200
scan = response.json()
assert scan["user_data"] == {"name": "test", "user_rating": 5}
@pytest.mark.timeout(60)
def test_update_scan_user_data_wrong_id(logged_in_client):
"""
Test that the scans/id endpoint returns 404 for a wrong scan id.
"""
client = logged_in_client
response = client.patch(
"/api/v1/scans/user_data",
params={"scan_id": "wrong_id"},
json={"name": "test", "user_rating": 5},
)
assert response.status_code == 404
assert response.json() == {"detail": "Scan not found"}
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"filter, count", [({}, 4), ('{"scan_number": 2251}', 1), ('{"scan_number": 2}', 0)]
)
def test_count_scans(logged_in_client, filter, count):
"""
Test that the scans/count endpoint returns the correct number of scans.
"""
client = logged_in_client
response = client.get("/api/v1/scans/count", params={"filter": filter})
assert response.status_code == 200
assert response.json() == {"count": count}
@pytest.mark.timeout(60)
def test_count_scans_with_invalid_filter(logged_in_client):
"""
Test that the scans/count endpoint returns 400 for an invalid filter.
"""
client = logged_in_client
response = client.get("/api/v1/scans/count", params={"filter": "invalid"})
assert response.status_code == 400
assert response.json() == {"detail": "Invalid filter. Must be a JSON object."}
@pytest.mark.timeout(60)
def test_count_scans_with_no_results(logged_in_client):
"""
Test that the scans/count endpoint returns 0 for no results.
"""
client = logged_in_client
_filter = {"session_id": str(ObjectId())}
response = client.get("/api/v1/scans/count", params={"filter": json.dumps(_filter)})
assert response.status_code == 200
assert response.json() == {"count": 0}