passing eta vector to interpolator and transform_eta_values

This commit is contained in:
2026-04-16 18:57:43 +02:00
parent 0814bd5678
commit 792b26089e
6 changed files with 183 additions and 8 deletions
+6 -6
View File
@@ -13,12 +13,12 @@ void define_eta(py::module &m, const std::string &typestr) {
py::class_<Eta2<T>>(m, class_name.c_str())
.def(py::init<>())
.def_readonly("x", &Eta2<T>::x, "eta x value")
.def_readonly("y", &Eta2<T>::y, "eta y value")
.def_readonly("c", &Eta2<T>::c,
"eta corner value cTopLeft, cTopRight, "
"cBottomLeft, cBottomRight")
.def_readonly("sum", &Eta2<T>::sum, "photon energy of cluster");
.def_readwrite("x", &Eta2<T>::x, "eta x value")
.def_readwrite("y", &Eta2<T>::y, "eta y value")
.def_readwrite("c", &Eta2<T>::c,
"eta corner value cTopLeft, cTopRight, "
"cBottomLeft, cBottomRight")
.def_readwrite("sum", &Eta2<T>::sum, "photon energy of cluster");
}
void define_corner_enum(py::module &m) {
+53 -1
View File
@@ -11,17 +11,20 @@
namespace py = pybind11;
// clang-format off
#define REGISTER_INTERPOLATOR_ETA2(T, N, M, U) \
register_interpolate<T, N, M, U, aare::calculate_full_eta2<T, N, M, U>>( \
interpolator, "_full_eta2", "full eta2"); \
register_interpolate<T, N, M, U, aare::calculate_eta2<T, N, M, U>>( \
interpolator, "", "eta2");
interpolator, "", "eta2"); \
register_interpolate_costum_eta<T, N, M, U>(interpolator);
#define REGISTER_INTERPOLATOR_ETA3(T, N, M, U) \
register_interpolate<T, N, M, U, aare::calculate_eta3<T, N, M, U>>( \
interpolator, "_eta3", "full eta3"); \
register_interpolate<T, N, M, U, aare::calculate_cross_eta3<T, N, M, U>>( \
interpolator, "_cross_eta3", "cross eta3");
// clang-format on
template <typename Type, uint8_t CoordSizeX, uint8_t CoordSizeY,
typename CoordType = uint16_t, auto EtaFunction>
@@ -48,6 +51,34 @@ void register_interpolate(py::class_<aare::Interpolator> &interpolator,
docstring.c_str(), py::arg("cluster_vector"));
}
template <typename Type, uint8_t ClusterSizeX, uint8_t ClusterSizeY,
typename CoordType = uint16_t>
void register_interpolate_costum_eta(
py::class_<aare::Interpolator> &interpolator) {
using ClusterType = Cluster<Type, ClusterSizeX, ClusterSizeY, CoordType>;
interpolator.def(
"interpolate",
[](aare::Interpolator &self, const ClusterVector<ClusterType> &clusters,
const std::vector<Eta2<Type>> &etas) {
auto photons = self.interpolate<Type, ClusterSizeX, ClusterSizeY, CoordType>(clusters, etas);
auto *ptr = new std::vector<Photon>{photons};
return return_vector(ptr);
},
R"(
Interpolation based on custom eta values provided by the user.
Args:
cluster_vector: vector of clusters to interpolate
etas: vector of eta values for each cluster (must be in the same order as the clusters)
Returns:
interpolated photons
)",
py::arg("cluster_vector"), py::arg("etas"));
}
template <typename Type>
void register_transform_eta_values(
py::class_<aare::Interpolator> &interpolator) {
@@ -61,10 +92,27 @@ void register_transform_eta_values(
py::arg("Eta"));
}
template <typename Type>
void register_transform_eta_values_vector(
py::class_<aare::Interpolator> &interpolator) {
interpolator.def(
"transform_eta_values",
[](Interpolator &self, const std::vector<Eta2<Type>> &etas) {
auto uniform_coords = self.transform_eta_values(etas);
auto* ptr = new std::vector<Coordinate2D>{uniform_coords};
return return_vector(ptr);
},
R"(vector of eta values transformed to uniform coordinates based on CDF ietax, ietay)",
py::arg("Etas"));
}
void define_interpolation_bindings(py::module &m) {
PYBIND11_NUMPY_DTYPE(aare::Photon, x, y, energy);
PYBIND11_NUMPY_DTYPE(aare::Coordinate2D, x, y);
auto interpolator =
py::class_<aare::Interpolator>(m, "Interpolator")
.def(py::init(
@@ -157,6 +205,10 @@ void define_interpolation_bindings(py::module &m) {
register_transform_eta_values<float>(interpolator);
register_transform_eta_values<double>(interpolator);
register_transform_eta_values_vector<int>(interpolator);
register_transform_eta_values_vector<float>(interpolator);
register_transform_eta_values_vector<double>(interpolator);
// TODO! Evaluate without converting to double
m.def(
"hej",