diff --git a/include/aare/ArrayExpr.hpp b/include/aare/ArrayExpr.hpp index 368f691..4d274eb 100644 --- a/include/aare/ArrayExpr.hpp +++ b/include/aare/ArrayExpr.hpp @@ -5,9 +5,49 @@ #include #include #include +#include namespace aare { +template +ssize_t element_offset(const Strides & /*unused*/) { + return 0; +} + +template +ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) { + return i * strides[Dim] + element_offset(strides, index...); +} + +template +class NDIndexOps { + public: + template + std::enable_if_t operator()(Ix... index) { + return derived().data()[element_offset(derived().strides(), index...)]; + } + + template + std::enable_if_t operator()(Ix... index) const { + return derived().data()[element_offset(derived().strides(), index...)]; + } + + T &operator()(ssize_t i) { + return derived().data()[i]; + } + + const T &operator()(ssize_t i) const { + return derived().data()[i]; + } + + T &operator[](ssize_t i) { return derived().data()[i]; } + const T &operator[](ssize_t i) const { return derived().data()[i]; } + + private: + Derived &derived() { return static_cast(*this); } + const Derived &derived() const { return static_cast(*this); } +}; + template class ArrayExpr { public: static constexpr bool is_leaf = false; @@ -96,4 +136,4 @@ auto operator/(const ArrayExpr &arr1, const ArrayExpr &arr2) { return ArrayDiv, ArrayExpr, Ndim>(arr1, arr2); } -} // namespace aare \ No newline at end of file +} // namespace aare diff --git a/include/aare/NDArray.hpp b/include/aare/NDArray.hpp index 079915d..1b73991 100644 --- a/include/aare/NDArray.hpp +++ b/include/aare/NDArray.hpp @@ -20,7 +20,8 @@ namespace aare { template -class NDArray : public ArrayExpr, Ndim> { +class NDArray : public ArrayExpr, Ndim>, + public NDIndexOps, T, Ndim> { std::array shape_; std::array strides_; size_t size_{}; // TODO! do we need to store size when we have shape? @@ -151,61 +152,15 @@ class NDArray : public ArrayExpr, Ndim> { // /////////////////////////////////////////////////////////////////////////////// + using NDIndexOps, T, Ndim>::operator(); + using NDIndexOps, T, Ndim>::operator[]; + auto *begin() { return data_; } const auto *begin() const { return data_; } auto *end() { return data_ + size_; } const auto *end() const { return data_ + size_; } - /* - * @brief Access element at given multi-dimensional index. - * i.e. arr(i,j,k,...) - * - * @note The fast index is the last index. Please take care when iterating - * through the array. - */ - template - std::enable_if_t operator()(Ix... index) { - return data_[element_offset(strides_, index...)]; - } - - /* - * @brief Access element at given multi-dimensional index (const version). - * i.e. arr(i,j,k,...) - * - * @note The fast index is the last index. Please take care when iterating - * through the array. - */ - template - std::enable_if_t - operator()(Ix... index) const { - return data_[element_offset(strides_, index...)]; - } - - /* - @brief Index the array as it would be a 1D array. To get a certain - pixel in a multidimensional array use the (i,j,k,...) operator instead. - */ - T &operator()(ssize_t i) { return data_[i]; } - - /* - @brief Index the array as it would be a 1D array. To get a certain - pixel in a multidimensional array use the (i,j,k,...) operator instead. - */ - const T &operator()(ssize_t i) const { return data_[i]; } - - /* - @brief Index the array as it would be a 1D array. To get a certain - pixel in a multidimensional array use the (i,j,k,...) operator instead. - */ - T &operator[](ssize_t i) { return data_[i]; } - - /* - @brief Index the array as it would be a 1D array. To get a certain - pixel in a multidimensional array use the (i,j,k,...) operator instead. - */ - const T &operator[](ssize_t i) const { return data_[i]; } - /* @brief Return a raw pointer to the data */ T *data() { return data_; } diff --git a/include/aare/NDView.hpp b/include/aare/NDView.hpp index 42a83a5..46b6964 100644 --- a/include/aare/NDView.hpp +++ b/include/aare/NDView.hpp @@ -54,16 +54,6 @@ size_t num_elements(const Shape &shape) { std::multiplies()); } -template -ssize_t element_offset(const Strides & /*unused*/) { - return 0; -} - -template -ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) { - return i * strides[Dim] + element_offset(strides, index...); -} - template std::array c_strides(const std::array &shape) { std::array strides{}; @@ -83,7 +73,8 @@ std::array make_array(const std::vector &vec) { } template -class NDView : public ArrayExpr, Ndim> { +class NDView : public ArrayExpr, Ndim>, + public NDIndexOps, T, Ndim> { public: NDView() = default; ~NDView() = default; @@ -94,26 +85,9 @@ class NDView : public ArrayExpr, Ndim> { : buffer_(buffer), strides_(c_strides(shape)), shape_(shape), size_(std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<>())) {} - - template - std::enable_if_t operator()(Ix... index) { - return buffer_[element_offset(strides_, index...)]; - } - template - std::enable_if_t 1), NDView> operator()(Ix... index) { - // return a view of the next dimension - std::array new_shape{}; - std::copy_n(shape_.begin() + 1, Ndim - 1, new_shape.begin()); - return NDView(&buffer_[element_offset(strides_, index...)], - new_shape); - - } - - template - std::enable_if_t operator()(Ix... index) const { - return buffer_[element_offset(strides_, index...)]; - } + using NDIndexOps, T, Ndim>::operator(); + using NDIndexOps, T, Ndim>::operator[]; ssize_t size() const { return static_cast(size_); } @@ -129,16 +103,6 @@ class NDView : public ArrayExpr, Ndim> { - /** - * @brief Access element at index i. - */ - T &operator[](ssize_t i) { return buffer_[i]; } - - /** - * @brief Access element at index i. - */ - const T &operator[](ssize_t i) const { return buffer_[i]; } - bool operator==(const NDView &other) const { if (size_ != other.size_) return false; @@ -270,4 +234,4 @@ template NDView make_view(std::vector &vec) { return NDView(vec.data(), {static_cast(vec.size())}); } -} // namespace aare \ No newline at end of file +} // namespace aare diff --git a/src/NDArray.test.cpp b/src/NDArray.test.cpp index a682800..c07c316 100644 --- a/src/NDArray.test.cpp +++ b/src/NDArray.test.cpp @@ -552,3 +552,8 @@ TEST_CASE("Move construct from an array with Ndim + 1 throws on size mismatch") REQUIRE_THROWS(NDArray(std::move(a))); } +TEST_CASE("Assign an element in a 2D array"){ + NDArray a({{3,4}},0); + a(1,2) = 57; + REQUIRE(a(1,2) == 57); +} \ No newline at end of file