reduction tests for python

This commit is contained in:
2025-09-01 14:15:08 +02:00
parent 8733a1d66f
commit 7926993bb2
7 changed files with 140 additions and 47 deletions

View File

@@ -34,31 +34,54 @@ void define_Cluster(py::module &m, const std::string &typestr) {
cluster.data[i] = r(i);
}
return cluster;
}));
}))
/*
//TODO! Review if to keep or not
.def_property(
"data",
[](ClusterType &c) -> py::array {
return py::array(py::buffer_info(
c.data, sizeof(Type),
py::format_descriptor<Type>::format(), // Type
// format
1, // Number of dimensions
{static_cast<ssize_t>(ClusterSizeX *
ClusterSizeY)}, // Shape (flattened)
{sizeof(Type)} // Stride (step size between elements)
));
// TODO! Review if to keep or not
.def_property_readonly(
"data",
[](Cluster<Type, ClusterSizeX, ClusterSizeY, CoordType> &c)
-> py::array {
return py::array(py::buffer_info(
c.data.data(), sizeof(Type),
py::format_descriptor<Type>::format(), // Type
// format
2, // Number of dimensions
{static_cast<ssize_t>(ClusterSizeX),
static_cast<ssize_t>(ClusterSizeY)}, // Shape (flattened)
{sizeof(Type) * ClusterSizeY, sizeof(Type)}
// Stride (step size between elements)
));
})
.def_readonly("x",
&Cluster<Type, ClusterSizeX, ClusterSizeY, CoordType>::x)
.def_readonly("y",
&Cluster<Type, ClusterSizeX, ClusterSizeY, CoordType>::y);
}
template <typename T, uint8_t ClusterSizeX, uint8_t ClusterSizeY,
typename CoordType = int16_t>
void reduce_to_3x3(py::module &m) {
m.def(
"reduce_to_3x3",
[](const Cluster<T, ClusterSizeX, ClusterSizeY, CoordType> &cl) {
return reduce_to_3x3(cl);
},
[](ClusterType &c, py::array_t<Type> arr) {
py::buffer_info buf_info = arr.request();
Type *ptr = static_cast<Type *>(buf_info.ptr);
std::copy(ptr, ptr + ClusterSizeX * ClusterSizeY,
c.data); // TODO dont iterate over centers!!!
py::return_value_policy::move);
}
});
*/
template <typename T, uint8_t ClusterSizeX, uint8_t ClusterSizeY,
typename CoordType = int16_t>
void reduce_to_2x2(py::module &m) {
m.def(
"reduce_to_2x2",
[](const Cluster<T, ClusterSizeX, ClusterSizeY, CoordType> &cl) {
return reduce_to_2x2(cl);
},
py::return_value_policy::move);
}
#pragma GCC diagnostic pop

View File

@@ -48,7 +48,8 @@ double, 'f' for float)
define_ClusterCollector<T, N, M, U>(m, "Cluster" #N "x" #M #TYPE_CODE); \
define_Cluster<T, N, M, U>(m, #N "x" #M #TYPE_CODE); \
register_calculate_eta<T, N, M, U>(m); \
define_2x2_reduction<T, N, M, U>(m);
define_2x2_reduction<T, N, M, U>(m); \
reduce_to_2x2<T, N, M, U>(m);
PYBIND11_MODULE(_aare, m) {
define_file_io_bindings(m);
@@ -86,16 +87,29 @@ PYBIND11_MODULE(_aare, m) {
DEFINE_CLUSTER_BINDINGS(double, 9, 9, uint16_t, d);
DEFINE_CLUSTER_BINDINGS(float, 9, 9, uint16_t, f);
define_3x3_reduction<int, 3, 3>(m);
define_3x3_reduction<double, 3, 3>(m);
define_3x3_reduction<float, 3, 3>(m);
define_3x3_reduction<int, 5, 5>(m);
define_3x3_reduction<double, 5, 5>(m);
define_3x3_reduction<float, 5, 5>(m);
define_3x3_reduction<int, 7, 7>(m);
define_3x3_reduction<double, 7, 7>(m);
define_3x3_reduction<float, 7, 7>(m);
define_3x3_reduction<int, 9, 9>(m);
define_3x3_reduction<double, 9, 9>(m);
define_3x3_reduction<float, 9, 9>(m);
define_3x3_reduction<int, 3, 3, uint16_t>(m);
define_3x3_reduction<double, 3, 3, uint16_t>(m);
define_3x3_reduction<float, 3, 3, uint16_t>(m);
define_3x3_reduction<int, 5, 5, uint16_t>(m);
define_3x3_reduction<double, 5, 5, uint16_t>(m);
define_3x3_reduction<float, 5, 5, uint16_t>(m);
define_3x3_reduction<int, 7, 7, uint16_t>(m);
define_3x3_reduction<double, 7, 7, uint16_t>(m);
define_3x3_reduction<float, 7, 7, uint16_t>(m);
define_3x3_reduction<int, 9, 9, uint16_t>(m);
define_3x3_reduction<double, 9, 9, uint16_t>(m);
define_3x3_reduction<float, 9, 9, uint16_t>(m);
reduce_to_3x3<int, 3, 3, uint16_t>(m);
reduce_to_3x3<double, 3, 3, uint16_t>(m);
reduce_to_3x3<float, 3, 3, uint16_t>(m);
reduce_to_3x3<int, 5, 5, uint16_t>(m);
reduce_to_3x3<double, 5, 5, uint16_t>(m);
reduce_to_3x3<float, 5, 5, uint16_t>(m);
reduce_to_3x3<int, 7, 7, uint16_t>(m);
reduce_to_3x3<double, 7, 7, uint16_t>(m);
reduce_to_3x3<float, 7, 7, uint16_t>(m);
reduce_to_3x3<int, 9, 9, uint16_t>(m);
reduce_to_3x3<double, 9, 9, uint16_t>(m);
reduce_to_3x3<float, 9, 9, uint16_t>(m);
}