added safe_divide to NDArray and used it for pedestal

This commit is contained in:
froejdh_e
2025-07-24 09:40:38 +02:00
parent cb439efb48
commit 0fea0f5b0e
2 changed files with 27 additions and 36 deletions

View File

@@ -217,6 +217,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
std::fill(shape_.begin(), shape_.end(), 0);
std::fill(strides_.begin(), strides_.end(), 0);
}
};
// Move assign
@@ -380,12 +381,8 @@ NDArray<T, Ndim> NDArray<T, Ndim>::operator*(const T &value) {
result *= value;
return result;
}
// template <typename T, ssize_t Ndim> void NDArray<T, Ndim>::Print() {
// if (shape_[0] < 20 && shape_[1] < 20)
// Print_all();
// else
// Print_some();
// }
template <typename T, ssize_t Ndim>
std::ostream &operator<<(std::ostream &os, const NDArray<T, Ndim> &arr) {
@@ -437,4 +434,22 @@ NDArray<T, Ndim> load(const std::string &pathname,
return img;
}
template <typename RT, typename NT, typename DT, ssize_t Ndim>
NDArray<RT, Ndim> safe_divide(const NDArray<NT, Ndim> &numerator,
const NDArray<DT, Ndim> &denominator) {
if (numerator.shape() != denominator.shape()) {
throw std::runtime_error("Shapes of numerator and denominator must match");
}
NDArray<RT, Ndim> result(numerator.shape());
for (ssize_t i = 0; i < numerator.size(); ++i) {
if (denominator[i] != 0) {
result[i] = static_cast<RT>(numerator[i]) / static_cast<RT>(denominator[i]);
} else {
result[i] = RT{0}; // or handle division by zero as needed
}
}
return result;
}
} // namespace aare