mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2025-04-19 21:30:02 +02:00
python bindings
This commit is contained in:
parent
73f46e4d2b
commit
3760fd5ed0
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
#include <aare/NDView.hpp>
|
#include <aare/NDView.hpp>
|
||||||
namespace aare {
|
namespace aare {
|
||||||
|
|
||||||
@ -10,4 +11,16 @@ uint16_t adc_sar_04_decode64to16(uint64_t input);
|
|||||||
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
|
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
|
||||||
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
|
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply custom weights to a 16-bit input value. Will sum up weights[i]**i
|
||||||
|
* for each bit i that is set in the input value.
|
||||||
|
* @throws std::out_of_range if weights.size() < 16
|
||||||
|
* @param input 16-bit input value
|
||||||
|
* @param weights vector of weights, size must be less than or equal to 16
|
||||||
|
*/
|
||||||
|
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights);
|
||||||
|
|
||||||
|
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double, 1> weights);
|
||||||
|
|
||||||
} // namespace aare
|
} // namespace aare
|
@ -10,6 +10,8 @@
|
|||||||
#include "aare/decode.hpp"
|
#include "aare/decode.hpp"
|
||||||
// #include "aare/fClusterFileV2.hpp"
|
// #include "aare/fClusterFileV2.hpp"
|
||||||
|
|
||||||
|
#include "np_helper.hpp"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <pybind11/iostream.h>
|
#include <pybind11/iostream.h>
|
||||||
@ -65,6 +67,27 @@ m.def("adc_sar_04_decode64to16", [](py::array_t<uint8_t> input) {
|
|||||||
return output;
|
return output;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
m.def("apply_custom_weights", [](py::array_t<uint16_t> input, py::array_t<double> weights) {
|
||||||
|
if (input.ndim() != 2) {
|
||||||
|
throw std::runtime_error("Only 2D arrays are supported at this moment");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a 2D output array with the same shape as the input
|
||||||
|
std::vector<ssize_t> shape{input.shape(0), input.shape(1)};
|
||||||
|
py::array_t<double> output(shape);
|
||||||
|
|
||||||
|
auto weights_view = make_view_1d(weights);
|
||||||
|
|
||||||
|
// Create a view of the input and output arrays
|
||||||
|
NDView<uint16_t, 2> input_view(input.mutable_data(), {input.shape(0), input.shape(1)});
|
||||||
|
NDView<double, 2> output_view(output.mutable_data(), {output.shape(0), output.shape(1)});
|
||||||
|
|
||||||
|
apply_custom_weights(input_view, output_view, weights_view);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
});
|
||||||
|
|
||||||
py::class_<CtbRawFile>(m, "CtbRawFile")
|
py::class_<CtbRawFile>(m, "CtbRawFile")
|
||||||
.def(py::init<const std::filesystem::path &>())
|
.def(py::init<const std::filesystem::path &>())
|
||||||
.def("read_frame",
|
.def("read_frame",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include "aare/decode.hpp"
|
#include "aare/decode.hpp"
|
||||||
|
#include <cmath>
|
||||||
namespace aare {
|
namespace aare {
|
||||||
|
|
||||||
uint16_t adc_sar_05_decode64to16(uint64_t input){
|
uint16_t adc_sar_05_decode64to16(uint64_t input){
|
||||||
@ -22,6 +22,10 @@ uint16_t adc_sar_05_decode64to16(uint64_t input){
|
|||||||
}
|
}
|
||||||
|
|
||||||
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
|
void adc_sar_05_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
|
||||||
|
if(input.shape() != output.shape()){
|
||||||
|
throw std::invalid_argument(LOCATION + " input and output shapes must match");
|
||||||
|
}
|
||||||
|
|
||||||
for(int64_t i = 0; i < input.shape(0); i++){
|
for(int64_t i = 0; i < input.shape(0); i++){
|
||||||
for(int64_t j = 0; j < input.shape(1); j++){
|
for(int64_t j = 0; j < input.shape(1); j++){
|
||||||
output(i,j) = adc_sar_05_decode64to16(input(i,j));
|
output(i,j) = adc_sar_05_decode64to16(input(i,j));
|
||||||
@ -49,6 +53,9 @@ uint16_t adc_sar_04_decode64to16(uint64_t input){
|
|||||||
}
|
}
|
||||||
|
|
||||||
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
|
void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> output){
|
||||||
|
if(input.shape() != output.shape()){
|
||||||
|
throw std::invalid_argument(LOCATION + " input and output shapes must match");
|
||||||
|
}
|
||||||
for(int64_t i = 0; i < input.shape(0); i++){
|
for(int64_t i = 0; i < input.shape(0); i++){
|
||||||
for(int64_t j = 0; j < input.shape(1); j++){
|
for(int64_t j = 0; j < input.shape(1); j++){
|
||||||
output(i,j) = adc_sar_04_decode64to16(input(i,j));
|
output(i,j) = adc_sar_04_decode64to16(input(i,j));
|
||||||
@ -57,5 +64,30 @@ void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights) {
|
||||||
|
if(weights.size() > 16){
|
||||||
|
throw std::invalid_argument("weights size must be less than or equal to 16");
|
||||||
|
}
|
||||||
|
|
||||||
|
double result = 0.0;
|
||||||
|
for (ssize_t i = 0; i < weights.size(); ++i) {
|
||||||
|
result += ((input >> i) & 1) * std::pow(weights[i], i);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double,1> weights) {
|
||||||
|
if(input.shape() != output.shape()){
|
||||||
|
throw std::invalid_argument(LOCATION + " input and output shapes must match");
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < input.shape(0); i++) {
|
||||||
|
for (int64_t j = 0; j < input.shape(1); j++) {
|
||||||
|
output(i, j) = apply_custom_weights(input(i, j), weights);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace aare
|
} // namespace aare
|
||||||
|
80
src/decode.test.cpp
Normal file
80
src/decode.test.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#include "aare/decode.hpp"
|
||||||
|
|
||||||
|
#include <catch2/matchers/catch_matchers_floating_point.hpp>
|
||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include "aare/NDArray.hpp"
|
||||||
|
using Catch::Matchers::WithinAbs;
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
TEST_CASE("test_adc_sar_05_decode64to16"){
|
||||||
|
uint64_t input = 0;
|
||||||
|
uint16_t output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == 0);
|
||||||
|
|
||||||
|
|
||||||
|
// bit 29 on th input is bit 0 on the output
|
||||||
|
input = 1UL << 29;
|
||||||
|
output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == 1);
|
||||||
|
|
||||||
|
// test all bits by iteratting through the bitlist
|
||||||
|
std::vector<int> bitlist = {29, 19, 28, 18, 31, 21, 27, 20, 24, 23, 25, 22};
|
||||||
|
for (size_t i = 0; i < bitlist.size(); i++) {
|
||||||
|
input = 1UL << bitlist[i];
|
||||||
|
output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == (1 << i));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// test a few "random" values
|
||||||
|
input = 0;
|
||||||
|
input |= (1UL << 29);
|
||||||
|
input |= (1UL << 19);
|
||||||
|
input |= (1UL << 28);
|
||||||
|
output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == 7UL);
|
||||||
|
|
||||||
|
|
||||||
|
input = 0;
|
||||||
|
input |= (1UL << 18);
|
||||||
|
input |= (1UL << 27);
|
||||||
|
input |= (1UL << 25);
|
||||||
|
output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == 1096UL);
|
||||||
|
|
||||||
|
input = 0;
|
||||||
|
input |= (1UL << 25);
|
||||||
|
input |= (1UL << 22);
|
||||||
|
output = aare::adc_sar_05_decode64to16(input);
|
||||||
|
CHECK(output == 3072UL);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE("test_apply_custom_weights") {
|
||||||
|
|
||||||
|
uint16_t input = 1;
|
||||||
|
aare::NDArray<double, 1> weights_data({3}, 0.0);
|
||||||
|
weights_data(0) = 1.7;
|
||||||
|
weights_data(1) = 2.1;
|
||||||
|
weights_data(2) = 1.8;
|
||||||
|
|
||||||
|
auto weights = weights_data.view();
|
||||||
|
|
||||||
|
|
||||||
|
double output = aare::apply_custom_weights(input, weights);
|
||||||
|
CHECK_THAT(output, WithinAbs(1.0, 0.001));
|
||||||
|
|
||||||
|
input = 1UL << 1;
|
||||||
|
output = aare::apply_custom_weights(input, weights);
|
||||||
|
CHECK_THAT(output, WithinAbs(2.1, 0.001));
|
||||||
|
|
||||||
|
|
||||||
|
input = 1UL << 2;
|
||||||
|
output = aare::apply_custom_weights(input, weights);
|
||||||
|
CHECK_THAT(output, WithinAbs(3.24, 0.001));
|
||||||
|
|
||||||
|
input = 0b111;
|
||||||
|
output = aare::apply_custom_weights(input, weights);
|
||||||
|
CHECK_THAT(output, WithinAbs(6.34, 0.001));
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user