#include "aare/ClusterVector.hpp" #include "aare/Interpolator.hpp" #include "aare/NDArray.hpp" #include #include #include #include using namespace aare; TEST_CASE("Test new Interpolation API", "[Interpolation]") { NDArray energy_bins(std::array{0.0, 100.0}); NDArray etax_bins(std::array{0.0, 0.3, 0.6, 1.0}); NDArray etay_bins(std::array{0.0, 0.3, 0.6, 1.0}); NDArray eta_distribution(std::array{3, 3, 1}, 0.0); Interpolator interpolator(eta_distribution.view(), etax_bins.view(), etay_bins.view(), energy_bins.view()); ClusterVector> cluster_vec{}; cluster_vec.push_back(Cluster{ 2, 2, std::array{1, 2, 2, 1, 4, 1, 1, 2, 1}}); auto photons = interpolator.interpolate>(cluster_vec); CHECK(photons.size() == 1); } TEST_CASE("Test constructor", "[Interpolation]") { NDArray energy_bins(std::array{2}); NDArray etax_bins(std::array{4}, 0.0); NDArray etay_bins(std::array{4}, 0.0); NDArray eta_distribution(std::array{3, 3, 1}); std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0); Interpolator interpolator(eta_distribution.view(), etax_bins.view(), etay_bins.view(), energy_bins.view()); auto ietax = interpolator.get_ietax(); auto ietay = interpolator.get_ietay(); CHECK(ietax.shape(0) == 3); CHECK(ietax.shape(1) == 3); CHECK(ietax.shape(2) == 1); CHECK(ietay.shape(0) == 3); CHECK(ietay.shape(1) == 3); CHECK(ietay.shape(2) == 1); std::array expected_ietax{ 0.0, 0.0, 0.0, 4.0 / 11.0, 5.0 / 13.0, 6.0 / 15.0, 1.0, 1.0, 1.0}; std::array expected_ietay{ 0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0}; for (ssize_t i = 0; i < ietax.shape(0); i++) { for (ssize_t j = 0; j < ietax.shape(1); j++) { CHECK(ietax(i, j, 0) == Catch::Approx(expected_ietax[i * ietax.shape(1) + j])); } } for (ssize_t i = 0; i < ietay.shape(0); i++) { for (ssize_t j = 0; j < ietay.shape(1); j++) { CHECK(ietay(i, j, 0) == Catch::Approx(expected_ietay[i * ietay.shape(1) + j])); } } } TEST_CASE("Test constructor with zero bins at borders", "[Interpolation]") { NDArray energy_bins(std::array{2}); NDArray etax_bins(std::array{5}, 0.0); NDArray etay_bins(std::array{5}, 0.0); NDArray eta_distribution(std::array{4, 4, 1}, 0.0); eta_distribution(1, 1, 0) = 1.0; eta_distribution(1, 2, 0) = 2.0; eta_distribution(2, 1, 0) = 3.0; eta_distribution(2, 2, 0) = 4.0; Interpolator interpolator(eta_distribution.view(), etax_bins.view(), etay_bins.view(), energy_bins.view()); auto ietax = interpolator.get_ietax(); auto ietay = interpolator.get_ietay(); CHECK(ietax.shape(0) == 4); CHECK(ietax.shape(1) == 4); CHECK(ietax.shape(2) == 1); CHECK(ietay.shape(0) == 4); CHECK(ietay.shape(1) == 4); CHECK(ietay.shape(2) == 1); std::array expected_ietax{ 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 4.0, 2.0 / 6.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0}; std::array expected_ietay{ 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 3.0, 1.0, 1.0, 0.0, 3.0 / 7.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0}; for (ssize_t i = 0; i < ietax.shape(0); i++) { for (ssize_t j = 0; j < ietax.shape(1); j++) { CHECK(ietax(i, j, 0) == Catch::Approx(expected_ietax[i * ietax.shape(1) + j])); } } for (ssize_t i = 0; i < ietay.shape(0); i++) { for (ssize_t j = 0; j < ietay.shape(1); j++) { CHECK(ietay(i, j, 0) == Catch::Approx(expected_ietay[i * ietay.shape(1) + j])); } } } TEST_CASE("Test Rosenblatt", "[Interpolation]") { NDArray energy_bins(std::array{2}); NDArray etax_bins(std::array{4}, 0.0); NDArray etay_bins(std::array{4}, 0.0); NDArray eta_distribution(std::array{3, 3, 1}); std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0); Interpolator interpolator(etax_bins.view(), etay_bins.view(), energy_bins.view()); interpolator.rosenblatttransform(eta_distribution.view()); auto ietax = interpolator.get_ietax(); auto ietay = interpolator.get_ietay(); CHECK(ietax.shape(0) == 3); CHECK(ietax.shape(1) == 3); CHECK(ietax.shape(2) == 1); CHECK(ietay.shape(0) == 3); CHECK(ietay.shape(1) == 3); CHECK(ietay.shape(2) == 1); // marginal CDF of eta_x std::array expected_ietax{ 0.0, 0.0, 0.0, 15.0 / 39.0, 15.0 / 39.0, 15.0 / 39.0, 1.0, 1.0, 1.0}; // conditional CDF of eta_y std::array expected_ietay{ 0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0}; for (ssize_t i = 0; i < ietax.shape(0); i++) { for (ssize_t j = 0; j < ietax.shape(1); j++) { CHECK(ietax(i, j, 0) == Catch::Approx(expected_ietax[i * ietax.shape(1) + j])); } } for (ssize_t i = 0; i < ietay.shape(0); i++) { for (ssize_t j = 0; j < ietay.shape(1); j++) { CHECK(ietay(i, j, 0) == Catch::Approx(expected_ietay[i * ietay.shape(1) + j])); } } }