move constructor for Ndim-1

This commit is contained in:
froejdh_e
2025-07-25 10:40:32 +02:00
parent 1347158235
commit d6222027d0
4 changed files with 79 additions and 32 deletions

View File

@@ -25,7 +25,7 @@ template <typename T, ssize_t Ndim = 2>
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> { class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
std::array<ssize_t, Ndim> shape_; std::array<ssize_t, Ndim> shape_;
std::array<ssize_t, Ndim> strides_; std::array<ssize_t, Ndim> strides_;
size_t size_{}; size_t size_{}; //TODO! do we need to store size when we have shape?
T *data_; T *data_;
public: public:
@@ -43,8 +43,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
*/ */
explicit NDArray(std::array<ssize_t, Ndim> shape) explicit NDArray(std::array<ssize_t, Ndim> shape)
: shape_(shape), strides_(c_strides<Ndim>(shape_)), : shape_(shape), strides_(c_strides<Ndim>(shape_)),
size_(std::accumulate(shape_.begin(), shape_.end(), 1, size_(num_elements(shape_)),
std::multiplies<>())),
data_(new T[size_]) {} data_(new T[size_]) {}
/** /**
@@ -79,6 +78,24 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
other.reset(); // TODO! is this necessary? other.reset(); // TODO! is this necessary?
} }
//Move constructor from an an array with Ndim + 1
template <ssize_t M, typename = std::enable_if_t<(M == Ndim + 1)>>
NDArray(NDArray<T, M> &&other)
: shape_(drop_first_dim(other.shape())),
strides_(c_strides<Ndim>(shape_)), size_(num_elements(shape_)),
data_(other.data()) {
// For now only allow move if the size matches, to avoid unreachable data
// if the use case arises we can remove this check
if(size() != other.size()) {
data_ = nullptr; // avoid double free, other will clean up the memory in it's destructor
throw std::runtime_error(LOCATION +
"Size mismatch in move constructor of NDArray<T, Ndim-1>");
}
other.reset();
}
// Copy constructor // Copy constructor
NDArray(const NDArray &other) NDArray(const NDArray &other)
: shape_(other.shape_), strides_(c_strides<Ndim>(shape_)), : shape_(other.shape_), strides_(c_strides<Ndim>(shape_)),
@@ -212,24 +229,6 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
void Print_all(); void Print_all();
void Print_some(); void Print_some();
template <ssize_t M = Ndim, typename = std::enable_if_t<(M > 2)>>
NDArray<T, Ndim - 1> drop_dimension() && {
std::array<ssize_t, Ndim - 1> new_shape;
std::copy(shape_.begin() + 1, shape_.begin() + Ndim, new_shape.begin());
NDArray<T, Ndim - 1> new_array(new_shape);
delete new_array.data();
new_array.data_ref() = data_;
this->reset();
return new_array;
}
void reset() { void reset() {
data_ = nullptr; data_ = nullptr;
size_ = 0; size_ = 0;

View File

@@ -26,6 +26,32 @@ Shape<Ndim> make_shape(const std::vector<size_t> &shape) {
return arr; return arr;
} }
/**
* @brief Helper function to drop the first dimension of a shape.
* This is useful when you want to create a 2D view from a 3D array.
* @param shape The shape to drop the first dimension from.
* @return A new shape with the first dimension dropped.
*/
template<size_t Ndim>
Shape<Ndim-1> drop_first_dim(const Shape<Ndim> &shape) {
Shape<Ndim - 1> new_shape;
std::copy(shape.begin() + 1, shape.end(), new_shape.begin());
return new_shape;
}
/**
* @brief Helper function when constructing NDArray/NDView. Calculates the number
* of elements in the resulting array from a shape.
* @param shape The shape to calculate the number of elements for.
* @return The number of elements in and NDArray/NDView of that shape.
*/
template <size_t Ndim>
size_t num_elements(const Shape<Ndim> &shape) {
return std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<size_t>());
}
template <ssize_t Dim = 0, typename Strides> template <ssize_t Dim = 0, typename Strides>
ssize_t element_offset(const Strides & /*unused*/) { ssize_t element_offset(const Strides & /*unused*/) {
return 0; return 0;

View File

@@ -163,24 +163,22 @@ calculate_pedestal(NDView<uint16_t, 3> raw_data, ssize_t n_threads) {
NDView<uint16_t, 3>)>(&sum_and_count_per_gain<only_gain0>), NDView<uint16_t, 3>)>(&sum_and_count_per_gain<only_gain0>),
view)); view));
} }
Shape<3> shape{num_gains, raw_data.shape(1), raw_data.shape(2)};
NDArray<size_t, 3> accumulator(shape, 0);
NDArray<size_t, 3> count(shape, 0);
NDArray<size_t, 3> accumulator( // Combine the results from the futures
std::array<ssize_t, 3>{num_gains, raw_data.shape(1), raw_data.shape(2)},
0);
NDArray<size_t, 3> count(
std::array<ssize_t, 3>{num_gains, raw_data.shape(1), raw_data.shape(2)},
0);
for (auto &f : futures) { for (auto &f : futures) {
auto [acc, cnt] = f.get(); auto [acc, cnt] = f.get();
accumulator += acc; accumulator += acc;
count += cnt; count += cnt;
} }
if constexpr (only_gain0) {
return safe_divide<T>(accumulator, count).drop_dimension(); // Will move to a NDArray<T, 3 - static_cast<ssize_t>(only_gain0)>
} else { // if only_gain0 is true
return safe_divide<T>(accumulator, count); return safe_divide<T>(accumulator, count);
}
} }
/** /**

View File

@@ -427,4 +427,28 @@ TEST_CASE("Construct an NDArray from an std::array") {
for (uint32_t i = 0; i < a.size(); ++i) { for (uint32_t i = 0; i < a.size(); ++i) {
REQUIRE(a(i) == b[i]); REQUIRE(a(i) == b[i]);
} }
}
TEST_CASE("Move construct from an array with Ndim + 1") {
NDArray<int, 3> a({{1,2,2}}, 0);
a(0, 0, 0) = 1;
a(0, 0, 1) = 2;
a(0, 1, 0) = 3;
a(0, 1, 1) = 4;
NDArray<int, 2> b(std::move(a));
REQUIRE(b.shape() == Shape<2>{2,2});
REQUIRE(b.size() == 4);
REQUIRE(b(0, 0) == 1);
REQUIRE(b(0, 1) == 2);
REQUIRE(b(1, 0) == 3);
REQUIRE(b(1, 1) == 4);
}
TEST_CASE("Move construct from an array with Ndim + 1 throws on size mismatch") {
NDArray<int, 3> a({{2,2,2}}, 0);
REQUIRE_THROWS(NDArray<int, 2>(std::move(a)));
} }