diff --git a/include/aare/NDView.hpp b/include/aare/NDView.hpp index e7ad002..28f5371 100644 --- a/include/aare/NDView.hpp +++ b/include/aare/NDView.hpp @@ -67,20 +67,13 @@ class NDView : public ArrayExpr, Ndim> { size_(std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<>())) {} - // NDView(T *buffer, const std::vector &shape) - // : buffer_(buffer), - // strides_(c_strides(make_array(shape))), - // shape_(make_array(shape)), - // size_(std::accumulate(std::begin(shape), std::end(shape), 1, - // std::multiplies<>())) {} - template std::enable_if_t operator()(Ix... index) { return buffer_[element_offset(strides_, index...)]; } template - std::enable_if_t operator()(Ix... index) const { + const std::enable_if_t operator()(Ix... index) const { return buffer_[element_offset(strides_, index...)]; } @@ -92,13 +85,17 @@ class NDView : public ArrayExpr, Ndim> { T *end() { return buffer_ + size_; } T const *begin() const { return buffer_; } T const *end() const { return buffer_ + size_; } - T &operator()(ssize_t i) const { return buffer_[i]; } - T &operator[](ssize_t i) const { return buffer_[i]; } + T &operator()(ssize_t i) { return buffer_[i]; } + T &operator[](ssize_t i) { return buffer_[i]; } + const T &operator()(ssize_t i) const { return buffer_[i]; } + const T &operator[](ssize_t i) const { return buffer_[i]; } bool operator==(const NDView &other) const { if (size_ != other.size_) return false; - for (uint64_t i = 0; i != size_; ++i) { + if (shape_ != other.shape_) + return false; + for (size_t i = 0; i != size_; ++i) { if (buffer_[i] != other.buffer_[i]) return false; } @@ -157,6 +154,7 @@ class NDView : public ArrayExpr, Ndim> { auto shape(ssize_t i) const { return shape_[i]; } T *data() { return buffer_; } + const T *data() const { return buffer_; } void print_all() const; private: @@ -180,6 +178,7 @@ class NDView : public ArrayExpr, Ndim> { return *this; } }; + template void NDView::print_all() const { for (auto row = 0; row < shape_[0]; ++row) { for (auto col = 0; col < shape_[1]; ++col) { diff --git a/src/NDView.test.cpp b/src/NDView.test.cpp index a65758c..04f56dd 100644 --- a/src/NDView.test.cpp +++ b/src/NDView.test.cpp @@ -21,6 +21,57 @@ TEST_CASE("Element reference 1D") { } } + +TEST_CASE("Assign elements through () and []") { + std::vector vec; + for (int i = 0; i != 10; ++i) { + vec.push_back(i); + } + NDView data(vec.data(), Shape<1>{10}); + REQUIRE(vec.size() == static_cast(data.size())); + + data[3] = 187; + data(4) = 512; + + + REQUIRE(data(0) == 0); + REQUIRE(data[0] == 0); + REQUIRE(data(1) == 1); + REQUIRE(data[1] == 1); + REQUIRE(data(2) == 2); + REQUIRE(data[2] == 2); + REQUIRE(data(3) == 187); + REQUIRE(data[3] == 187); + REQUIRE(data(4) == 512); + REQUIRE(data[4] == 512); + REQUIRE(data(5) == 5); + REQUIRE(data[5] == 5); + REQUIRE(data(6) == 6); + REQUIRE(data[6] == 6); + REQUIRE(data(7) == 7); + REQUIRE(data[7] == 7); + REQUIRE(data(8) == 8); + REQUIRE(data[8] == 8); + REQUIRE(data(9) == 9); + REQUIRE(data[9] == 9); + + +} + +TEST_CASE("Element reference 1D with a const NDView") { + std::vector vec; + for (int i = 0; i != 10; ++i) { + vec.push_back(i); + } + const NDView data(vec.data(), Shape<1>{10}); + REQUIRE(vec.size() == static_cast(data.size())); + for (int i = 0; i != 10; ++i) { + REQUIRE(data(i) == vec[i]); + REQUIRE(data[i] == vec[i]); + } +} + + TEST_CASE("Element reference 2D") { std::vector vec(12); std::iota(vec.begin(), vec.end(), 0); @@ -56,7 +107,7 @@ TEST_CASE("Element reference 3D") { } } -TEST_CASE("Plus and miuns with single value") { +TEST_CASE("Plus and minus with single value") { std::vector vec(12); std::iota(vec.begin(), vec.end(), 0); NDView data(vec.data(), Shape<2>{3, 4}); @@ -137,16 +188,9 @@ TEST_CASE("iterators") { } } -// TEST_CASE("shape from vector") { -// std::vector vec; -// for (int i = 0; i != 12; ++i) { -// vec.push_back(i); -// } -// std::vector shape{3, 4}; -// NDView data(vec.data(), shape); -// } -TEST_CASE("divide with another span") { + +TEST_CASE("divide with another NDView") { std::vector vec0{9, 12, 3}; std::vector vec1{3, 2, 1}; std::vector result{3, 6, 3}; @@ -183,6 +227,30 @@ TEST_CASE("compare two views") { REQUIRE((view1 == view2)); } +TEST_CASE("Compare two views with different size"){ + std::vector vec1(12); + std::iota(vec1.begin(), vec1.end(), 0); + NDView view1(vec1.data(), Shape<2>{3, 4}); + + std::vector vec2(8); + std::iota(vec2.begin(), vec2.end(), 0); + NDView view2(vec2.data(), Shape<2>{2, 4}); + + REQUIRE_FALSE(view1 == view2); +} + +TEST_CASE("Compare two views with same size but different shape"){ + std::vector vec1(12); + std::iota(vec1.begin(), vec1.end(), 0); + NDView view1(vec1.data(), Shape<2>{3, 4}); + + std::vector vec2(12); + std::iota(vec2.begin(), vec2.end(), 0); + NDView view2(vec2.data(), Shape<2>{2, 6}); + + REQUIRE_FALSE(view1 == view2); +} + TEST_CASE("Create a view over a vector") { std::vector vec(12); std::iota(vec.begin(), vec.end(), 0);