diff --git a/python/src/ctb_raw_file.hpp b/python/src/ctb_raw_file.hpp index 71af668..17da18e 100644 --- a/python/src/ctb_raw_file.hpp +++ b/python/src/ctb_raw_file.hpp @@ -68,7 +68,7 @@ m.def("adc_sar_04_decode64to16", [](py::array_t input) { }); -m.def("apply_custom_weights", [](py::array_t input, py::array_t weights) { +m.def("apply_custom_weights", [](py::array_t& input, py::array_t& weights) { if (input.ndim() != 2) { throw std::runtime_error("Only 2D arrays are supported at this moment"); } diff --git a/src/decode.cpp b/src/decode.cpp index f01f3b6..fc0b5f6 100644 --- a/src/decode.cpp +++ b/src/decode.cpp @@ -81,9 +81,22 @@ void apply_custom_weights(NDView input, NDView output, c if(input.shape() != output.shape()){ throw std::invalid_argument(LOCATION + " input and output shapes must match"); } - for (int64_t i = 0; i < input.shape(0); i++) { - for (int64_t j = 0; j < input.shape(1); j++) { - output(i, j) = apply_custom_weights(input(i, j), weights); + + //Calculate weights to avoid repeatedly calling std::pow + std::vector weights_powers(weights.size()); + for (ssize_t i = 0; i < weights.size(); ++i) { + weights_powers[i] = std::pow(weights[i], i); + } + + // Apply custom weights to each element in the input array + for (ssize_t i = 0; i < input.shape(0); i++) { + for (ssize_t j = 0; j < input.shape(1); j++) { + + double result = 0.0; + for (ssize_t bit_index = 0; bit_index < weights_powers.size(); ++bit_index) { + result += ((input(i,j) >> bit_index) & 1) * weights_powers[bit_index]; + } + output(i,j) = result; } } }