Compare commits

1 Commits

Author SHA1 Message Date
Erik Fröjdh
1f2d937d74 test implementation 2025-12-19 14:57:07 +01:00
4 changed files with 56 additions and 92 deletions

View File

@@ -5,9 +5,49 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace aare {
template <ssize_t Dim = 0, typename Strides>
ssize_t element_offset(const Strides & /*unused*/) {
return 0;
}
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) {
return i * strides[Dim] + element_offset<Dim + 1>(strides, index...);
}
template <typename Derived, typename T, ssize_t Ndim>
class NDIndexOps {
public:
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
return derived().data()[element_offset(derived().strides(), index...)];
}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, const T &> operator()(Ix... index) const {
return derived().data()[element_offset(derived().strides(), index...)];
}
T &operator()(ssize_t i) {
return derived().data()[i];
}
const T &operator()(ssize_t i) const {
return derived().data()[i];
}
T &operator[](ssize_t i) { return derived().data()[i]; }
const T &operator[](ssize_t i) const { return derived().data()[i]; }
private:
Derived &derived() { return static_cast<Derived &>(*this); }
const Derived &derived() const { return static_cast<const Derived &>(*this); }
};
template <typename E, ssize_t Ndim> class ArrayExpr {
public:
static constexpr bool is_leaf = false;
@@ -96,4 +136,4 @@ auto operator/(const ArrayExpr<A, Ndim> &arr1, const ArrayExpr<B, Ndim> &arr2) {
return ArrayDiv<ArrayExpr<A, Ndim>, ArrayExpr<B, Ndim>, Ndim>(arr1, arr2);
}
} // namespace aare
} // namespace aare

View File

@@ -20,7 +20,8 @@
namespace aare {
template <typename T, ssize_t Ndim = 2>
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim>,
public NDIndexOps<NDArray<T, Ndim>, T, Ndim> {
std::array<ssize_t, Ndim> shape_;
std::array<ssize_t, Ndim> strides_;
size_t size_{}; // TODO! do we need to store size when we have shape?
@@ -151,61 +152,15 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
//
///////////////////////////////////////////////////////////////////////////////
using NDIndexOps<NDArray<T, Ndim>, T, Ndim>::operator();
using NDIndexOps<NDArray<T, Ndim>, T, Ndim>::operator[];
auto *begin() { return data_; }
const auto *begin() const { return data_; }
auto *end() { return data_ + size_; }
const auto *end() const { return data_ + size_; }
/*
* @brief Access element at given multi-dimensional index.
* i.e. arr(i,j,k,...)
*
* @note The fast index is the last index. Please take care when iterating
* through the array.
*/
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
return data_[element_offset(strides_, index...)];
}
/*
* @brief Access element at given multi-dimensional index (const version).
* i.e. arr(i,j,k,...)
*
* @note The fast index is the last index. Please take care when iterating
* through the array.
*/
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, const T &>
operator()(Ix... index) const {
return data_[element_offset(strides_, index...)];
}
/*
@brief Index the array as it would be a 1D array. To get a certain
pixel in a multidimensional array use the (i,j,k,...) operator instead.
*/
T &operator()(ssize_t i) { return data_[i]; }
/*
@brief Index the array as it would be a 1D array. To get a certain
pixel in a multidimensional array use the (i,j,k,...) operator instead.
*/
const T &operator()(ssize_t i) const { return data_[i]; }
/*
@brief Index the array as it would be a 1D array. To get a certain
pixel in a multidimensional array use the (i,j,k,...) operator instead.
*/
T &operator[](ssize_t i) { return data_[i]; }
/*
@brief Index the array as it would be a 1D array. To get a certain
pixel in a multidimensional array use the (i,j,k,...) operator instead.
*/
const T &operator[](ssize_t i) const { return data_[i]; }
/* @brief Return a raw pointer to the data */
T *data() { return data_; }

View File

@@ -54,16 +54,6 @@ size_t num_elements(const Shape<Ndim> &shape) {
std::multiplies<size_t>());
}
template <ssize_t Dim = 0, typename Strides>
ssize_t element_offset(const Strides & /*unused*/) {
return 0;
}
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) {
return i * strides[Dim] + element_offset<Dim + 1>(strides, index...);
}
template <ssize_t Ndim>
std::array<ssize_t, Ndim> c_strides(const std::array<ssize_t, Ndim> &shape) {
std::array<ssize_t, Ndim> strides{};
@@ -83,7 +73,8 @@ std::array<ssize_t, Ndim> make_array(const std::vector<ssize_t> &vec) {
}
template <typename T, ssize_t Ndim = 2>
class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim>,
public NDIndexOps<NDView<T, Ndim>, T, Ndim> {
public:
NDView() = default;
~NDView() = default;
@@ -94,26 +85,9 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
: buffer_(buffer), strides_(c_strides<Ndim>(shape)), shape_(shape),
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
std::multiplies<>())) {}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
return buffer_[element_offset(strides_, index...)];
}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == 1 && (Ndim > 1), NDView<T, Ndim - 1>> operator()(Ix... index) {
// return a view of the next dimension
std::array<ssize_t, Ndim - 1> new_shape{};
std::copy_n(shape_.begin() + 1, Ndim - 1, new_shape.begin());
return NDView<T, Ndim - 1>(&buffer_[element_offset(strides_, index...)],
new_shape);
}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, const T &> operator()(Ix... index) const {
return buffer_[element_offset(strides_, index...)];
}
using NDIndexOps<NDView<T, Ndim>, T, Ndim>::operator();
using NDIndexOps<NDView<T, Ndim>, T, Ndim>::operator[];
ssize_t size() const { return static_cast<ssize_t>(size_); }
@@ -129,16 +103,6 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
/**
* @brief Access element at index i.
*/
T &operator[](ssize_t i) { return buffer_[i]; }
/**
* @brief Access element at index i.
*/
const T &operator[](ssize_t i) const { return buffer_[i]; }
bool operator==(const NDView &other) const {
if (size_ != other.size_)
return false;
@@ -270,4 +234,4 @@ template <typename T> NDView<T, 1> make_view(std::vector<T> &vec) {
return NDView<T, 1>(vec.data(), {static_cast<ssize_t>(vec.size())});
}
} // namespace aare
} // namespace aare

View File

@@ -552,3 +552,8 @@ TEST_CASE("Move construct from an array with Ndim + 1 throws on size mismatch")
REQUIRE_THROWS(NDArray<int, 2>(std::move(a)));
}
TEST_CASE("Assign an element in a 2D array"){
NDArray<int,2> a({{3,4}},0);
a(1,2) = 57;
REQUIRE(a(1,2) == 57);
}