diff --git a/include/aare/Cluster.hpp b/include/aare/Cluster.hpp index 889593b..aa324a9 100644 --- a/include/aare/Cluster.hpp +++ b/include/aare/Cluster.hpp @@ -74,6 +74,110 @@ struct Cluster { } }; +template +Cluster reduce_3x3_to_2x2(const Cluster &c) { + Cluster result; + + auto [s, i] = c.max_sum_2x2(); + switch (i) { + case 0: + result.x = c.x-1; + result.y = c.y+1; + result.data = {c.data[0], c.data[1], c.data[3], c.data[4]}; + break; + case 1: + result.x = c.x; + result.y = c.y + 1; + result.data = {c.data[1], c.data[2], c.data[4], c.data[5]}; + break; + case 2: + result.x = c.x -1; + result.y = c.y; + result.data = {c.data[3], c.data[4], c.data[6], c.data[7]}; + break; + case 3: + result.x = c.x; + result.y = c.y; + result.data = {c.data[4], c.data[5], c.data[7], c.data[8]}; + break; + } + + // do some stuff + return result; +} + +template +Cluster reduce_5x5_to_3x3(const Cluster &c) { + Cluster result; + + // Reduce the 5x5 cluster to a 3x3 cluster by selecting the 3x3 block with the highest sum + std::array sum_3x3_subclusters; + + //Write out the sums in the hope that the compiler can optimize this + sum_3x3_subclusters[0] = c.data[0] + c.data[1] + c.data[2] + c.data[5] + c.data[6] + c.data[7] + c.data[10] + c.data[11] + c.data[12]; + sum_3x3_subclusters[1] = c.data[1] + c.data[2] + c.data[3] + c.data[6] + c.data[7] + c.data[8] + c.data[11] + c.data[12] + c.data[13]; + sum_3x3_subclusters[2] = c.data[2] + c.data[3] + c.data[4] + c.data[7] + c.data[8] + c.data[9] + c.data[12] + c.data[13] + c.data[14]; + sum_3x3_subclusters[3] = c.data[5] + c.data[6] + c.data[7] + c.data[10] + c.data[11] + c.data[12] + c.data[15] + c.data[16] + c.data[17]; + sum_3x3_subclusters[4] = c.data[6] + c.data[7] + c.data[8] + c.data[11] + c.data[12] + c.data[13] + c.data[16] + c.data[17] + c.data[18]; + sum_3x3_subclusters[5] = c.data[7] + c.data[8] + c.data[9] + c.data[12] + c.data[13] + c.data[14] + c.data[17] + c.data[18] + c.data[19]; + sum_3x3_subclusters[6] = c.data[10] + c.data[11] + c.data[12] + c.data[15] + c.data[16] + c.data[17] + c.data[20] + c.data[21] + c.data[22]; + sum_3x3_subclusters[7] = c.data[11] + c.data[12] + c.data[13] + c.data[16] + c.data[17] + c.data[18] + c.data[21] + c.data[22] + c.data[23]; + sum_3x3_subclusters[8] = c.data[12] + c.data[13] + c.data[14] + c.data[17] + c.data[18] + c.data[19] + c.data[22] + c.data[23] + c.data[24]; + + auto index = std::max_element(sum_3x3_subclusters.begin(), sum_3x3_subclusters.end()) - sum_3x3_subclusters.begin(); + + switch (index) { + case 0: + result.x = c.x - 1; + result.y = c.y + 1; + result.data = {c.data[0], c.data[1], c.data[2], c.data[5], c.data[6], c.data[7], c.data[10], c.data[11], c.data[12]}; + break; + case 1: + result.x = c.x; + result.y = c.y + 1; + result.data = {c.data[1], c.data[2], c.data[3], c.data[6], c.data[7], c.data[8], c.data[11], c.data[12], c.data[13]}; + break; + case 2: + result.x = c.x + 1; + result.y = c.y + 1; + result.data = {c.data[2], c.data[3], c.data[4], c.data[7], c.data[8], c.data[9], c.data[12], c.data[13], c.data[14]}; + break; + case 3: + result.x = c.x - 1; + result.y = c.y; + result.data = {c.data[5], c.data[6], c.data[7], c.data[10], c.data[11], c.data[12], c.data[15], c.data[16], c.data[17]}; + break; + case 4: + result.x = c.x + 1; + result.y = c.y; + result.data = {c.data[6], c.data[7], c.data[8], c.data[11], c.data[12], c.data[13], c.data[16], c.data[17], c.data[18]}; + break; + case 5: + result.x = c.x + 1; + result.y = c.y; + result.data = {c.data[7], c.data[8], c.data[9], c.data[12], c.data[13], c.data[14], c.data[17], c.data[18], c.data[19]}; + break; + case 6: + result.x = c.x + 1; + result.y = c.y -1; + result.data = {c.data[10], c.data[11], c.data[12], c.data[15], c.data[16], c.data[17], c.data[20], c.data[21], c.data[22]}; + break; + case 7: + result.x = c.x + 1; + result.y = c.y-1; + result.data = {c.data[11], c.data[12], c.data[13], c.data[16], c.data[17], c.data[18], c.data[21], c.data[22], c.data[23]}; + break; + case 8: + result.x = c.x + 1; + result.y = c.y-1; + result.data = {c.data[12], c.data[13], c.data[14], c.data[17], c.data[18], c.data[19], c.data[22], c.data[23], c.data[24]}; + break; + } + + // do some stuff + return result; +} + // Type Traits for is_cluster_type template struct is_cluster : std::false_type {}; // Default case: Not a Cluster diff --git a/include/aare/ClusterVector.hpp b/include/aare/ClusterVector.hpp index 9d575d9..0f9ce57 100644 --- a/include/aare/ClusterVector.hpp +++ b/include/aare/ClusterVector.hpp @@ -167,4 +167,22 @@ class ClusterVector> { } }; +template +ClusterVector> reduce_3x3_to_2x2(const ClusterVector> &cv) { + ClusterVector> result; + for (const auto &c : cv) { + result.push_back(reduce_3x3_to_2x2(c)); + } + return result; +} + +template +ClusterVector> reduce_5x5_to_3x3(const ClusterVector> &cv) { + ClusterVector> result; + for (const auto &c : cv) { + result.push_back(reduce_5x5_to_3x3(c)); + } + return result; +} + } // namespace aare \ No newline at end of file diff --git a/python/src/bind_ClusterVector.hpp b/python/src/bind_ClusterVector.hpp index 9e9c4ab..198cac9 100644 --- a/python/src/bind_ClusterVector.hpp +++ b/python/src/bind_ClusterVector.hpp @@ -104,4 +104,14 @@ void define_ClusterVector(py::module &m, const std::string &typestr) { }); } +void define_reduction(py::module &m) { + m.def("reduce_3x3_to_2x2", [](const ClusterVector> &cv) { + return new ClusterVector>(reduce_3x3_to_2x2(cv)); + // return new ClusterVector>(); + }) + .def("reduce_5x5_to_3x3", [](const ClusterVector> &cv) { + return new ClusterVector>(reduce_5x5_to_3x3(cv)); + }); +} + #pragma GCC diagnostic pop \ No newline at end of file diff --git a/python/src/module.cpp b/python/src/module.cpp index fc04a9f..a4b6404 100644 --- a/python/src/module.cpp +++ b/python/src/module.cpp @@ -81,4 +81,8 @@ PYBIND11_MODULE(_aare, m) { DEFINE_CLUSTER_BINDINGS(int, 9, 9, uint16_t, i); DEFINE_CLUSTER_BINDINGS(double, 9, 9, uint16_t, d); DEFINE_CLUSTER_BINDINGS(float, 9, 9, uint16_t, f); + + + define_reduction(m); + }