move constructor for Ndim-1

This commit is contained in:
froejdh_e
2025-07-25 10:40:32 +02:00
parent 1347158235
commit d6222027d0
4 changed files with 79 additions and 32 deletions

View File

@@ -25,7 +25,7 @@ template <typename T, ssize_t Ndim = 2>
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
std::array<ssize_t, Ndim> shape_;
std::array<ssize_t, Ndim> strides_;
size_t size_{};
size_t size_{}; //TODO! do we need to store size when we have shape?
T *data_;
public:
@@ -43,8 +43,7 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
*/
explicit NDArray(std::array<ssize_t, Ndim> shape)
: shape_(shape), strides_(c_strides<Ndim>(shape_)),
size_(std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<>())),
size_(num_elements(shape_)),
data_(new T[size_]) {}
/**
@@ -79,6 +78,24 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
other.reset(); // TODO! is this necessary?
}
//Move constructor from an an array with Ndim + 1
template <ssize_t M, typename = std::enable_if_t<(M == Ndim + 1)>>
NDArray(NDArray<T, M> &&other)
: shape_(drop_first_dim(other.shape())),
strides_(c_strides<Ndim>(shape_)), size_(num_elements(shape_)),
data_(other.data()) {
// For now only allow move if the size matches, to avoid unreachable data
// if the use case arises we can remove this check
if(size() != other.size()) {
data_ = nullptr; // avoid double free, other will clean up the memory in it's destructor
throw std::runtime_error(LOCATION +
"Size mismatch in move constructor of NDArray<T, Ndim-1>");
}
other.reset();
}
// Copy constructor
NDArray(const NDArray &other)
: shape_(other.shape_), strides_(c_strides<Ndim>(shape_)),
@@ -212,24 +229,6 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
void Print_all();
void Print_some();
template <ssize_t M = Ndim, typename = std::enable_if_t<(M > 2)>>
NDArray<T, Ndim - 1> drop_dimension() && {
std::array<ssize_t, Ndim - 1> new_shape;
std::copy(shape_.begin() + 1, shape_.begin() + Ndim, new_shape.begin());
NDArray<T, Ndim - 1> new_array(new_shape);
delete new_array.data();
new_array.data_ref() = data_;
this->reset();
return new_array;
}
void reset() {
data_ = nullptr;
size_ = 0;