v1.0.0-rc.38

This commit is contained in:
2025-05-12 14:17:24 +02:00
parent 19d6f22136
commit b245967df3
201 changed files with 2380 additions and 1432 deletions

View File

@@ -6,7 +6,7 @@
#include <catch2/catch_all.hpp>
#include "../writer/HDF5Objects.h"
#include "../image_analysis/NeuralNetResPredictor.h"
#include "../image_analysis/NeuralNetInferenceClient.h"
TEST_CASE("NeuralNetResPredictor_Prepare", "[LinearAlgebra][Coord]") {
DiffractionExperiment experiment(DetJF4M());
@@ -20,41 +20,31 @@ TEST_CASE("NeuralNetResPredictor_Prepare", "[LinearAlgebra][Coord]") {
v[1050 * experiment.GetXPixelsNum() + 1050] = 52;
v[2000 * experiment.GetXPixelsNum() + 1500] = 160;
NeuralNetResPredictor predictor("../../resonet/traced_resnet_model.pt");
v[800 * experiment.GetXPixelsNum() + 600] = 49;
v[1200 * experiment.GetXPixelsNum() + 600] = 36;
v[800 * experiment.GetXPixelsNum() + 1400] = 64;
NeuralNetInferenceClient predictor;
REQUIRE(predictor.GetMaxPoolFactor(experiment) == 2);
predictor.Prepare(experiment, v.data());
auto nn_input = predictor.GetModelInput();
REQUIRE(nn_input[0] == 10);
REQUIRE(nn_input[25 * 512 + 25] == 7);
REQUIRE(nn_input[500 * 512 + 250] == 12);
}
TEST_CASE("NeuralNetResPredictor_Inference", "[LinearAlgebra][Coord]") {
RegisterHDF5Filter();
DiffractionExperiment experiment(DetJF4M());
experiment.DetectorDistance_mm(75).IncidentEnergy_keV(12.4).BeamY_pxl(1136).BeamX_pxl(1090);
NeuralNetResPredictor predictor("../../resonet/traced_resnet_model.pt");
HDF5ReadOnlyFile data("../../tests/test_data/compression_benchmark.h5");
HDF5DataSet dataset(data, "/entry/data/data");
HDF5DataSpace file_space(dataset);
std::vector<int16_t> image_conv (file_space.GetDimensions()[1] * file_space.GetDimensions()[2]);
std::vector<hsize_t> start = {4,0,0};
std::vector<hsize_t> file_size = {1, file_space.GetDimensions()[1], file_space.GetDimensions()[2]};
dataset.ReadVector(image_conv, start, file_size);
auto res = predictor.Inference(experiment, image_conv.data());
#ifdef JFJOCH_USE_TORCH
REQUIRE(res);
REQUIRE(res.value() < 1.5);
REQUIRE(res.value() > 1.4);
#else
REQUIRE(!res);
#endif
auto br = predictor.Prepare(experiment, v.data(), Quarter::BottomRight);
REQUIRE(br.size() == 512 * 512);
CHECK(br[0] == 10);
CHECK(br[25 * 512 + 25] == 7);
CHECK(br[500 * 512 + 250] == 12);
auto tl = predictor.Prepare(experiment, v.data(), Quarter::TopLeft);
REQUIRE(tl.size() == 512 * 512);
CHECK(tl[100 * 512 + 200] == 7);
auto tr = predictor.Prepare(experiment, v.data(), Quarter::TopRight);
REQUIRE(tr.size() == 512 * 512);
CHECK(tr[100 * 512 + 200] == 8);
auto bl = predictor.Prepare(experiment, v.data(), Quarter::BottomLeft);
REQUIRE(bl.size() == 512 * 512);
CHECK(bl[100 * 512 + 200] == 6);
}