added safe_divide to NDArray and used it for pedestal

This commit is contained in:
froejdh_e
2025-07-24 09:40:38 +02:00
parent cb439efb48
commit 0fea0f5b0e
2 changed files with 27 additions and 36 deletions

View File

@@ -217,6 +217,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
std::fill(shape_.begin(), shape_.end(), 0);
std::fill(strides_.begin(), strides_.end(), 0);
}
};
// Move assign
@@ -380,12 +381,8 @@ NDArray<T, Ndim> NDArray<T, Ndim>::operator*(const T &value) {
result *= value;
return result;
}
// template <typename T, ssize_t Ndim> void NDArray<T, Ndim>::Print() {
// if (shape_[0] < 20 && shape_[1] < 20)
// Print_all();
// else
// Print_some();
// }
template <typename T, ssize_t Ndim>
std::ostream &operator<<(std::ostream &os, const NDArray<T, Ndim> &arr) {
@@ -437,4 +434,22 @@ NDArray<T, Ndim> load(const std::string &pathname,
return img;
}
template <typename RT, typename NT, typename DT, ssize_t Ndim>
NDArray<RT, Ndim> safe_divide(const NDArray<NT, Ndim> &numerator,
const NDArray<DT, Ndim> &denominator) {
if (numerator.shape() != denominator.shape()) {
throw std::runtime_error("Shapes of numerator and denominator must match");
}
NDArray<RT, Ndim> result(numerator.shape());
for (ssize_t i = 0; i < numerator.size(); ++i) {
if (denominator[i] != 0) {
result[i] = static_cast<RT>(numerator[i]) / static_cast<RT>(denominator[i]);
} else {
result[i] = RT{0}; // or handle division by zero as needed
}
}
return result;
}
} // namespace aare

View File

@@ -161,6 +161,9 @@ NDArray<T, 2> calculate_pedestal_g0(NDView<uint16_t, 3> raw_data) {
}
return pedestal;
}
template <typename T>
NDArray<T, 2> calculate_pedestal_g0(NDView<uint16_t, 3> raw_data,
ssize_t n_threads) {
@@ -191,20 +194,8 @@ NDArray<T, 2> calculate_pedestal_g0(NDView<uint16_t, 3> raw_data,
accumulator += acc;
count += cnt;
}
NDArray<T, 2> pedestal(
std::array<ssize_t,2>{raw_data.shape(1), raw_data.shape(2)}, 0);
for (int gain = 0; gain < 3; ++gain) {
for (int row = 0; row < raw_data.shape(1); ++row) {
for (int col = 0; col < raw_data.shape(2); ++col) {
if (count(row, col) != 0) {
pedestal(row, col) =
static_cast<T>(accumulator(row, col)) /
static_cast<T>(count(row, col));
}
}
}
}
return pedestal;
return safe_divide<T>(accumulator, count);
}
@@ -212,8 +203,6 @@ NDArray<T, 2> calculate_pedestal_g0(NDView<uint16_t, 3> raw_data,
template <typename T>
NDArray<T, 3> calculate_pedestal(NDView<uint16_t, 3> raw_data,
ssize_t n_threads) {
NDArray<int, 2> switched(
std::array<ssize_t, 2>{raw_data.shape(1), raw_data.shape(2)}, 0);
std::vector<std::future<std::pair<NDArray<size_t, 3>, NDArray<size_t, 3>>>>
futures;
futures.reserve(n_threads);
@@ -244,20 +233,7 @@ NDArray<T, 3> calculate_pedestal(NDView<uint16_t, 3> raw_data,
count += cnt;
}
NDArray<T, 3> pedestal(
std::array<ssize_t, 3>{3, raw_data.shape(1), raw_data.shape(2)}, 0);
for (int gain = 0; gain < 3; ++gain) {
for (int row = 0; row < raw_data.shape(1); ++row) {
for (int col = 0; col < raw_data.shape(2); ++col) {
if (count(gain, row, col) != 0) {
pedestal(gain, row, col) =
static_cast<T>(accumulator(gain, row, col)) /
static_cast<T>(count(gain, row, col));
}
}
}
}
return pedestal;
return safe_divide<T>(accumulator, count);
}
/**