mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2026-05-01 17:12:23 +02:00
Dev/enable custom etas (#305)
- Allowing the users more flexibility to play around with custom eta
functions without touching the c++ code
- passing vector of eta values to ``transform_eta_values``
```
from aare import Interpolator, ClusterVector, Etai, Cluster
import numpy as np
def custom_eta(cluster_pixel_coordinate_x, cluster_pixel_coordinate_y, cluster_data):
# dummy custom eta function that just returns the sum of the cluster data
eta = Etai()
eta.x = 0.1 # dummy x value
eta.y = 0.1 # dummy y value
eta.sum = np.sum(cluster_data) # sum of the cluster data as the "energy
return eta
# Create a dummy eta distribution and bins
eta_distribution = np.zeros((10, 10, 1)) # dummy eta distribution
etax_bins = np.linspace(0, 1.0, 11)
etay_bins = np.linspace(0, 1.0, 11)
e_bins = np.array([0., 10.]) # dummy energy bins
# Create the interpolator
interpolator = Interpolator(eta_distribution, etax_bins, etay_bins, e_bins)
# Create a dummy cluster vector
cluster_vector = ClusterVector()
cluster_vector.push_back(Cluster(10, 5, np.ones(shape=9, dtype = np.int32)))
cluster_vector.push_back(Cluster(20, 10, np.ones(shape=9, dtype = np.int32)))
# Create dummy etas for the clusters
cluster_array = np.array(cluster_vector)
etas = np.array([custom_eta(cluster["x"], cluster["y"], cluster["data"]) for cluster in cluster_array])
# transform eta values to uniform coordinates
uniform_coordinates = interpolator.transform_eta_values(etas)
# Interpolate to get the photon coordinates e.g. apply interpolation logic
photon_coordinates_x = cluster_array["x"] + uniform_coordinates["x"] # add to pixel coordinate
photon_coordinates_y = cluster_array["y"] + uniform_coordinates["y"] # add to pixel coordinate
```
advantage: full control over interpolation logic,
downside: inefficient quite some loops in python
- passing pre computed eta values to interpolate function
```
Interpolator.interpolate(cluster_vector, etas)
```
downside: less flexibility in interpolation logic.
downside: People might misuse it instead of using interpolate directly
with a pre compiled eta function implemented in c++
This commit is contained in:
@@ -26,6 +26,8 @@ from ._aare import reduce_to_2x2, reduce_to_3x3
|
||||
|
||||
from ._aare import apply_custom_weights
|
||||
|
||||
from ._aare import Etai, Etad, Etaf
|
||||
|
||||
from .CtbRawFile import CtbRawFile
|
||||
from .RawFile import RawFile
|
||||
from .ScanParameters import ScanParameters
|
||||
|
||||
@@ -13,12 +13,12 @@ void define_eta(py::module &m, const std::string &typestr) {
|
||||
|
||||
py::class_<Eta2<T>>(m, class_name.c_str())
|
||||
.def(py::init<>())
|
||||
.def_readonly("x", &Eta2<T>::x, "eta x value")
|
||||
.def_readonly("y", &Eta2<T>::y, "eta y value")
|
||||
.def_readonly("c", &Eta2<T>::c,
|
||||
"eta corner value cTopLeft, cTopRight, "
|
||||
"cBottomLeft, cBottomRight")
|
||||
.def_readonly("sum", &Eta2<T>::sum, "photon energy of cluster");
|
||||
.def_readwrite("x", &Eta2<T>::x, "eta x value")
|
||||
.def_readwrite("y", &Eta2<T>::y, "eta y value")
|
||||
.def_readwrite("c", &Eta2<T>::c,
|
||||
"eta corner value cTopLeft, cTopRight, "
|
||||
"cBottomLeft, cBottomRight")
|
||||
.def_readwrite("sum", &Eta2<T>::sum, "photon energy of cluster");
|
||||
}
|
||||
|
||||
void define_corner_enum(py::module &m) {
|
||||
|
||||
@@ -11,17 +11,20 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// clang-format off
|
||||
#define REGISTER_INTERPOLATOR_ETA2(T, N, M, U) \
|
||||
register_interpolate<T, N, M, U, aare::calculate_full_eta2<T, N, M, U>>( \
|
||||
interpolator, "_full_eta2", "full eta2"); \
|
||||
register_interpolate<T, N, M, U, aare::calculate_eta2<T, N, M, U>>( \
|
||||
interpolator, "", "eta2");
|
||||
interpolator, "", "eta2"); \
|
||||
register_interpolate_custom_eta<T, N, M, U>(interpolator);
|
||||
|
||||
#define REGISTER_INTERPOLATOR_ETA3(T, N, M, U) \
|
||||
register_interpolate<T, N, M, U, aare::calculate_eta3<T, N, M, U>>( \
|
||||
interpolator, "_eta3", "full eta3"); \
|
||||
register_interpolate<T, N, M, U, aare::calculate_cross_eta3<T, N, M, U>>( \
|
||||
interpolator, "_cross_eta3", "cross eta3");
|
||||
// clang-format on
|
||||
|
||||
template <typename Type, uint8_t CoordSizeX, uint8_t CoordSizeY,
|
||||
typename CoordType = uint16_t, auto EtaFunction>
|
||||
@@ -48,6 +51,34 @@ void register_interpolate(py::class_<aare::Interpolator> &interpolator,
|
||||
docstring.c_str(), py::arg("cluster_vector"));
|
||||
}
|
||||
|
||||
template <typename Type, uint8_t ClusterSizeX, uint8_t ClusterSizeY,
|
||||
typename CoordType = uint16_t>
|
||||
void register_interpolate_custom_eta(
|
||||
py::class_<aare::Interpolator> &interpolator) {
|
||||
|
||||
using ClusterType = Cluster<Type, ClusterSizeX, ClusterSizeY, CoordType>;
|
||||
|
||||
interpolator.def(
|
||||
"interpolate",
|
||||
[](aare::Interpolator &self, const ClusterVector<ClusterType> &clusters,
|
||||
const std::vector<Eta2<Type>> &etas) {
|
||||
auto photons = self.interpolate<Type, ClusterSizeX, ClusterSizeY, CoordType>(clusters, etas);
|
||||
auto *ptr = new std::vector<Photon>{std::move(photons)};
|
||||
return return_vector(ptr);
|
||||
},
|
||||
R"(
|
||||
Interpolation based on custom eta values provided by the user.
|
||||
|
||||
Args:
|
||||
cluster_vector: vector of clusters to interpolate
|
||||
etas: vector of eta values for each cluster (must be in the same order as the clusters)
|
||||
|
||||
Returns:
|
||||
interpolated photons
|
||||
)",
|
||||
py::arg("cluster_vector"), py::arg("etas"));
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
void register_transform_eta_values(
|
||||
py::class_<aare::Interpolator> &interpolator) {
|
||||
@@ -65,6 +96,8 @@ void define_interpolation_bindings(py::module &m) {
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(aare::Photon, x, y, energy);
|
||||
|
||||
PYBIND11_NUMPY_DTYPE(aare::Coordinate2D, x, y);
|
||||
|
||||
auto interpolator =
|
||||
py::class_<aare::Interpolator>(m, "Interpolator")
|
||||
.def(py::init(
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
|
||||
from aare import Interpolator, ClusterVector, Etai, Cluster
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_interpolation_api():
|
||||
eta_distribution = np.zeros((10, 10, 1)) # dummy eta distribution
|
||||
etax_bins = np.linspace(0, 1.0, 11)
|
||||
etay_bins = np.linspace(0, 1.0, 11)
|
||||
e_bins = np.array([0., 10.]) # dummy energy bins
|
||||
interpolator = Interpolator(eta_distribution, etax_bins, etay_bins, e_bins)
|
||||
|
||||
cluster_vector = ClusterVector()
|
||||
cluster_vector.push_back(Cluster(10, 5, np.ones(shape=9, dtype=np.int32)))
|
||||
cluster_vector.push_back(Cluster(20, 10, np.ones(shape=9, dtype=np.int32)))
|
||||
|
||||
eta1 = Etai()
|
||||
eta1.x = 0.1
|
||||
eta1.y = 0.1
|
||||
eta1.sum = 5
|
||||
eta2 = Etai()
|
||||
eta2.x = 0.1
|
||||
eta2.y = 0.9
|
||||
eta2.sum = 6
|
||||
etas = np.array([eta1, eta2]) # dummy etas for the clusters
|
||||
|
||||
photons = interpolator.interpolate(cluster_vector, etas)
|
||||
|
||||
assert photons.size == cluster_vector.size # should return one photon per cluster
|
||||
Reference in New Issue
Block a user