52 lines
2.0 KiB
C++
52 lines
2.0 KiB
C++
// Copyright (2019-2024) Paul Scherrer Institute
|
|
|
|
#include <catch2/catch_all.hpp>
|
|
|
|
#include "../writer/HDF5Objects.h"
|
|
#include "../resonet/NeuralNetResPredictor.h"
|
|
|
|
TEST_CASE("NeuralNetResPredictor_Prepare", "[LinearAlgebra][Coord]") {
|
|
DiffractionExperiment experiment(DetectorGeometry(8, 2, 8, 36));
|
|
experiment.DetectorDistance_mm(75).PhotonEnergy_keV(12.4).BeamX_pxl(1000).BeamY_pxl(1000);
|
|
|
|
std::vector<int16_t> v(experiment.GetPixelsNum(),0);
|
|
v[1000 * experiment.GetXPixelsNum() + 1000] = 100;
|
|
v[1000 * experiment.GetXPixelsNum() + 1001] = 20;
|
|
v[1001 * experiment.GetXPixelsNum() + 1000] = 30;
|
|
v[1001 * experiment.GetXPixelsNum() + 1001] = INT16_MIN;
|
|
|
|
v[1050 * experiment.GetXPixelsNum() + 1050] = 52;
|
|
v[2000 * experiment.GetXPixelsNum() + 1500] = 160;
|
|
NeuralNetResPredictor predictor("../../resonet/traced_resnet_model.pt");
|
|
|
|
REQUIRE(predictor.GetMaxPoolFactor(experiment) == 2);
|
|
|
|
predictor.Prepare(experiment, v.data());
|
|
auto nn_input = predictor.GetModelInput();
|
|
REQUIRE(nn_input[0] == 10);
|
|
REQUIRE(nn_input[25 * 512 + 25] == 7);
|
|
REQUIRE(nn_input[500 * 512 + 250] == 12);
|
|
}
|
|
|
|
TEST_CASE("NeuralNetResPredictor_Inference", "[LinearAlgebra][Coord]") {
|
|
DiffractionExperiment experiment(DetectorGeometry(8, 2, 8, 36));
|
|
experiment.DetectorDistance_mm(75).PhotonEnergy_keV(12.4).BeamY_pxl(1136).BeamX_pxl(1090);
|
|
|
|
NeuralNetResPredictor predictor("../../resonet/traced_resnet_model.pt");
|
|
|
|
HDF5ReadOnlyFile data("../../tests/test_data/compression_benchmark.h5");
|
|
HDF5DataSet dataset(data, "/entry/data/data");
|
|
HDF5DataSpace file_space(dataset);
|
|
|
|
std::vector<int16_t> image_conv (file_space.GetDimensions()[1] * file_space.GetDimensions()[2]);
|
|
|
|
std::vector<hsize_t> start = {4,0,0};
|
|
std::vector<hsize_t> file_size = {1, file_space.GetDimensions()[1], file_space.GetDimensions()[2]};
|
|
dataset.ReadVector(image_conv, start, file_size);
|
|
|
|
auto res = predictor.Inference(experiment, image_conv.data());
|
|
REQUIRE(res);
|
|
REQUIRE(res.value() < 1.5);
|
|
REQUIRE(res.value() > 1.4);
|
|
}
|