template pedestal (#71)

template the type of m_sum and m_sum2 in pedestal
This commit is contained in:
Bechir Braham 2024-05-27 15:24:26 +02:00 committed by GitHub
parent fed362e843
commit 1dd361a343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 67 deletions

View File

@ -11,7 +11,7 @@ add_library(core STATIC ${SourceFiles})
target_include_directories(core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(core PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(core PUBLIC fmt::fmt PRIVATE aare_compiler_flags utils ) target_link_libraries(core PUBLIC utils fmt::fmt PRIVATE aare_compiler_flags )
if (AARE_PYTHON_BINDINGS) if (AARE_PYTHON_BINDINGS)
set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON)

View File

@ -5,7 +5,8 @@
#include <cstddef> #include <cstddef>
namespace aare { namespace aare {
class Pedestal {
template <typename SUM_TYPE = double> class Pedestal {
public: public:
Pedestal(int rows, int cols, int n_samples = 1000); Pedestal(int rows, int cols, int n_samples = 1000);
~Pedestal(); ~Pedestal();
@ -15,16 +16,16 @@ class Pedestal {
assert(frame.size() == m_rows * m_cols); assert(frame.size() == m_rows * m_cols);
// TODO: test the effect of #pragma omp parallel for // TODO: test the effect of #pragma omp parallel for
for (int index = 0; index < m_rows * m_cols; index++) { for (int index = 0; index < m_rows * m_cols; index++) {
push<T>(index % m_cols, index / m_cols, frame(index)); push<T>(index / m_cols, index % m_cols, frame(index));
} }
} }
template <typename T> void push(Frame &frame) { template <typename T> void push(Frame &frame) {
assert(frame.rows() == static_cast<size_t>(m_rows) && frame.cols() == static_cast<size_t>(m_cols)); assert(frame.rows() == static_cast<size_t>(m_rows) && frame.cols() == static_cast<size_t>(m_cols));
push<T>(frame.view<T>()); push<T>(frame.view<T>());
} }
NDArray<double> mean(); NDArray<SUM_TYPE> mean();
NDArray<double> variance(); NDArray<SUM_TYPE> variance();
NDArray<double> standard_deviation(); NDArray<SUM_TYPE> standard_deviation();
void clear(); void clear();
// getter functions // getter functions
@ -32,33 +33,34 @@ class Pedestal {
inline int cols() const { return m_cols; } inline int cols() const { return m_cols; }
inline int n_samples() const { return m_samples; } inline int n_samples() const { return m_samples; }
inline uint32_t *cur_samples() const { return m_cur_samples; } inline uint32_t *cur_samples() const { return m_cur_samples; }
inline double *get_sum() const { return m_sum; } inline NDArray<SUM_TYPE, 2> get_sum() const { return m_sum; }
inline double *get_sum2() const { return m_sum2; } inline NDArray<SUM_TYPE, 2> get_sum2() const { return m_sum2; }
// pixel level operations (should be refactored to allow users to implement their own pixel level operations) // pixel level operations (should be refactored to allow users to implement their own pixel level operations)
template <typename T> inline void push(const int row, const int col, const T val) { template <typename T> inline void push(const int row, const int col, const T val) {
const int idx = index(row, col); const int idx = index(row, col);
if (m_cur_samples[idx] < m_samples) { if (m_cur_samples[idx] < m_samples) {
m_sum[idx] += val; m_sum(idx) += val;
m_sum2[idx] += val * val; m_sum2(idx) += val * val;
m_cur_samples[idx]++; m_cur_samples[idx]++;
} else { } else {
m_sum[idx] += val - m_sum[idx] / m_cur_samples[idx]; m_sum(idx) += val - m_sum(idx) / m_cur_samples[idx];
m_sum2[idx] += val * val - m_sum2[idx] / m_cur_samples[idx]; m_sum2(idx) += val * val - m_sum2(idx) / m_cur_samples[idx];
} }
} }
double mean(const int row, const int col) const; SUM_TYPE mean(const int row, const int col) const;
double variance(const int row, const int col) const; SUM_TYPE variance(const int row, const int col) const;
double standard_deviation(const int row, const int col) const; SUM_TYPE standard_deviation(const int row, const int col) const;
inline int index(const int row, const int col) const { return row * m_cols + col; }; inline int index(const int row, const int col) const { return row * m_cols + col; };
void clear(const int row, const int col); void clear(const int row, const int col);
private: private:
int m_rows; int m_rows;
int m_cols; int m_cols;
int m_samples; uint32_t m_samples;
uint32_t *m_cur_samples{nullptr}; uint32_t *m_cur_samples{nullptr};
double *m_sum{nullptr}; NDArray<SUM_TYPE, 2> m_sum;
double *m_sum2{nullptr}; NDArray<SUM_TYPE, 2> m_sum2;
}; };
} // namespace aare } // namespace aare

View File

@ -3,68 +3,74 @@
#include <cstddef> #include <cstddef>
namespace aare { namespace aare {
Pedestal::Pedestal(int rows, int cols, int n_samples) template <typename SUM_TYPE>
: m_rows(rows), m_cols(cols), m_samples(n_samples), m_sum(new double[rows * cols]{}), Pedestal<SUM_TYPE>::Pedestal(int rows, int cols, int n_samples)
m_sum2(new double[rows * cols]{}), m_cur_samples(new uint32_t[rows * cols]{}) { : m_rows(rows), m_cols(cols), m_samples(n_samples), m_sum(NDArray<SUM_TYPE, 2>({cols, rows})),
m_sum2(NDArray<SUM_TYPE, 2>({cols, rows})), m_cur_samples(new uint32_t[static_cast<uint64_t>(rows) * cols]{}) {
assert(rows > 0 && cols > 0 && n_samples > 0); assert(rows > 0 && cols > 0 && n_samples > 0);
m_sum = 0;
m_sum2 = 0;
} }
NDArray<double, 2> Pedestal::mean() { template <typename SUM_TYPE> NDArray<SUM_TYPE, 2> Pedestal<SUM_TYPE>::mean() {
NDArray<double, 2> mean_array({m_rows, m_cols}); NDArray<SUM_TYPE, 2> mean_array({m_rows, m_cols});
for (int i = 0; i < m_rows * m_cols; i++) { for (int i = 0; i < m_rows * m_cols; i++) {
mean_array(i % m_cols, i / m_cols) = mean(i % m_cols, i / m_cols); mean_array(i / m_cols, i % m_cols) = mean(i / m_cols, i % m_cols);
} }
return mean_array; return mean_array;
} }
NDArray<double, 2> Pedestal::variance() { template <typename SUM_TYPE> NDArray<SUM_TYPE, 2> Pedestal<SUM_TYPE>::variance() {
NDArray<double, 2> variance_array({m_rows, m_cols}); NDArray<SUM_TYPE, 2> variance_array({m_rows, m_cols});
for (int i = 0; i < m_rows * m_cols; i++) { for (int i = 0; i < m_rows * m_cols; i++) {
variance_array(i % m_cols, i / m_cols) = variance(i % m_cols, i / m_cols); variance_array(i / m_cols, i % m_cols) = variance(i / m_cols, i % m_cols);
} }
return variance_array; return variance_array;
} }
NDArray<double, 2> Pedestal::standard_deviation() { template <typename SUM_TYPE> NDArray<SUM_TYPE, 2> Pedestal<SUM_TYPE>::standard_deviation() {
NDArray<double, 2> standard_deviation_array({m_rows, m_cols}); NDArray<SUM_TYPE, 2> standard_deviation_array({m_rows, m_cols});
for (int i = 0; i < m_rows * m_cols; i++) { for (int i = 0; i < m_rows * m_cols; i++) {
standard_deviation_array(i % m_cols, i / m_cols) = standard_deviation(i % m_cols, i / m_cols); standard_deviation_array(i / m_cols, i % m_cols) = standard_deviation(i / m_cols, i % m_cols);
} }
return standard_deviation_array; return standard_deviation_array;
} }
void Pedestal::clear() { template <typename SUM_TYPE> void Pedestal<SUM_TYPE>::clear() {
for (int i = 0; i < m_rows * m_cols; i++) { for (int i = 0; i < m_rows * m_cols; i++) {
clear(i % m_cols, i / m_cols); clear(i / m_cols, i % m_cols);
} }
} }
/* /*
* index level operations * index level operations
*/ */
double Pedestal::mean(const int row, const int col) const { template <typename SUM_TYPE> SUM_TYPE Pedestal<SUM_TYPE>::mean(const int row, const int col) const {
if (m_cur_samples[index(row, col)] == 0) { if (m_cur_samples[index(row, col)] == 0) {
return 0.0; return 0.0;
} }
return m_sum[index(row, col)] / m_cur_samples[index(row, col)]; return m_sum(row, col) / m_cur_samples[index(row, col)];
} }
double Pedestal::variance(const int row, const int col) const { template <typename SUM_TYPE> SUM_TYPE Pedestal<SUM_TYPE>::variance(const int row, const int col) const {
if (m_cur_samples[index(row, col)] == 0) { if (m_cur_samples[index(row, col)] == 0) {
return 0.0; return 0.0;
} }
return m_sum2[index(row, col)] / m_cur_samples[index(row, col)] - mean(row, col) * mean(row, col); return m_sum2(row, col) / m_cur_samples[index(row, col)] - mean(row, col) * mean(row, col);
}
template <typename SUM_TYPE> SUM_TYPE Pedestal<SUM_TYPE>::standard_deviation(const int row, const int col) const {
return std::sqrt(variance(row, col));
} }
double Pedestal::standard_deviation(const int row, const int col) const { return std::sqrt(variance(row, col)); }
void Pedestal::clear(const int row, const int col) { template <typename SUM_TYPE> void Pedestal<SUM_TYPE>::clear(const int row, const int col) {
m_sum[index(row, col)] = 0; m_sum(row, col) = 0;
m_sum2[index(row, col)] = 0; m_sum2(row, col) = 0;
m_cur_samples[index(row, col)] = 0; m_cur_samples[index(row, col)] = 0;
} }
Pedestal::~Pedestal() { template <typename SUM_TYPE> Pedestal<SUM_TYPE>::~Pedestal() { delete[] m_cur_samples; }
delete[] m_sum;
delete[] m_sum2; template class Pedestal<double>;
delete[] m_cur_samples; template class Pedestal<float>;
} template class Pedestal<long double>;
} // namespace aare } // namespace aare

View File

@ -11,12 +11,10 @@ TEST_CASE("test pedestal constructor") {
REQUIRE(pedestal.cols() == 10); REQUIRE(pedestal.cols() == 10);
REQUIRE(pedestal.n_samples() == 5); REQUIRE(pedestal.n_samples() == 5);
REQUIRE(pedestal.cur_samples() != nullptr); REQUIRE(pedestal.cur_samples() != nullptr);
REQUIRE(pedestal.get_sum() != nullptr);
REQUIRE(pedestal.get_sum2() != nullptr);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) { for (int j = 0; j < 10; j++) {
REQUIRE(pedestal.get_sum()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.get_sum()(i, j) == 0);
REQUIRE(pedestal.get_sum2()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.get_sum2()(i, j) == 0);
REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 0);
} }
} }
@ -35,8 +33,8 @@ TEST_CASE("test pedestal push") {
pedestal.push<uint16_t>(frame); pedestal.push<uint16_t>(frame);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) { for (int j = 0; j < 10; j++) {
REQUIRE(pedestal.get_sum()[pedestal.index(i, j)] == i + j); REQUIRE(pedestal.get_sum()(i, j) == i + j);
REQUIRE(pedestal.get_sum2()[pedestal.index(i, j)] == (i + j) * (i + j)); REQUIRE(pedestal.get_sum2()(i, j) == (i + j) * (i + j));
REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 1); REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 1);
} }
} }
@ -45,8 +43,8 @@ TEST_CASE("test pedestal push") {
pedestal.clear(); pedestal.clear();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) { for (int j = 0; j < 10; j++) {
REQUIRE(pedestal.get_sum()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.get_sum()(i, j) == 0);
REQUIRE(pedestal.get_sum2()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.get_sum2()(i, j) == 0);
REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 0); REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 0);
} }
} }
@ -58,12 +56,12 @@ TEST_CASE("test pedestal push") {
for (int j = 0; j < 10; j++) { for (int j = 0; j < 10; j++) {
if (k < 5) { if (k < 5) {
REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == k + 1); REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == k + 1);
REQUIRE(pedestal.get_sum()[pedestal.index(i, j)] == (k + 1) * (i + j)); REQUIRE(pedestal.get_sum()(i, j) == (k + 1) * (i + j));
REQUIRE(pedestal.get_sum2()[pedestal.index(i, j)] == (k + 1) * (i + j) * (i + j)); REQUIRE(pedestal.get_sum2()(i, j) == (k + 1) * (i + j) * (i + j));
} else { } else {
REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 5); REQUIRE(pedestal.cur_samples()[pedestal.index(i, j)] == 5);
REQUIRE(pedestal.get_sum()[pedestal.index(i, j)] == 5 * (i + j)); REQUIRE(pedestal.get_sum()(i, j) == 5 * (i + j));
REQUIRE(pedestal.get_sum2()[pedestal.index(i, j)] == 5 * (i + j) * (i + j)); REQUIRE(pedestal.get_sum2()(i, j) == 5 * (i + j) * (i + j));
} }
REQUIRE(pedestal.mean(i, j) == (i + j)); REQUIRE(pedestal.mean(i, j) == (i + j));
REQUIRE(pedestal.variance(i, j) == 0); REQUIRE(pedestal.variance(i, j) == 0);
@ -80,24 +78,22 @@ TEST_CASE("test pedestal with normal distribution") {
std::default_random_engine generator(seed); std::default_random_engine generator(seed);
std::normal_distribution<double> distribution(MEAN, STD); std::normal_distribution<double> distribution(MEAN, STD);
aare::Pedestal pedestal(4, 4, 10000); aare::Pedestal pedestal(3, 5, 10000);
for (int i = 0; i < 10000; i++) { for (int i = 0; i < 10000; i++) {
aare::Frame frame(4, 4, 64); aare::Frame frame(3, 5, 64);
for (int j = 0; j < 4; j++) { for (int j = 0; j < 3; j++) {
for (int k = 0; k < 4; k++) { for (int k = 0; k < 5; k++) {
frame.set<double>(j, k, distribution(generator)); frame.set<double>(j, k, distribution(generator));
} }
} }
pedestal.push<double>(frame); pedestal.push<double>(frame);
} }
auto mean = pedestal.mean(); auto mean = pedestal.mean();
auto variance = pedestal.variance(); auto variance = pedestal.variance();
auto standard_deviation = pedestal.standard_deviation(); auto standard_deviation = pedestal.standard_deviation();
for (int i = 0; i < 4; i++) { for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) { for (int j = 0; j < 5; j++) {
// 10% tolerance
REQUIRE(compare_floats<double>(mean(i, j), MEAN, MEAN * TOLERANCE)); REQUIRE(compare_floats<double>(mean(i, j), MEAN, MEAN * TOLERANCE));
REQUIRE(compare_floats<double>(variance(i, j), VAR, VAR * TOLERANCE)); REQUIRE(compare_floats<double>(variance(i, j), VAR, VAR * TOLERANCE));
REQUIRE(compare_floats<double>(standard_deviation(i, j), STD, STD * TOLERANCE)); // maybe sqrt of tolerance? REQUIRE(compare_floats<double>(standard_deviation(i, j), STD, STD * TOLERANCE)); // maybe sqrt of tolerance?

View File

@ -1,9 +1,9 @@
#include "test_config.hpp"
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <climits>
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include "test_config.hpp"
TEST_CASE("Test suite can find data assets") { TEST_CASE("Test suite can find data assets") {
auto fpath = test_data_path() / "numpy" / "test_numpy_file.npy"; auto fpath = test_data_path() / "numpy" / "test_numpy_file.npy";
REQUIRE(std::filesystem::exists(fpath)); REQUIRE(std::filesystem::exists(fpath));
@ -13,4 +13,9 @@ TEST_CASE("Test suite can open data assets") {
auto fpath = test_data_path() / "numpy" / "test_numpy_file.npy"; auto fpath = test_data_path() / "numpy" / "test_numpy_file.npy";
auto f = std::ifstream(fpath, std::ios::binary); auto f = std::ifstream(fpath, std::ios::binary);
REQUIRE(f.is_open()); REQUIRE(f.is_open());
}
TEST_CASE("Test float32 and char8") {
REQUIRE(sizeof(float) == 4);
REQUIRE(CHAR_BIT == 8);
} }