Dev/rosenblatttransform (#241)

- added rosenblatttransform 
- added 3x3 eta methods 
- interpolation can be used with various eta functions
- added documentation for interpolation, eta calculation 
- exposed full eta struct in python 
- disable ClusterFinder for 2x2 clusters 
- factory function for ClusterVector

---------

Co-authored-by: Dhanya Thattil <dhanya.thattil@psi.ch>
Co-authored-by: Erik Fröjdh <erik.frojdh@psi.ch>
This commit is contained in:
2025-11-21 14:48:46 +01:00
committed by GitHub
parent 7fb500c44c
commit 267ca87ab0
49 changed files with 3253 additions and 1172 deletions

View File

@@ -20,21 +20,21 @@ using ClusterTypes =
auto get_test_parameters() {
return GENERATE(
std::make_tuple(ClusterTypes{Cluster<int, 2, 2>{0, 0, {1, 2, 3, 1}}},
Eta2<int>{2. / 3, 3. / 4, corner::cTopLeft, 7}),
std::make_tuple(ClusterTypes{Cluster<int, 2, 2>{0, 0, {1, 2, 1, 3}}},
Eta2<int>{3. / 4, 3. / 5, corner::cTopLeft, 7}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{0, 0, {1, 2, 3, 4, 5, 6, 1, 2, 7}}},
Eta2<int>{6. / 11, 2. / 7, corner::cBottomRight, 20}),
ClusterTypes{Cluster<int, 3, 3>{0, 0, {1, 2, 3, 4, 7, 6, 1, 2, 5}}},
Eta2<int>{6. / 13, 2. / 9, corner::cBottomRight, 20}),
std::make_tuple(ClusterTypes{Cluster<int, 5, 5>{
0, 0, {1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 9, 8,
0, 0, {1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 8, 9,
1, 4, 1, 6, 7, 8, 1, 1, 1, 1, 1, 1}}},
Eta2<int>{8. / 17, 7. / 15, corner::cBottomLeft, 30}),
Eta2<int>{9. / 17, 7. / 16, corner::cBottomLeft, 30}),
std::make_tuple(
ClusterTypes{Cluster<int, 4, 2>{0, 0, {1, 4, 7, 2, 5, 6, 4, 3}}},
Eta2<int>{4. / 10, 4. / 11, corner::cTopLeft, 21}),
ClusterTypes{Cluster<int, 4, 2>{0, 0, {1, 4, 4, 2, 5, 6, 7, 3}}},
Eta2<int>{7. / 13, 7. / 11, corner::cTopLeft, 21}),
std::make_tuple(
ClusterTypes{Cluster<int, 2, 3>{0, 0, {1, 3, 2, 3, 4, 2}}},
Eta2<int>{3. / 5, 2. / 5, corner::cBottomLeft, 11}));
ClusterTypes{Cluster<int, 2, 3>{0, 0, {1, 3, 2, 4, 3, 2}}},
Eta2<int>{4. / 6, 2. / 6, corner::cBottomLeft, 11}));
}
TEST_CASE("compute_largest_2x2_subcluster", "[eta_calculation]") {
@@ -61,10 +61,22 @@ TEST_CASE("calculate_eta2", "[eta_calculation]") {
CHECK(eta.sum == expected_eta.sum);
}
// 3x3 cluster layout (rotated to match the cBottomLeft enum):
// 6, 7, 8
// 3, 4, 5
// 0, 1, 2
TEST_CASE("calculate_eta2 after reduction", "[eta_calculation]") {
auto [cluster, expected_eta] = get_test_parameters();
auto eta = std::visit(
[](const auto &clustertype) {
auto reduced_cluster = reduce_to_2x2(clustertype);
return calculate_eta2(reduced_cluster);
},
cluster);
CHECK(eta.x == expected_eta.x);
CHECK(eta.y == expected_eta.y);
CHECK(eta.c == expected_eta.c);
CHECK(eta.sum == expected_eta.sum);
}
TEST_CASE("Calculate eta2 for a 3x3 int32 cluster with the largest 2x2 sum in "
"the bottom left",
@@ -74,29 +86,25 @@ TEST_CASE("Calculate eta2 for a 3x3 int32 cluster with the largest 2x2 sum in "
Cluster<int32_t, 3, 3> cl;
cl.x = 0;
cl.y = 0;
cl.data[0] = 30;
cl.data[1] = 23;
cl.data[0] = 8;
cl.data[1] = 2;
cl.data[2] = 5;
cl.data[3] = 20;
cl.data[4] = 50;
cl.data[5] = 3;
cl.data[6] = 8;
cl.data[7] = 2;
cl.data[6] = 30;
cl.data[7] = 23;
cl.data[8] = 3;
// 8, 2, 3
// 20, 50, 3
// 30, 23, 5
auto eta = calculate_eta2(cl);
CHECK(eta.c == corner::cBottomLeft);
CHECK(eta.x == 50.0 / (20 + 50)); // 4/(3+4)
CHECK(eta.y == 50.0 / (23 + 50)); // 4/(1+4)
CHECK(eta.x == 50.0 / (20 + 50));
CHECK(eta.y == 23.0 / (23 + 50));
CHECK(eta.sum == 30 + 23 + 20 + 50);
}
TEST_CASE("Calculate eta2 for a 3x3 int32 cluster with the largest 2x2 sum in "
"the top left",
"the top right",
"[eta_calculation]") {
// Create a 3x3 cluster
@@ -105,21 +113,67 @@ TEST_CASE("Calculate eta2 for a 3x3 int32 cluster with the largest 2x2 sum in "
cl.y = 0;
cl.data[0] = 8;
cl.data[1] = 12;
cl.data[2] = 5;
cl.data[2] = 82;
cl.data[3] = 77;
cl.data[4] = 80;
cl.data[5] = 3;
cl.data[6] = 82;
cl.data[7] = 91;
cl.data[5] = 91;
cl.data[6] = 5;
cl.data[7] = 3;
cl.data[8] = 3;
// 82, 91, 3
// 77, 80, 3
// 8, 12, 5
auto eta = calculate_eta2(cl);
CHECK(eta.c == corner::cTopLeft);
CHECK(eta.x == 80. / (77 + 80)); // 4/(3+4)
CHECK(eta.y == 91.0 / (91 + 80)); // 7/(7+4)
CHECK(eta.sum == 77 + 80 + 82 + 91);
CHECK(eta.c == corner::cTopRight);
CHECK(eta.x == 91. / (80 + 91));
CHECK(eta.y == 80.0 / (80 + 12));
CHECK(eta.sum == 12 + 80 + 82 + 91);
}
auto get_test_parameters_fulleta2x2() {
return GENERATE(
std::make_tuple(ClusterTypes{Cluster<int, 2, 2>{0, 0, {1, 2, 1, 3}}},
Eta2<int>{5. / 7, 4. / 7, corner::cTopLeft, 7}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{0, 0, {1, 2, 3, 4, 7, 6, 1, 2, 5}}},
Eta2<int>{11. / 20, 7. / 20, corner::cBottomRight, 20}),
std::make_tuple(ClusterTypes{Cluster<int, 5, 5>{
0, 0, {1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 8, 9,
1, 4, 1, 6, 7, 8, 1, 1, 1, 1, 1, 1}}},
Eta2<int>{16. / 30, 13. / 30, corner::cBottomLeft, 30}),
std::make_tuple(
ClusterTypes{Cluster<int, 4, 2>{0, 0, {1, 4, 4, 2, 5, 6, 7, 3}}},
Eta2<int>{11. / 21, 13. / 21, corner::cTopLeft, 21}),
std::make_tuple(
ClusterTypes{Cluster<int, 2, 3>{0, 0, {1, 3, 2, 4, 3, 2}}},
Eta2<int>{6. / 11, 5. / 11, corner::cBottomLeft, 11}));
}
TEST_CASE("Calculate full eta2", "[eta_calculation]") {
auto [test_cluster, expected_eta] = get_test_parameters_fulleta2x2();
auto eta = std::visit(
[](const auto &clustertype) {
return calculate_full_eta2(clustertype);
},
test_cluster);
CHECK(expected_eta.c == eta.c);
CHECK(expected_eta.sum == eta.sum);
CHECK(expected_eta.x == eta.x);
CHECK(expected_eta.y == eta.y);
}
TEST_CASE("Calculate full eta2 after reduction", "[eta_calculation]") {
auto [test_cluster, expected_eta] = get_test_parameters_fulleta2x2();
auto eta = std::visit(
[](const auto &clustertype) {
auto reduced_cluster = reduce_to_2x2(clustertype);
return calculate_full_eta2(reduced_cluster);
},
test_cluster);
CHECK(expected_eta.c == eta.c);
CHECK(expected_eta.sum == eta.sum);
CHECK(expected_eta.x == eta.x);
CHECK(expected_eta.y == eta.y);
}

View File

@@ -14,7 +14,7 @@
using namespace aare;
TEST_CASE("Test sum of Cluster", "[.cluster]") {
TEST_CASE("Test sum of Cluster", "[cluster]") {
Cluster<int, 2, 2> cluster{0, 0, {1, 2, 3, 4}};
CHECK(cluster.sum() == 10);
@@ -26,33 +26,33 @@ using ClusterTypes = std::variant<Cluster<int, 2, 2>, Cluster<int, 3, 3>,
using ClusterTypesLargerThan2x2 =
std::variant<Cluster<int, 3, 3>, Cluster<int, 4, 4>, Cluster<int, 5, 5>>;
TEST_CASE("Test reduce to 2x2 Cluster", "[.cluster]") {
TEST_CASE("Test reduce to 2x2 Cluster", "[cluster]") {
auto [cluster, expected_reduced_cluster] = GENERATE(
std::make_tuple(ClusterTypes{Cluster<int, 2, 2>{5, 5, {1, 2, 3, 4}}},
Cluster<int, 2, 2>{4, 6, {1, 2, 3, 4}}),
Cluster<int, 2, 2>{5, 5, {1, 2, 3, 4}}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{5, 5, {1, 1, 1, 1, 3, 2, 1, 2, 2}}},
Cluster<int, 2, 2>{5, 5, {3, 2, 2, 2}}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{5, 5, {1, 1, 1, 2, 3, 1, 2, 2, 1}}},
Cluster<int, 2, 2>{4, 5, {2, 3, 2, 2}}),
Cluster<int, 2, 2>{5, 5, {2, 3, 2, 2}}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{5, 5, {2, 2, 1, 2, 3, 1, 1, 1, 1}}},
Cluster<int, 2, 2>{4, 6, {2, 2, 2, 3}}),
Cluster<int, 2, 2>{5, 5, {2, 2, 2, 3}}),
std::make_tuple(
ClusterTypes{Cluster<int, 3, 3>{5, 5, {1, 2, 2, 1, 3, 2, 1, 1, 1}}},
Cluster<int, 2, 2>{5, 6, {2, 2, 3, 2}}),
Cluster<int, 2, 2>{5, 5, {2, 2, 3, 2}}),
std::make_tuple(ClusterTypes{Cluster<int, 5, 5>{
5, 5, {1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 3,
2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}},
Cluster<int, 2, 2>{5, 6, {2, 2, 3, 2}}),
Cluster<int, 2, 2>{5, 5, {2, 2, 3, 2}}),
std::make_tuple(ClusterTypes{Cluster<int, 5, 5>{
5, 5, {1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 3,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}},
Cluster<int, 2, 2>{4, 6, {2, 2, 2, 3}}),
Cluster<int, 2, 2>{5, 5, {2, 2, 2, 3}}),
std::make_tuple(
ClusterTypes{Cluster<int, 2, 3>{5, 5, {2, 2, 3, 2, 1, 1}}},
Cluster<int, 2, 2>{4, 6, {2, 2, 3, 2}}));
Cluster<int, 2, 2>{5, 5, {2, 2, 3, 2}}));
auto reduced_cluster = std::visit(
[](const auto &clustertype) { return reduce_to_2x2(clustertype); },
@@ -65,7 +65,7 @@ TEST_CASE("Test reduce to 2x2 Cluster", "[.cluster]") {
expected_reduced_cluster.data.begin()));
}
TEST_CASE("Test reduce to 3x3 Cluster", "[.cluster]") {
TEST_CASE("Test reduce to 3x3 Cluster", "[cluster]") {
auto [cluster, expected_reduced_cluster] = GENERATE(
std::make_tuple(ClusterTypesLargerThan2x2{Cluster<int, 3, 3>{
5, 5, {1, 1, 1, 1, 3, 1, 1, 1, 1}}},
@@ -73,23 +73,11 @@ TEST_CASE("Test reduce to 3x3 Cluster", "[.cluster]") {
std::make_tuple(
ClusterTypesLargerThan2x2{Cluster<int, 4, 4>{
5, 5, {2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1}}},
Cluster<int, 3, 3>{4, 6, {2, 2, 1, 2, 2, 1, 1, 1, 3}}),
std::make_tuple(
ClusterTypesLargerThan2x2{Cluster<int, 4, 4>{
5, 5, {1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 1, 1}}},
Cluster<int, 3, 3>{5, 6, {1, 2, 2, 1, 2, 2, 1, 3, 1}}),
std::make_tuple(
ClusterTypesLargerThan2x2{Cluster<int, 4, 4>{
5, 5, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1, 1, 2, 2}}},
Cluster<int, 3, 3>{5, 5, {1, 1, 1, 1, 3, 2, 1, 2, 2}}),
std::make_tuple(
ClusterTypesLargerThan2x2{Cluster<int, 4, 4>{
5, 5, {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 1, 2, 2, 1, 1}}},
Cluster<int, 3, 3>{4, 5, {1, 1, 1, 2, 2, 3, 2, 2, 1}}),
Cluster<int, 3, 3>{5, 5, {2, 1, 1, 1, 3, 1, 1, 1, 1}}),
std::make_tuple(ClusterTypesLargerThan2x2{Cluster<int, 5, 5>{
5, 5, {1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 3,
1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1}}},
Cluster<int, 3, 3>{4, 5, {1, 2, 1, 2, 2, 3, 1, 2, 1}}));
Cluster<int, 3, 3>{5, 5, {2, 1, 1, 2, 3, 1, 2, 1, 1}}));
auto reduced_cluster = std::visit(
[](const auto &clustertype) { return reduce_to_3x3(clustertype); },

166
src/Interpolation.test.cpp Normal file
View File

@@ -0,0 +1,166 @@
#include "aare/ClusterVector.hpp"
#include "aare/Interpolator.hpp"
#include "aare/NDArray.hpp"
#include <array>
#include <catch2/catch_all.hpp>
#include <catch2/catch_test_macros.hpp>
#include <iostream>
using namespace aare;
TEST_CASE("Test new Interpolation API", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1}, 0.0);
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
ClusterVector<Cluster<double, 3, 3>> cluster_vec{};
cluster_vec.push_back(Cluster<double, 3, 3>{
2, 2, std::array<double, 9>{1, 2, 2, 1, 4, 1, 1, 2, 1}});
auto photons =
interpolator.interpolate<calculate_eta2<double, 3, 3>>(cluster_vec);
CHECK(photons.size() == 1);
}
TEST_CASE("Test constructor", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1});
std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0);
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 3);
CHECK(ietax.shape(1) == 3);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 3);
CHECK(ietay.shape(1) == 3);
CHECK(ietay.shape(2) == 1);
std::array<double, 9> expected_ietax{
0.0, 0.0, 0.0, 4.0 / 11.0, 5.0 / 13.0, 6.0 / 15.0, 1.0, 1.0, 1.0};
std::array<double, 9> expected_ietay{
0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}
TEST_CASE("Test constructor with zero bins at borders", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{5}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{5}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{4, 4, 1}, 0.0);
eta_distribution(1, 1, 0) = 1.0;
eta_distribution(1, 2, 0) = 2.0;
eta_distribution(2, 1, 0) = 3.0;
eta_distribution(2, 2, 0) = 4.0;
Interpolator interpolator(eta_distribution.view(), etax_bins.view(),
etay_bins.view(), energy_bins.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 4);
CHECK(ietax.shape(1) == 4);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 4);
CHECK(ietay.shape(1) == 4);
CHECK(ietay.shape(2) == 1);
std::array<double, 16> expected_ietax{
0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 4.0, 2.0 / 6.0, 0.0,
0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0};
std::array<double, 16> expected_ietay{
0.0, 0.0, 0.0, 0.0, 0.0, 1.0 / 3.0, 1.0, 1.0,
0.0, 3.0 / 7.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}
TEST_CASE("Test Rosenblatt", "[Interpolation]") {
NDArray<double, 1> energy_bins(std::array<ssize_t, 1>{2});
NDArray<double, 1> etax_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 1> etay_bins(std::array<ssize_t, 1>{4}, 0.0);
NDArray<double, 3> eta_distribution(std::array<ssize_t, 3>{3, 3, 1});
std::iota(eta_distribution.begin(), eta_distribution.end(), 1.0);
Interpolator interpolator(etax_bins.view(), etay_bins.view(),
energy_bins.view());
interpolator.rosenblatttransform(eta_distribution.view());
auto ietax = interpolator.get_ietax();
auto ietay = interpolator.get_ietay();
CHECK(ietax.shape(0) == 3);
CHECK(ietax.shape(1) == 3);
CHECK(ietax.shape(2) == 1);
CHECK(ietay.shape(0) == 3);
CHECK(ietay.shape(1) == 3);
CHECK(ietay.shape(2) == 1);
// marginal CDF of eta_x
std::array<double, 9> expected_ietax{
0.0, 0.0, 0.0, 15.0 / 39.0, 15.0 / 39.0, 15.0 / 39.0, 1.0, 1.0, 1.0};
// conditional CDF of eta_y
std::array<double, 9> expected_ietay{
0.0, 2.0 / 5.0, 1.0, 0.0, 5.0 / 11.0, 1.0, 0.0, 8.0 / 17.0, 1.0};
for (ssize_t i = 0; i < ietax.shape(0); i++) {
for (ssize_t j = 0; j < ietax.shape(1); j++) {
CHECK(ietax(i, j, 0) ==
Catch::Approx(expected_ietax[i * ietax.shape(1) + j]));
}
}
for (ssize_t i = 0; i < ietay.shape(0); i++) {
for (ssize_t j = 0; j < ietay.shape(1); j++) {
CHECK(ietay(i, j, 0) ==
Catch::Approx(expected_ietay[i * ietay.shape(1) + j]));
}
}
}

View File

@@ -2,55 +2,145 @@
namespace aare {
Interpolator::Interpolator(NDView<double, 1> xbins, NDView<double, 1> ybins,
NDView<double, 1> ebins)
: m_etabinsx(xbins), m_etabinsy(ybins), m_energy_bins(ebins){};
Interpolator::Interpolator(NDView<double, 3> etacube, NDView<double, 1> xbins,
NDView<double, 1> ybins, NDView<double, 1> ebins)
: m_ietax(etacube), m_ietay(etacube), m_etabinsx(xbins), m_etabinsy(ybins),
m_energy_bins(ebins) {
if (etacube.shape(0) != xbins.size() || etacube.shape(1) != ybins.size() ||
etacube.shape(2) != ebins.size()) {
: m_etabinsx(xbins), m_etabinsy(ybins), m_energy_bins(ebins) {
if (etacube.shape(0) + 1 != xbins.size() ||
etacube.shape(1) + 1 != ybins.size() ||
etacube.shape(2) + 1 != ebins.size()) {
throw std::invalid_argument(
"The shape of the etacube does not match the shape of the bins");
}
// Cumulative sum in the x direction
for (ssize_t i = 1; i < m_ietax.shape(0); i++) {
for (ssize_t j = 0; j < m_ietax.shape(1); j++) {
for (ssize_t k = 0; k < m_ietax.shape(2); k++) {
m_ietax(i, j, k) += m_ietax(i - 1, j, k);
}
}
}
m_ietax = NDArray<double, 3>(etacube);
// Normalize by the highest row, if norm less than 1 don't do anything
m_ietay = NDArray<double, 3>(etacube);
// prefix sum - conditional CDF
for (ssize_t i = 0; i < m_ietax.shape(0); i++) {
for (ssize_t j = 0; j < m_ietax.shape(1); j++) {
for (ssize_t k = 0; k < m_ietax.shape(2); k++) {
auto val = m_ietax(m_ietax.shape(0) - 1, j, k);
double norm = val < 1 ? 1 : val;
m_ietax(i, j, k) /= norm;
m_ietax(i, j, k) += (i == 0) ? 0 : m_ietax(i - 1, j, k);
m_ietay(i, j, k) += (j == 0) ? 0 : m_ietay(i, j - 1, k);
}
}
}
// Cumulative sum in the y direction
for (ssize_t i = 0; i < m_ietay.shape(0); i++) {
for (ssize_t j = 1; j < m_ietay.shape(1); j++) {
for (ssize_t k = 0; k < m_ietay.shape(2); k++) {
m_ietay(i, j, k) += m_ietay(i, j - 1, k);
}
}
}
// Standardize, if norm less than 1 don't do anything
for (ssize_t i = 0; i < m_ietax.shape(0); i++) {
for (ssize_t j = 0; j < m_ietax.shape(1); j++) {
for (ssize_t k = 0; k < m_ietax.shape(2); k++) {
auto shift_x = etacube(0, j, k);
auto val_etax = m_ietax(m_ietax.shape(0) - 1, j, k) - shift_x;
double norm_etax = val_etax == 0 ? 1 : val_etax;
m_ietax(i, j, k) -= shift_x;
m_ietax(i, j, k) /= norm_etax;
auto shift_y = etacube(i, 0, k);
auto val_etay = m_ietay(i, m_ietay.shape(1) - 1, k) - shift_y;
double norm_etay = val_etay == 0 ? 1 : val_etay;
m_ietay(i, j, k) -= shift_y;
// Normalize by the highest column, if norm less than 1 don't do anything
for (ssize_t i = 0; i < m_ietay.shape(0); i++) {
for (ssize_t j = 0; j < m_ietay.shape(1); j++) {
for (ssize_t k = 0; k < m_ietay.shape(2); k++) {
auto val = m_ietay(i, m_ietay.shape(1) - 1, k);
double norm = val < 1 ? 1 : val;
m_ietay(i, j, k) /= norm;
m_ietay(i, j, k) /= norm_etay;
}
}
}
}
void Interpolator::rosenblatttransform(NDView<double, 3> etacube) {
if (etacube.shape(0) + 1 != m_etabinsx.size() ||
etacube.shape(1) + 1 != m_etabinsy.size() ||
etacube.shape(2) + 1 != m_energy_bins.size()) {
throw std::invalid_argument(
"The shape of the etacube does not match the shape of the bins");
}
// TODO: less loops and better performance if ebins is first dimension
// (violates backwardscompatibility ieta_x and ieta_y public getters,
// previously generated etacubes)
// TODO: maybe more loops is better then storing total_sum_y and
// total_sum_x
// marginal CDF for eta_x
NDArray<double, 2> marg_CDF_EtaX(
std::array<ssize_t, 2>{m_etabinsx.size() - 1, m_energy_bins.size() - 1},
0.0); // simulate proper probability distribution with zero at start
// conditional CDF for eta_y
NDArray<double, 3> cond_CDF_EtaY(etacube);
for (ssize_t i = 0; i < cond_CDF_EtaY.shape(0); ++i) {
for (ssize_t j = 0; j < cond_CDF_EtaY.shape(1); ++j) {
for (ssize_t k = 0; k < cond_CDF_EtaY.shape(2); ++k) {
// cumsum along y-axis
marg_CDF_EtaX(i, k) +=
etacube(i, j,
k); // marginal probability for etaX
// cumsum along y-axis
cond_CDF_EtaY(i, j, k) +=
(j == 0) ? 0 : cond_CDF_EtaY(i, j - 1, k);
}
}
}
// cumsum along x-axis
for (ssize_t i = 1; i < marg_CDF_EtaX.shape(0); ++i) {
for (ssize_t k = 0; k < marg_CDF_EtaX.shape(1); ++k) {
marg_CDF_EtaX(0, k) =
0.0; // shift by first value to ensure values between 0 and 1
marg_CDF_EtaX(i, k) += marg_CDF_EtaX(i - 1, k);
}
}
// normalize marg_CDF_EtaX
for (ssize_t i = 1; i < marg_CDF_EtaX.shape(0); ++i) {
for (ssize_t k = 0; k < marg_CDF_EtaX.shape(1); ++k) {
double norm = marg_CDF_EtaX(marg_CDF_EtaX.shape(0) - 1, k) == 0
? 1
: marg_CDF_EtaX(marg_CDF_EtaX.shape(0) - 1, k);
marg_CDF_EtaX(i, k) /= norm;
}
}
// standardize, normalize conditional CDF for etaY
// Note P(EtaY|EtaX) = P(EtaY,EtaX)/P(EtaX) we dont divide by P(EtaX) as it
// cancels out during normalization
for (ssize_t i = 0; i < cond_CDF_EtaY.shape(0); ++i) {
for (ssize_t j = 0; j < cond_CDF_EtaY.shape(1); ++j) {
for (ssize_t k = 0; k < cond_CDF_EtaY.shape(2); ++k) {
double shift = etacube(i, 0, k);
double norm =
(cond_CDF_EtaY(i, cond_CDF_EtaY.shape(1) - 1, k) - shift) ==
0
? 1
: cond_CDF_EtaY(i, cond_CDF_EtaY.shape(1) - 1, k) -
shift;
cond_CDF_EtaY(i, j, k) -= shift;
cond_CDF_EtaY(i, j, k) /= norm;
}
}
}
m_ietay = std::move(
cond_CDF_EtaY); // TODO maybe rename m_ietay to lookup or CDF_EtaY_cond
// TODO: should actually be only 2dimensional keep three dimension due to
// consistency with Annas code change though
m_ietax = NDArray<double, 3>(
std::array<ssize_t, 3>{m_etabinsx.size() - 1, m_etabinsy.size() - 1,
m_energy_bins.size() - 1});
for (ssize_t i = 0; i < m_etabinsx.size() - 1; ++i)
for (ssize_t j = 0; j < m_etabinsy.size() - 1; ++j)
for (ssize_t k = 0; k < m_energy_bins.size() - 1; ++k)
m_ietax(i, j, k) = marg_CDF_EtaX(i, k);
}
} // namespace aare

View File

@@ -192,3 +192,43 @@ TEST_CASE("Last element is different", "[algorithm]") {
std::vector<int> vec = {1, 1, 1, 1, 2};
REQUIRE(aare::all_equal(vec) == false);
}
TEST_CASE("Linear interpolation", "[algorithm]") {
SECTION("interpolated mean value") {
const double interpolated_value =
aare::linear_interpolation({0.0, 1.0}, {4.0, 6.0}, 0.5);
REQUIRE(interpolated_value == 5.0);
}
SECTION("interpolate left value") {
const double interpolated_value =
aare::linear_interpolation({0.0, 1.0}, {4.0, 6.0}, 0.0);
REQUIRE(interpolated_value == 4.0);
}
SECTION("interpolate right value") {
const double interpolated_value =
aare::linear_interpolation({0.0, 1.0}, {4.0, 6.0}, 1.0);
REQUIRE(interpolated_value == 6.0);
}
SECTION("interpolate the same value") {
const double interpolated_value =
aare::linear_interpolation({0.0, 1.0}, {4.0, 4.0}, 0.5);
REQUIRE(interpolated_value == 4.0);
}
}
TEST_CASE("Bilinear interpolation", "[algorithm]") {
SECTION("interpolated mean value") {
const double interpolated_value_left =
aare::linear_interpolation({0.0, 1.0}, {4.0, 6.0}, 0.5);
const double interpolated_value_right =
aare::linear_interpolation({0.0, 1.0}, {5.0, 6.0}, 0.5);
const double interpolated_value = aare::linear_interpolation(
{0.5, 1.0}, {interpolated_value_left, interpolated_value_right},
0.75);
REQUIRE(interpolated_value == 5.25);
}
}