mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2025-06-08 05:30:41 +02:00
changes ctb custom weights to accept 1D arrays instead of 2D arrays
This commit is contained in:
parent
88a8eddacf
commit
fa6f33fd85
@ -21,6 +21,6 @@ void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> outpu
|
|||||||
*/
|
*/
|
||||||
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights);
|
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights);
|
||||||
|
|
||||||
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double, 1> weights);
|
void apply_custom_weights(NDView<uint16_t, 1> input, NDView<double, 1> output, const NDView<double, 1> weights);
|
||||||
|
|
||||||
} // namespace aare
|
} // namespace aare
|
@ -69,19 +69,19 @@ m.def("adc_sar_04_decode64to16", [](py::array_t<uint8_t> input) {
|
|||||||
|
|
||||||
|
|
||||||
m.def("apply_custom_weights", [](py::array_t<uint16_t>& input, py::array_t<double>& weights) {
|
m.def("apply_custom_weights", [](py::array_t<uint16_t>& input, py::array_t<double>& weights) {
|
||||||
if (input.ndim() != 2) {
|
if (input.ndim() != 1) {
|
||||||
throw std::runtime_error("Only 2D arrays are supported at this moment");
|
throw std::runtime_error("Only 1D arrays are supported at this moment");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a 2D output array with the same shape as the input
|
// Create a 1D output array with the same shape as the input
|
||||||
std::vector<ssize_t> shape{input.shape(0), input.shape(1)};
|
std::vector<ssize_t> shape{input.shape(0)};
|
||||||
py::array_t<double> output(shape);
|
py::array_t<double> output(shape);
|
||||||
|
|
||||||
auto weights_view = make_view_1d(weights);
|
auto weights_view = make_view_1d(weights);
|
||||||
|
|
||||||
// Create a view of the input and output arrays
|
// Create a view of the input and output arrays
|
||||||
NDView<uint16_t, 2> input_view(input.mutable_data(), {input.shape(0), input.shape(1)});
|
NDView<uint16_t, 1> input_view(input.mutable_data(), {input.shape(0)});
|
||||||
NDView<double, 2> output_view(output.mutable_data(), {output.shape(0), output.shape(1)});
|
NDView<double, 1> output_view(output.mutable_data(), {output.shape(0)});
|
||||||
|
|
||||||
apply_custom_weights(input_view, output_view, weights_view);
|
apply_custom_weights(input_view, output_view, weights_view);
|
||||||
|
|
||||||
|
@ -63,7 +63,6 @@ void adc_sar_04_decode64to16(NDView<uint64_t, 2> input, NDView<uint16_t,2> outpu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights) {
|
double apply_custom_weights(uint16_t input, const NDView<double, 1> weights) {
|
||||||
if(weights.size() > 16){
|
if(weights.size() > 16){
|
||||||
throw std::invalid_argument("weights size must be less than or equal to 16");
|
throw std::invalid_argument("weights size must be less than or equal to 16");
|
||||||
@ -77,7 +76,7 @@ double apply_custom_weights(uint16_t input, const NDView<double, 1> weights) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, const NDView<double,1> weights) {
|
void apply_custom_weights(NDView<uint16_t, 1> input, NDView<double, 1> output, const NDView<double,1> weights) {
|
||||||
if(input.shape() != output.shape()){
|
if(input.shape() != output.shape()){
|
||||||
throw std::invalid_argument(LOCATION + " input and output shapes must match");
|
throw std::invalid_argument(LOCATION + " input and output shapes must match");
|
||||||
}
|
}
|
||||||
@ -90,14 +89,11 @@ void apply_custom_weights(NDView<uint16_t, 2> input, NDView<double, 2> output, c
|
|||||||
|
|
||||||
// Apply custom weights to each element in the input array
|
// Apply custom weights to each element in the input array
|
||||||
for (ssize_t i = 0; i < input.shape(0); i++) {
|
for (ssize_t i = 0; i < input.shape(0); i++) {
|
||||||
for (ssize_t j = 0; j < input.shape(1); j++) {
|
|
||||||
|
|
||||||
double result = 0.0;
|
double result = 0.0;
|
||||||
for (ssize_t bit_index = 0; bit_index < weights_powers.size(); ++bit_index) {
|
for (ssize_t bit_index = 0; bit_index < weights_powers.size(); ++bit_index) {
|
||||||
result += ((input(i,j) >> bit_index) & 1) * weights_powers[bit_index];
|
result += ((input(i) >> bit_index) & 1) * weights_powers[bit_index];
|
||||||
}
|
|
||||||
output(i,j) = result;
|
|
||||||
}
|
}
|
||||||
|
output(i) = result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user