move constructor for Ndim-1

This commit is contained in:
froejdh_e
2025-07-25 10:40:32 +02:00
parent 1347158235
commit d6222027d0
4 changed files with 79 additions and 32 deletions

View File

@@ -25,7 +25,7 @@ template <typename T, ssize_t Ndim = 2>
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
std::array<ssize_t, Ndim> shape_;
std::array<ssize_t, Ndim> strides_;
size_t size_{};
size_t size_{}; //TODO! do we need to store size when we have shape?
T *data_;
public:
@@ -43,8 +43,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
*/
explicit NDArray(std::array<ssize_t, Ndim> shape)
: shape_(shape), strides_(c_strides<Ndim>(shape_)),
size_(std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<>())),
size_(num_elements(shape_)),
data_(new T[size_]) {}
/**
@@ -79,6 +78,24 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
other.reset(); // TODO! is this necessary?
}
//Move constructor from an an array with Ndim + 1
template <ssize_t M, typename = std::enable_if_t<(M == Ndim + 1)>>
NDArray(NDArray<T, M> &&other)
: shape_(drop_first_dim(other.shape())),
strides_(c_strides<Ndim>(shape_)), size_(num_elements(shape_)),
data_(other.data()) {
// For now only allow move if the size matches, to avoid unreachable data
// if the use case arises we can remove this check
if(size() != other.size()) {
data_ = nullptr; // avoid double free, other will clean up the memory in it's destructor
throw std::runtime_error(LOCATION +
"Size mismatch in move constructor of NDArray<T, Ndim-1>");
}
other.reset();
}
// Copy constructor
NDArray(const NDArray &other)
: shape_(other.shape_), strides_(c_strides<Ndim>(shape_)),
@@ -212,24 +229,6 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
void Print_all();
void Print_some();
template <ssize_t M = Ndim, typename = std::enable_if_t<(M > 2)>>
NDArray<T, Ndim - 1> drop_dimension() && {
std::array<ssize_t, Ndim - 1> new_shape;
std::copy(shape_.begin() + 1, shape_.begin() + Ndim, new_shape.begin());
NDArray<T, Ndim - 1> new_array(new_shape);
delete new_array.data();
new_array.data_ref() = data_;
this->reset();
return new_array;
}
void reset() {
data_ = nullptr;
size_ = 0;

View File

@@ -26,6 +26,32 @@ Shape<Ndim> make_shape(const std::vector<size_t> &shape) {
return arr;
}
/**
* @brief Helper function to drop the first dimension of a shape.
* This is useful when you want to create a 2D view from a 3D array.
* @param shape The shape to drop the first dimension from.
* @return A new shape with the first dimension dropped.
*/
template<size_t Ndim>
Shape<Ndim-1> drop_first_dim(const Shape<Ndim> &shape) {
Shape<Ndim - 1> new_shape;
std::copy(shape.begin() + 1, shape.end(), new_shape.begin());
return new_shape;
}
/**
* @brief Helper function when constructing NDArray/NDView. Calculates the number
* of elements in the resulting array from a shape.
* @param shape The shape to calculate the number of elements for.
* @return The number of elements in and NDArray/NDView of that shape.
*/
template <size_t Ndim>
size_t num_elements(const Shape<Ndim> &shape) {
return std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<size_t>());
}
template <ssize_t Dim = 0, typename Strides>
ssize_t element_offset(const Strides & /*unused*/) {
return 0;

View File

@@ -163,24 +163,22 @@ calculate_pedestal(NDView<uint16_t, 3> raw_data, ssize_t n_threads) {
NDView<uint16_t, 3>)>(&sum_and_count_per_gain<only_gain0>),
view));
}
Shape<3> shape{num_gains, raw_data.shape(1), raw_data.shape(2)};
NDArray<size_t, 3> accumulator(shape, 0);
NDArray<size_t, 3> count(shape, 0);
NDArray<size_t, 3> accumulator(
std::array<ssize_t, 3>{num_gains, raw_data.shape(1), raw_data.shape(2)},
0);
NDArray<size_t, 3> count(
std::array<ssize_t, 3>{num_gains, raw_data.shape(1), raw_data.shape(2)},
0);
// Combine the results from the futures
for (auto &f : futures) {
auto [acc, cnt] = f.get();
accumulator += acc;
count += cnt;
}
if constexpr (only_gain0) {
return safe_divide<T>(accumulator, count).drop_dimension();
} else {
return safe_divide<T>(accumulator, count);
}
// Will move to a NDArray<T, 3 - static_cast<ssize_t>(only_gain0)>
// if only_gain0 is true
return safe_divide<T>(accumulator, count);
}
/**