Files
aare/src/Interpolation.test.cpp
AliceMazzoleni99 267ca87ab0 Dev/rosenblatttransform (#241)
- added rosenblatttransform 
- added 3x3 eta methods 
- interpolation can be used with various eta functions
- added documentation for interpolation, eta calculation 
- exposed full eta struct in python 
- disable ClusterFinder for 2x2 clusters 
- factory function for ClusterVector

---------

Co-authored-by: Dhanya Thattil <dhanya.thattil@psi.ch>
Co-authored-by: Erik Fröjdh <erik.frojdh@psi.ch>
2025-11-21 14:48:46 +01:00

166 lines
5.7 KiB
C++

#include "aare/ClusterVector.hpp"
#include "aare/Interpolator.hpp"
#include "aare/NDArray.hpp"
#include <array>
#include <catch2/catch_all.hpp>
#include <catch2/catch_test_macros.hpp>
#include <iostream>
using namespace aare;
TEST_CASE("Test new Interpolation API", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1}, 0.0);
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
ClusterVector<Cluster<double, 3, 3>> cluster_vec{};
cluster_vec.push_back(Cluster<double, 3, 3>{
2, 2, std::array<double, 9>{1, 2, 2, 1, 4, 1, 1, 2, 1}});
auto photons =
interpolator.interpolate<calculate_eta2<double, 3, 3>>(cluster_vec);
CHECK(photons.size() == 1);
}
TEST_CASE("Test constructor", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1});
std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0);
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 3);
CHECK(ietax.shape(1) == 3);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 3);
CHECK(ietay.shape(1) == 3);
CHECK(ietay.shape(2) == 1);
std::array<double, 9> expected_ietax{
0.0, 0.0, 0.0, 4.0 / 11.0, 5.0 / 13.0, 6.0 / 15.0, 1.0, 1.0, 1.0};
std::array<double, 9> expected_ietay{
0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}
TEST_CASE("Test constructor with zero bins at borders", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{5}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{5}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{4, 4, 1}, 0.0);
eta_distribution(1, 1, 0) = 1.0;
eta_distribution(1, 2, 0) = 2.0;
eta_distribution(2, 1, 0) = 3.0;
eta_distribution(2, 2, 0) = 4.0;
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 4);
CHECK(ietax.shape(1) == 4);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 4);
CHECK(ietay.shape(1) == 4);
CHECK(ietay.shape(2) == 1);
std::array<double, 16> expected_ietax{
0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 4.0, 2.0 / 6.0, 0.0,
0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0};
std::array<double, 16> expected_ietay{
0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 3.0, 1.0, 1.0,
0.0, 3.0 / 7.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}
TEST_CASE("Test Rosenblatt", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1});
std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0);
Interpolator interpolator(etax_bins.view(), etay_bins.view(),
energy_bins.view());
interpolator.rosenblatttransform(eta_distribution.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 3);
CHECK(ietax.shape(1) == 3);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 3);
CHECK(ietay.shape(1) == 3);
CHECK(ietay.shape(2) == 1);
// marginal CDF of eta_x
std::array<double, 9> expected_ietax{
0.0, 0.0, 0.0, 15.0 / 39.0, 15.0 / 39.0, 15.0 / 39.0, 1.0, 1.0, 1.0};
// conditional CDF of eta_y
std::array<double, 9> expected_ietay{
0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}