mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2025-07-03 08:20:47 +02:00
Const element access and fixed comparing bug (#208)
- Added const element access - Added const data* - Fixed bug comparing two Views of same size but different shapes closes #207
This commit is contained in:
@ -67,20 +67,13 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
|
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
|
||||||
std::multiplies<>())) {}
|
std::multiplies<>())) {}
|
||||||
|
|
||||||
// NDView(T *buffer, const std::vector<ssize_t> &shape)
|
|
||||||
// : buffer_(buffer),
|
|
||||||
// strides_(c_strides<Ndim>(make_array<Ndim>(shape))),
|
|
||||||
// shape_(make_array<Ndim>(shape)),
|
|
||||||
// size_(std::accumulate(std::begin(shape), std::end(shape), 1,
|
|
||||||
// std::multiplies<>())) {}
|
|
||||||
|
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
|
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
|
||||||
return buffer_[element_offset(strides_, index...)];
|
return buffer_[element_offset(strides_, index...)];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ix>
|
template <typename... Ix>
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
|
const std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
|
||||||
return buffer_[element_offset(strides_, index...)];
|
return buffer_[element_offset(strides_, index...)];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,13 +85,17 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
T *end() { return buffer_ + size_; }
|
T *end() { return buffer_ + size_; }
|
||||||
T const *begin() const { return buffer_; }
|
T const *begin() const { return buffer_; }
|
||||||
T const *end() const { return buffer_ + size_; }
|
T const *end() const { return buffer_ + size_; }
|
||||||
T &operator()(ssize_t i) const { return buffer_[i]; }
|
T &operator()(ssize_t i) { return buffer_[i]; }
|
||||||
T &operator[](ssize_t i) const { return buffer_[i]; }
|
T &operator[](ssize_t i) { return buffer_[i]; }
|
||||||
|
const T &operator()(ssize_t i) const { return buffer_[i]; }
|
||||||
|
const T &operator[](ssize_t i) const { return buffer_[i]; }
|
||||||
|
|
||||||
bool operator==(const NDView &other) const {
|
bool operator==(const NDView &other) const {
|
||||||
if (size_ != other.size_)
|
if (size_ != other.size_)
|
||||||
return false;
|
return false;
|
||||||
for (uint64_t i = 0; i != size_; ++i) {
|
if (shape_ != other.shape_)
|
||||||
|
return false;
|
||||||
|
for (size_t i = 0; i != size_; ++i) {
|
||||||
if (buffer_[i] != other.buffer_[i])
|
if (buffer_[i] != other.buffer_[i])
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -157,6 +154,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
auto shape(ssize_t i) const { return shape_[i]; }
|
auto shape(ssize_t i) const { return shape_[i]; }
|
||||||
|
|
||||||
T *data() { return buffer_; }
|
T *data() { return buffer_; }
|
||||||
|
const T *data() const { return buffer_; }
|
||||||
void print_all() const;
|
void print_all() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -180,6 +178,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, ssize_t Ndim> void NDView<T, Ndim>::print_all() const {
|
template <typename T, ssize_t Ndim> void NDView<T, Ndim>::print_all() const {
|
||||||
for (auto row = 0; row < shape_[0]; ++row) {
|
for (auto row = 0; row < shape_[0]; ++row) {
|
||||||
for (auto col = 0; col < shape_[1]; ++col) {
|
for (auto col = 0; col < shape_[1]; ++col) {
|
||||||
|
@ -21,6 +21,57 @@ TEST_CASE("Element reference 1D") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE("Assign elements through () and []") {
|
||||||
|
std::vector<int> vec;
|
||||||
|
for (int i = 0; i != 10; ++i) {
|
||||||
|
vec.push_back(i);
|
||||||
|
}
|
||||||
|
NDView<int, 1> data(vec.data(), Shape<1>{10});
|
||||||
|
REQUIRE(vec.size() == static_cast<size_t>(data.size()));
|
||||||
|
|
||||||
|
data[3] = 187;
|
||||||
|
data(4) = 512;
|
||||||
|
|
||||||
|
|
||||||
|
REQUIRE(data(0) == 0);
|
||||||
|
REQUIRE(data[0] == 0);
|
||||||
|
REQUIRE(data(1) == 1);
|
||||||
|
REQUIRE(data[1] == 1);
|
||||||
|
REQUIRE(data(2) == 2);
|
||||||
|
REQUIRE(data[2] == 2);
|
||||||
|
REQUIRE(data(3) == 187);
|
||||||
|
REQUIRE(data[3] == 187);
|
||||||
|
REQUIRE(data(4) == 512);
|
||||||
|
REQUIRE(data[4] == 512);
|
||||||
|
REQUIRE(data(5) == 5);
|
||||||
|
REQUIRE(data[5] == 5);
|
||||||
|
REQUIRE(data(6) == 6);
|
||||||
|
REQUIRE(data[6] == 6);
|
||||||
|
REQUIRE(data(7) == 7);
|
||||||
|
REQUIRE(data[7] == 7);
|
||||||
|
REQUIRE(data(8) == 8);
|
||||||
|
REQUIRE(data[8] == 8);
|
||||||
|
REQUIRE(data(9) == 9);
|
||||||
|
REQUIRE(data[9] == 9);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("Element reference 1D with a const NDView") {
|
||||||
|
std::vector<int> vec;
|
||||||
|
for (int i = 0; i != 10; ++i) {
|
||||||
|
vec.push_back(i);
|
||||||
|
}
|
||||||
|
const NDView<int, 1> data(vec.data(), Shape<1>{10});
|
||||||
|
REQUIRE(vec.size() == static_cast<size_t>(data.size()));
|
||||||
|
for (int i = 0; i != 10; ++i) {
|
||||||
|
REQUIRE(data(i) == vec[i]);
|
||||||
|
REQUIRE(data[i] == vec[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE("Element reference 2D") {
|
TEST_CASE("Element reference 2D") {
|
||||||
std::vector<int> vec(12);
|
std::vector<int> vec(12);
|
||||||
std::iota(vec.begin(), vec.end(), 0);
|
std::iota(vec.begin(), vec.end(), 0);
|
||||||
@ -56,7 +107,7 @@ TEST_CASE("Element reference 3D") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("Plus and miuns with single value") {
|
TEST_CASE("Plus and minus with single value") {
|
||||||
std::vector<int> vec(12);
|
std::vector<int> vec(12);
|
||||||
std::iota(vec.begin(), vec.end(), 0);
|
std::iota(vec.begin(), vec.end(), 0);
|
||||||
NDView<int, 2> data(vec.data(), Shape<2>{3, 4});
|
NDView<int, 2> data(vec.data(), Shape<2>{3, 4});
|
||||||
@ -137,16 +188,9 @@ TEST_CASE("iterators") {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TEST_CASE("shape from vector") {
|
|
||||||
// std::vector<int> vec;
|
|
||||||
// for (int i = 0; i != 12; ++i) {
|
|
||||||
// vec.push_back(i);
|
|
||||||
// }
|
|
||||||
// std::vector<ssize_t> shape{3, 4};
|
|
||||||
// NDView<int, 2> data(vec.data(), shape);
|
|
||||||
// }
|
|
||||||
|
|
||||||
TEST_CASE("divide with another span") {
|
|
||||||
|
TEST_CASE("divide with another NDView") {
|
||||||
std::vector<int> vec0{9, 12, 3};
|
std::vector<int> vec0{9, 12, 3};
|
||||||
std::vector<int> vec1{3, 2, 1};
|
std::vector<int> vec1{3, 2, 1};
|
||||||
std::vector<int> result{3, 6, 3};
|
std::vector<int> result{3, 6, 3};
|
||||||
@ -183,6 +227,30 @@ TEST_CASE("compare two views") {
|
|||||||
REQUIRE((view1 == view2));
|
REQUIRE((view1 == view2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("Compare two views with different size"){
|
||||||
|
std::vector<int> vec1(12);
|
||||||
|
std::iota(vec1.begin(), vec1.end(), 0);
|
||||||
|
NDView<int, 2> view1(vec1.data(), Shape<2>{3, 4});
|
||||||
|
|
||||||
|
std::vector<int> vec2(8);
|
||||||
|
std::iota(vec2.begin(), vec2.end(), 0);
|
||||||
|
NDView<int, 2> view2(vec2.data(), Shape<2>{2, 4});
|
||||||
|
|
||||||
|
REQUIRE_FALSE(view1 == view2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("Compare two views with same size but different shape"){
|
||||||
|
std::vector<int> vec1(12);
|
||||||
|
std::iota(vec1.begin(), vec1.end(), 0);
|
||||||
|
NDView<int, 2> view1(vec1.data(), Shape<2>{3, 4});
|
||||||
|
|
||||||
|
std::vector<int> vec2(12);
|
||||||
|
std::iota(vec2.begin(), vec2.end(), 0);
|
||||||
|
NDView<int, 2> view2(vec2.data(), Shape<2>{2, 6});
|
||||||
|
|
||||||
|
REQUIRE_FALSE(view1 == view2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("Create a view over a vector") {
|
TEST_CASE("Create a view over a vector") {
|
||||||
std::vector<int> vec(12);
|
std::vector<int> vec(12);
|
||||||
std::iota(vec.begin(), vec.end(), 0);
|
std::iota(vec.begin(), vec.end(), 0);
|
||||||
|
Reference in New Issue
Block a user