From fa6f33fd851f467e00cf38d3992cb99b7585307c Mon Sep 17 00:00:00 2001 From: siebsi Date: Thu, 17 Apr 2025 11:12:26 +0200 Subject: [PATCH] changes ctb custom weights to accept 1D arrays instead of 2D arrays --- include/aare/decode.hpp | 4 ++-- python/src/ctb_raw_file.hpp | 14 +++++++------- src/decode.cpp | 16 ++++++---------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/include/aare/decode.hpp b/include/aare/decode.hpp index 3e3f170..e784c4a 100644 --- a/include/aare/decode.hpp +++ b/include/aare/decode.hpp @@ -21,6 +21,6 @@ void adc_sar_04_decode64to16(NDView input, NDView outpu */ double apply_custom_weights(uint16_t input, const NDView weights); -void apply_custom_weights(NDView input, NDView output, const NDView weights); +void apply_custom_weights(NDView input, NDView output, const NDView weights); -} // namespace aare \ No newline at end of file +} // namespace aare diff --git a/python/src/ctb_raw_file.hpp b/python/src/ctb_raw_file.hpp index 17da18e..575d449 100644 --- a/python/src/ctb_raw_file.hpp +++ b/python/src/ctb_raw_file.hpp @@ -69,19 +69,19 @@ m.def("adc_sar_04_decode64to16", [](py::array_t input) { m.def("apply_custom_weights", [](py::array_t& input, py::array_t& weights) { - if (input.ndim() != 2) { - throw std::runtime_error("Only 2D arrays are supported at this moment"); + if (input.ndim() != 1) { + throw std::runtime_error("Only 1D arrays are supported at this moment"); } - // Create a 2D output array with the same shape as the input - std::vector shape{input.shape(0), input.shape(1)}; + // Create a 1D output array with the same shape as the input + std::vector shape{input.shape(0)}; py::array_t output(shape); auto weights_view = make_view_1d(weights); // Create a view of the input and output arrays - NDView input_view(input.mutable_data(), {input.shape(0), input.shape(1)}); - NDView output_view(output.mutable_data(), {output.shape(0), output.shape(1)}); + NDView input_view(input.mutable_data(), {input.shape(0)}); + NDView output_view(output.mutable_data(), {output.shape(0)}); apply_custom_weights(input_view, output_view, weights_view); @@ -119,4 +119,4 @@ m.def("apply_custom_weights", [](py::array_t& input, py::array_t input, NDView outpu } } - double apply_custom_weights(uint16_t input, const NDView weights) { if(weights.size() > 16){ throw std::invalid_argument("weights size must be less than or equal to 16"); @@ -77,7 +76,7 @@ double apply_custom_weights(uint16_t input, const NDView weights) { } -void apply_custom_weights(NDView input, NDView output, const NDView weights) { +void apply_custom_weights(NDView input, NDView output, const NDView weights) { if(input.shape() != output.shape()){ throw std::invalid_argument(LOCATION + " input and output shapes must match"); } @@ -89,15 +88,12 @@ void apply_custom_weights(NDView input, NDView output, c } // Apply custom weights to each element in the input array - for (ssize_t i = 0; i < input.shape(0); i++) { - for (ssize_t j = 0; j < input.shape(1); j++) { - - double result = 0.0; - for (ssize_t bit_index = 0; bit_index < weights_powers.size(); ++bit_index) { - result += ((input(i,j) >> bit_index) & 1) * weights_powers[bit_index]; - } - output(i,j) = result; + for (ssize_t i = 0; i < input.shape(0); i++) { + double result = 0.0; + for (ssize_t bit_index = 0; bit_index < weights_powers.size(); ++bit_index) { + result += ((input(i) >> bit_index) & 1) * weights_powers[bit_index]; } + output(i) = result; } }