templated calculate_pedestal with boolean template argument only_gain0, added drop_dimension to NDArray and reference pointer to data

This commit is contained in:
2025-07-24 15:40:05 +02:00
parent 8c4d8b687e
commit 1347158235
5 changed files with 85 additions and 92 deletions

View File

@@ -33,7 +33,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
* @brief Default constructor. Will construct an empty NDArray.
*
*/
NDArray() : shape_(), strides_(c_strides<Ndim>(shape_)), data_(nullptr){};
NDArray() : shape_(), strides_(c_strides<Ndim>(shape_)), data_(nullptr) {};
/**
* @brief Construct a new NDArray object with a given shape.
@@ -185,6 +185,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
const T &operator[](ssize_t i) const { return data_[i]; }
T *data() { return data_; }
T *&data_ref() { return data_; }
std::byte *buffer() { return reinterpret_cast<std::byte *>(data_); }
ssize_t size() const { return static_cast<ssize_t>(size_); }
size_t total_bytes() const { return size_ * sizeof(T); }
@@ -211,13 +212,30 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
void Print_all();
void Print_some();
template <ssize_t M = Ndim, typename = std::enable_if_t<(M > 2)>>
NDArray<T, Ndim - 1> drop_dimension() && {
std::array<ssize_t, Ndim - 1> new_shape;
std::copy(shape_.begin() + 1, shape_.begin() + Ndim, new_shape.begin());
NDArray<T, Ndim - 1> new_array(new_shape);
delete new_array.data();
new_array.data_ref() = data_;
this->reset();
return new_array;
}
void reset() {
data_ = nullptr;
size_ = 0;
std::fill(shape_.begin(), shape_.end(), 0);
std::fill(strides_.begin(), strides_.end(), 0);
}
};
// Move assign
@@ -382,8 +400,6 @@ NDArray<T, Ndim> NDArray<T, Ndim>::operator*(const T &value) {
return result;
}
template <typename T, ssize_t Ndim>
std::ostream &operator<<(std::ostream &os, const NDArray<T, Ndim> &arr) {
for (auto row = 0; row < arr.shape(0); ++row) {
@@ -434,17 +450,18 @@ NDArray<T, Ndim> load(const std::string &pathname,
return img;
}
template <typename RT, typename NT, typename DT, ssize_t Ndim>
NDArray<RT, Ndim> safe_divide(const NDArray<NT, Ndim> &numerator,
const NDArray<DT, Ndim> &denominator) {
const NDArray<DT, Ndim> &denominator) {
if (numerator.shape() != denominator.shape()) {
throw std::runtime_error("Shapes of numerator and denominator must match");
throw std::runtime_error(
"Shapes of numerator and denominator must match");
}
NDArray<RT, Ndim> result(numerator.shape());
for (ssize_t i = 0; i < numerator.size(); ++i) {
if (denominator[i] != 0) {
result[i] = static_cast<RT>(numerator[i]) / static_cast<RT>(denominator[i]);
result[i] =
static_cast<RT>(numerator[i]) / static_cast<RT>(denominator[i]);
} else {
result[i] = RT{0}; // or handle division by zero as needed
}