diff --git a/RELEASE.md b/RELEASE.md index 9e50145..c9f5e49 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,7 +7,8 @@ - Added a new Minuit2-based fitting framework for ``Gaussian``, ``RisingScurve``, ``FallingScurve``, ``Pol1`` and ``Pol2`` models. - setter and getter for nSigma for ClusterFinder ``aare.ClusterFinder().nSigma = 2``, ``aare.ClusterFinderMT().set_nSigma(2)`` -- mask opeartor for ClusterVector ``masked_clustervector = aare.ClusterVector()(mask)`` +- ``aare.Interpolator.transform_eta_values(np.ndarray)`` can take an array of ``Eta`` structs +- passing pre computed eta values to ``aare.Interpolator.interpolate`` alongside clusters ### Bugfixes: diff --git a/include/aare/Interpolator.hpp b/include/aare/Interpolator.hpp index 548aa91..0fd9fbd 100644 --- a/include/aare/Interpolator.hpp +++ b/include/aare/Interpolator.hpp @@ -93,6 +93,26 @@ class Interpolator { std::vector interpolate(const ClusterVector &clusters) const; + /** + * @brief interpolates the cluster centers for all clusters to a better + * precision + * @param clusters clusters of photon hits to interpolate + * @param etas precomputed eta values for each cluster (must be in the same + * order as the clusters) + * @return interpolated photons (photon positions are given as double but + * following row column format e.g. x=0, y=0 means top row and first column + * of frame) (An interpolated photon position of (1.5, 2.5) corresponds to + * an estimated photon hit at the pixel center of pixel (1,2)) + */ + template >>> + std::vector interpolate( + const ClusterVector> + &clusters, + const std::vector> &etas) const; + /** * @brief transforms the eta values to uniform coordinates based on the CDF * ieta_x and ieta_y @@ -102,6 +122,17 @@ class Interpolator { template Coordinate2D transform_eta_values(const Eta2 &eta) const; + /** + * @brief transforms the eta values to uniform coordinates based on the CDF + * ieta_x and ieta_y for a vector of eta values + * @tparam T type of eta values + * @param etas vector of eta values to transform + * @return vector of uniform coordinates {x,y} + */ + template + std::vector + transform_eta_values(const std::vector> &etas) const; + private: /** * @brief bilinear interpolation of the transformed eta values @@ -190,6 +221,19 @@ Coordinate2D Interpolator::transform_eta_values(const Eta2 &eta) const { return Coordinate2D{m_ietax(ix, iy, ie), m_ietay(ix, iy, ie)}; } +template +std::vector +Interpolator::transform_eta_values(const std::vector> &etas) const { + std::vector uniform_coordinates; + uniform_coordinates.reserve(etas.size()); + + for (const auto &eta : etas) { + uniform_coordinates.push_back(transform_eta_values(eta)); + } + + return uniform_coordinates; +} + template std::vector Interpolator::interpolate(const ClusterVector &clusters) const { @@ -265,4 +309,49 @@ Interpolator::interpolate(const ClusterVector &clusters) const { return photons; } +template +std::vector Interpolator::interpolate( + const ClusterVector> + &clusters, + const std::vector> &etas) const { + + if (clusters.size() != etas.size()) { + throw std::runtime_error( + fmt::format("Size of clusters and precomputed etas must be the " + "same, but got {} clusters and {} etas", + clusters.size(), etas.size())); + } + + std::vector photons; + photons.reserve(clusters.size()); + + for (size_t i = 0; i < clusters.size(); ++i) { + const auto &cluster = clusters[i]; + const auto &eta = etas[i]; + + Photon photon; + photon.x = cluster.x; + photon.y = cluster.y; + photon.energy = static_cast(eta.sum); + + try { + // check if eta values are within bounds + transform_eta_values(eta); + } catch (const std::runtime_error &e) { + throw std::runtime_error( + fmt::format("{} for cluster: {}", e.what(), i)); + } + + auto uniform_coordinates = transform_eta_values(eta); + + photon.x += uniform_coordinates.x; + photon.y += uniform_coordinates.y; + + photons.push_back(photon); + } + + return photons; +} + } // namespace aare \ No newline at end of file diff --git a/python/aare/__init__.py b/python/aare/__init__.py index 7684df8..0145f37 100644 --- a/python/aare/__init__.py +++ b/python/aare/__init__.py @@ -26,6 +26,8 @@ from ._aare import reduce_to_2x2, reduce_to_3x3 from ._aare import apply_custom_weights +from ._aare import Etai, Etad, Etaf + from .CtbRawFile import CtbRawFile from .RawFile import RawFile from .ScanParameters import ScanParameters diff --git a/python/src/bind_Eta.hpp b/python/src/bind_Eta.hpp index e2f9835..c27012f 100644 --- a/python/src/bind_Eta.hpp +++ b/python/src/bind_Eta.hpp @@ -13,12 +13,12 @@ void define_eta(py::module &m, const std::string &typestr) { py::class_>(m, class_name.c_str()) .def(py::init<>()) - .def_readonly("x", &Eta2::x, "eta x value") - .def_readonly("y", &Eta2::y, "eta y value") - .def_readonly("c", &Eta2::c, - "eta corner value cTopLeft, cTopRight, " - "cBottomLeft, cBottomRight") - .def_readonly("sum", &Eta2::sum, "photon energy of cluster"); + .def_readwrite("x", &Eta2::x, "eta x value") + .def_readwrite("y", &Eta2::y, "eta y value") + .def_readwrite("c", &Eta2::c, + "eta corner value cTopLeft, cTopRight, " + "cBottomLeft, cBottomRight") + .def_readwrite("sum", &Eta2::sum, "photon energy of cluster"); } void define_corner_enum(py::module &m) { diff --git a/python/src/bind_Interpolator.hpp b/python/src/bind_Interpolator.hpp index 8bc5ebf..8ee87eb 100644 --- a/python/src/bind_Interpolator.hpp +++ b/python/src/bind_Interpolator.hpp @@ -11,17 +11,20 @@ namespace py = pybind11; +// clang-format off #define REGISTER_INTERPOLATOR_ETA2(T, N, M, U) \ register_interpolate>( \ interpolator, "_full_eta2", "full eta2"); \ register_interpolate>( \ - interpolator, "", "eta2"); + interpolator, "", "eta2"); \ + register_interpolate_costum_eta(interpolator); #define REGISTER_INTERPOLATOR_ETA3(T, N, M, U) \ register_interpolate>( \ interpolator, "_eta3", "full eta3"); \ register_interpolate>( \ interpolator, "_cross_eta3", "cross eta3"); +// clang-format on template @@ -48,6 +51,34 @@ void register_interpolate(py::class_ &interpolator, docstring.c_str(), py::arg("cluster_vector")); } +template +void register_interpolate_costum_eta( + py::class_ &interpolator) { + + using ClusterType = Cluster; + + interpolator.def( + "interpolate", + [](aare::Interpolator &self, const ClusterVector &clusters, + const std::vector> &etas) { + auto photons = self.interpolate(clusters, etas); + auto *ptr = new std::vector{photons}; + return return_vector(ptr); + }, + R"( + Interpolation based on custom eta values provided by the user. + + Args: + cluster_vector: vector of clusters to interpolate + etas: vector of eta values for each cluster (must be in the same order as the clusters) + + Returns: + interpolated photons + )", + py::arg("cluster_vector"), py::arg("etas")); +} + template void register_transform_eta_values( py::class_ &interpolator) { @@ -61,10 +92,27 @@ void register_transform_eta_values( py::arg("Eta")); } +template +void register_transform_eta_values_vector( + py::class_ &interpolator) { + interpolator.def( + "transform_eta_values", + [](Interpolator &self, const std::vector> &etas) { + auto uniform_coords = self.transform_eta_values(etas); + + auto* ptr = new std::vector{uniform_coords}; + return return_vector(ptr); + }, + R"(vector of eta values transformed to uniform coordinates based on CDF ietax, ietay)", + py::arg("Etas")); +} + void define_interpolation_bindings(py::module &m) { PYBIND11_NUMPY_DTYPE(aare::Photon, x, y, energy); + PYBIND11_NUMPY_DTYPE(aare::Coordinate2D, x, y); + auto interpolator = py::class_(m, "Interpolator") .def(py::init( @@ -157,6 +205,10 @@ void define_interpolation_bindings(py::module &m) { register_transform_eta_values(interpolator); register_transform_eta_values(interpolator); + register_transform_eta_values_vector(interpolator); + register_transform_eta_values_vector(interpolator); + register_transform_eta_values_vector(interpolator); + // TODO! Evaluate without converting to double m.def( "hej", diff --git a/python/tests/test_InterpolationAPI.py b/python/tests/test_InterpolationAPI.py new file mode 100644 index 0000000..77bef50 --- /dev/null +++ b/python/tests/test_InterpolationAPI.py @@ -0,0 +1,31 @@ +import pytest + +from aare import Interpolator, ClusterVector, Etai, Cluster + +import numpy as np + + +def test_interpolation_api(): + eta_distribution = np.zeros((10, 10, 1)) # dummy eta distribution + etax_bins = np.linspace(0, 1.0, 11) + etay_bins = np.linspace(0, 1.0, 11) + e_bins = np.array([0., 10.]) # dummy energy bins + interpolator = Interpolator(eta_distribution, etax_bins, etay_bins, e_bins) + + cluster_vector = ClusterVector() + cluster_vector.push_back(Cluster(10, 5, np.ones(shape=9, dtype=np.int32))) + cluster_vector.push_back(Cluster(20, 10, np.ones(shape=9, dtype=np.int32))) + + eta1 = Etai() + eta1.x = 0.1 + eta1.y = 0.1 + eta1.sum = 5 + eta2 = Etai() + eta2.x = 0.1 + eta2.y = 0.9 + eta2.sum = 6 + etas = np.array([eta1, eta2]) # dummy etas for the clusters + + photons = interpolator.interpolate(cluster_vector, etas) + + assert photons.size == 2 \ No newline at end of file