Const element access and fixed comparing bug (#208)

- Added const element access
- Added const data*
- Fixed bug comparing two Views of same size but different shapes

closes #207
This commit is contained in:
Erik Fröjdh
2025-06-27 14:13:51 +02:00
committed by GitHub
parent 6ec8fbee72
commit e3f4b34b72
2 changed files with 88 additions and 21 deletions

View File

@ -67,20 +67,13 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
std::multiplies<>())) {}
// NDView(T *buffer, const std::vector<ssize_t> &shape)
// : buffer_(buffer),
// strides_(c_strides<Ndim>(make_array<Ndim>(shape))),
// shape_(make_array<Ndim>(shape)),
// size_(std::accumulate(std::begin(shape), std::end(shape), 1,
// std::multiplies<>())) {}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
return buffer_[element_offset(strides_, index...)];
}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
const std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
return buffer_[element_offset(strides_, index...)];
}
@ -92,13 +85,17 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
T *end() { return buffer_ + size_; }
T const *begin() const { return buffer_; }
T const *end() const { return buffer_ + size_; }
T &operator()(ssize_t i) const { return buffer_[i]; }
T &operator[](ssize_t i) const { return buffer_[i]; }
T &operator()(ssize_t i) { return buffer_[i]; }
T &operator[](ssize_t i) { return buffer_[i]; }
const T &operator()(ssize_t i) const { return buffer_[i]; }
const T &operator[](ssize_t i) const { return buffer_[i]; }
bool operator==(const NDView &other) const {
if (size_ != other.size_)
return false;
for (uint64_t i = 0; i != size_; ++i) {
if (shape_ != other.shape_)
return false;
for (size_t i = 0; i != size_; ++i) {
if (buffer_[i] != other.buffer_[i])
return false;
}
@ -157,6 +154,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
auto shape(ssize_t i) const { return shape_[i]; }
T *data() { return buffer_; }
const T *data() const { return buffer_; }
void print_all() const;
private:
@ -180,6 +178,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
return *this;
}
};
template <typename T, ssize_t Ndim> void NDView<T, Ndim>::print_all() const {
for (auto row = 0; row < shape_[0]; ++row) {
for (auto col = 0; col < shape_[1]; ++col) {

View File

@ -21,6 +21,57 @@ TEST_CASE("Element reference 1D") {
}
}
TEST_CASE("Assign elements through () and []") {
std::vector<int> vec;
for (int i = 0; i != 10; ++i) {
vec.push_back(i);
}
NDView<int, 1> data(vec.data(), Shape<1>{10});
REQUIRE(vec.size() == static_cast<size_t>(data.size()));
data[3] = 187;
data(4) = 512;
REQUIRE(data(0) == 0);
REQUIRE(data[0] == 0);
REQUIRE(data(1) == 1);
REQUIRE(data[1] == 1);
REQUIRE(data(2) == 2);
REQUIRE(data[2] == 2);
REQUIRE(data(3) == 187);
REQUIRE(data[3] == 187);
REQUIRE(data(4) == 512);
REQUIRE(data[4] == 512);
REQUIRE(data(5) == 5);
REQUIRE(data[5] == 5);
REQUIRE(data(6) == 6);
REQUIRE(data[6] == 6);
REQUIRE(data(7) == 7);
REQUIRE(data[7] == 7);
REQUIRE(data(8) == 8);
REQUIRE(data[8] == 8);
REQUIRE(data(9) == 9);
REQUIRE(data[9] == 9);
}
TEST_CASE("Element reference 1D with a const NDView") {
std::vector<int> vec;
for (int i = 0; i != 10; ++i) {
vec.push_back(i);
}
const NDView<int, 1> data(vec.data(), Shape<1>{10});
REQUIRE(vec.size() == static_cast<size_t>(data.size()));
for (int i = 0; i != 10; ++i) {
REQUIRE(data(i) == vec[i]);
REQUIRE(data[i] == vec[i]);
}
}
TEST_CASE("Element reference 2D") {
std::vector<int> vec(12);
std::iota(vec.begin(), vec.end(), 0);
@ -56,7 +107,7 @@ TEST_CASE("Element reference 3D") {
}
}
TEST_CASE("Plus and miuns with single value") {
TEST_CASE("Plus and minus with single value") {
std::vector<int> vec(12);
std::iota(vec.begin(), vec.end(), 0);
NDView<int, 2> data(vec.data(), Shape<2>{3, 4});
@ -137,16 +188,9 @@ TEST_CASE("iterators") {
}
}
// TEST_CASE("shape from vector") {
// std::vector<int> vec;
// for (int i = 0; i != 12; ++i) {
// vec.push_back(i);
// }
// std::vector<ssize_t> shape{3, 4};
// NDView<int, 2> data(vec.data(), shape);
// }
TEST_CASE("divide with another span") {
TEST_CASE("divide with another NDView") {
std::vector<int> vec0{9, 12, 3};
std::vector<int> vec1{3, 2, 1};
std::vector<int> result{3, 6, 3};
@ -183,6 +227,30 @@ TEST_CASE("compare two views") {
REQUIRE((view1 == view2));
}
TEST_CASE("Compare two views with different size"){
std::vector<int> vec1(12);
std::iota(vec1.begin(), vec1.end(), 0);
NDView<int, 2> view1(vec1.data(), Shape<2>{3, 4});
std::vector<int> vec2(8);
std::iota(vec2.begin(), vec2.end(), 0);
NDView<int, 2> view2(vec2.data(), Shape<2>{2, 4});
REQUIRE_FALSE(view1 == view2);
}
TEST_CASE("Compare two views with same size but different shape"){
std::vector<int> vec1(12);
std::iota(vec1.begin(), vec1.end(), 0);
NDView<int, 2> view1(vec1.data(), Shape<2>{3, 4});
std::vector<int> vec2(12);
std::iota(vec2.begin(), vec2.end(), 0);
NDView<int, 2> view2(vec2.data(), Shape<2>{2, 6});
REQUIRE_FALSE(view1 == view2);
}
TEST_CASE("Create a view over a vector") {
std::vector<int> vec(12);
std::iota(vec.begin(), vec.end(), 0);