python bindings

This commit is contained in:
froejdh_e 2025-04-16 18:08:47 +02:00
parent 73f46e4d2b
commit 3760fd5ed0
4 changed files with 149 additions and 1 deletions

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <vector>
#include <aare/NDView.hpp> #include <aare/NDView.hpp>
namespace aare { namespace aare {
@ -10,4 +11,16 @@ uint16_t adc_sar_04_decode64to16(uint64_t input);
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output); void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output); void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
/**
* @brief Apply custom weights to a 16-bit input value. Will sum up weights[i]**i
* for each bit i that is set in the input value.
* @throws std::out_of_range if weights.size() < 16
* @param input 16-bit input value
* @param weights vector of weights, size must be less than or equal to 16
*/
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights);
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double, 1> weights);
} // namespace aare } // namespace aare

View File

@ -10,6 +10,8 @@
#include "aare/decode.hpp" #include "aare/decode.hpp"
// #include "aare/fClusterFileV2.hpp" // #include "aare/fClusterFileV2.hpp"
#include "np_helper.hpp"
#include <cstdint> #include <cstdint>
#include <filesystem> #include <filesystem>
#include <pybind11/iostream.h> #include <pybind11/iostream.h>
@ -65,6 +67,27 @@ m.def("adc_sar_04_decode64to16", [](py::array_t<uint8_t> input) {
return output; return output;
}); });
m.def("apply_custom_weights", [](py::array_t<uint16_t> input, py::array_t<double> weights) {
if (input.ndim() != 2) {
throw std::runtime_error("Only 2D arrays are supported at this moment");
}
// Create a 2D output array with the same shape as the input
std::vector<ssize_t> shape{input.shape(0), input.shape(1)};
py::array_t<double> output(shape);
auto weights_view = make_view_1d(weights);
// Create a view of the input and output arrays
NDView<uint16_t, 2> input_view(input.mutable_data(), {input.shape(0), input.shape(1)});
NDView<double, 2> output_view(output.mutable_data(), {output.shape(0), output.shape(1)});
apply_custom_weights(input_view, output_view, weights_view);
return output;
});
py::class_<CtbRawFile>(m, "CtbRawFile") py::class_<CtbRawFile>(m, "CtbRawFile")
.def(py::init<const std::filesystem::path &>()) .def(py::init<const std::filesystem::path &>())
.def("read_frame", .def("read_frame",

View File

@ -1,5 +1,5 @@
#include "aare/decode.hpp" #include "aare/decode.hpp"
#include <cmath>
namespace aare { namespace aare {
uint16_t adc_sar_05_decode64to16(uint64_t input){ uint16_t adc_sar_05_decode64to16(uint64_t input){
@ -22,6 +22,10 @@ uint16_t adc_sar_05_decode64to16(uint64_t input){
} }
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){ void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
if(input.shape() != output.shape()){
throw std::invalid_argument(LOCATION + " input and output shapes must match");
}
for(int64_t i = 0; i < input.shape(0); i++){ for(int64_t i = 0; i < input.shape(0); i++){
for(int64_t j = 0; j < input.shape(1); j++){ for(int64_t j = 0; j < input.shape(1); j++){
output(i,j) = adc_sar_05_decode64to16(input(i,j)); output(i,j) = adc_sar_05_decode64to16(input(i,j));
@ -49,6 +53,9 @@ uint16_t adc_sar_04_decode64to16(uint64_t input){
} }
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){ void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
if(input.shape() != output.shape()){
throw std::invalid_argument(LOCATION + " input and output shapes must match");
}
for(int64_t i = 0; i < input.shape(0); i++){ for(int64_t i = 0; i < input.shape(0); i++){
for(int64_t j = 0; j < input.shape(1); j++){ for(int64_t j = 0; j < input.shape(1); j++){
output(i,j) = adc_sar_04_decode64to16(input(i,j)); output(i,j) = adc_sar_04_decode64to16(input(i,j));
@ -57,5 +64,30 @@ void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> outpu
} }
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights) {
if(weights.size() > 16){
throw std::invalid_argument("weights size must be less than or equal to 16");
}
double result = 0.0;
for (ssize_t i = 0; i < weights.size(); ++i) {
result += ((input >> i) & 1) * std::pow(weights[i], i);
}
return result;
}
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double,1> weights) {
if(input.shape() != output.shape()){
throw std::invalid_argument(LOCATION + " input and output shapes must match");
}
for (int64_t i = 0; i < input.shape(0); i++) {
for (int64_t j = 0; j < input.shape(1); j++) {
output(i, j) = apply_custom_weights(input(i, j), weights);
}
}
}
} // namespace aare } // namespace aare

80
src/decode.test.cpp Normal file
View File

@ -0,0 +1,80 @@
#include "aare/decode.hpp"
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include <catch2/catch_test_macros.hpp>
#include "aare/NDArray.hpp"
using Catch::Matchers::WithinAbs;
#include <vector>
TEST_CASE("test_adc_sar_05_decode64to16"){
uint64_t input = 0;
uint16_t output = aare::adc_sar_05_decode64to16(input);
CHECK(output == 0);
// bit 29 on th input is bit 0 on the output
input = 1UL << 29;
output = aare::adc_sar_05_decode64to16(input);
CHECK(output == 1);
// test all bits by iteratting through the bitlist
std::vector<int> bitlist = {29, 19, 28, 18, 31, 21, 27, 20, 24, 23, 25, 22};
for (size_t i = 0; i < bitlist.size(); i++) {
input = 1UL << bitlist[i];
output = aare::adc_sar_05_decode64to16(input);
CHECK(output == (1 << i));
}
// test a few "random" values
input = 0;
input |= (1UL << 29);
input |= (1UL << 19);
input |= (1UL << 28);
output = aare::adc_sar_05_decode64to16(input);
CHECK(output == 7UL);
input = 0;
input |= (1UL << 18);
input |= (1UL << 27);
input |= (1UL << 25);
output = aare::adc_sar_05_decode64to16(input);
CHECK(output == 1096UL);
input = 0;
input |= (1UL << 25);
input |= (1UL << 22);
output = aare::adc_sar_05_decode64to16(input);
CHECK(output == 3072UL);
}
TEST_CASE("test_apply_custom_weights") {
uint16_t input = 1;
aare::NDArray<double, 1> weights_data({3}, 0.0);
weights_data(0) = 1.7;
weights_data(1) = 2.1;
weights_data(2) = 1.8;
auto weights = weights_data.view();
double output = aare::apply_custom_weights(input, weights);
CHECK_THAT(output, WithinAbs(1.0, 0.001));
input = 1UL << 1;
output = aare::apply_custom_weights(input, weights);
CHECK_THAT(output, WithinAbs(2.1, 0.001));
input = 1UL << 2;
output = aare::apply_custom_weights(input, weights);
CHECK_THAT(output, WithinAbs(3.24, 0.001));
input = 0b111;
output = aare::apply_custom_weights(input, weights);
CHECK_THAT(output, WithinAbs(6.34, 0.001));
}