Const element access and fixed comparing bug (#208)

- Added const element access
- Added const data*
- Fixed bug comparing two Views of same size but different shapes

closes #207
This commit is contained in:
Erik Fröjdh
2025-06-27 14:13:51 +02:00
committed by GitHub
parent 6ec8fbee72
commit e3f4b34b72
2 changed files with 88 additions and 21 deletions

View File

@@ -67,20 +67,13 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
std::multiplies<>())) {}
// NDView(T *buffer, const std::vector<ssize_t> &shape)
// : buffer_(buffer),
// strides_(c_strides<Ndim>(make_array<Ndim>(shape))),
// shape_(make_array<Ndim>(shape)),
// size_(std::accumulate(std::begin(shape), std::end(shape), 1,
// std::multiplies<>())) {}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
return buffer_[element_offset(strides_, index...)];
}
template <typename... Ix>
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
const std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) const {
return buffer_[element_offset(strides_, index...)];
}
@@ -92,13 +85,17 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
T *end() { return buffer_ + size_; }
T const *begin() const { return buffer_; }
T const *end() const { return buffer_ + size_; }
T &operator()(ssize_t i) const { return buffer_[i]; }
T &operator[](ssize_t i) const { return buffer_[i]; }
T &operator()(ssize_t i) { return buffer_[i]; }
T &operator[](ssize_t i) { return buffer_[i]; }
const T &operator()(ssize_t i) const { return buffer_[i]; }
const T &operator[](ssize_t i) const { return buffer_[i]; }
bool operator==(const NDView &other) const {
if (size_ != other.size_)
return false;
for (uint64_t i = 0; i != size_; ++i) {
if (shape_ != other.shape_)
return false;
for (size_t i = 0; i != size_; ++i) {
if (buffer_[i] != other.buffer_[i])
return false;
}
@@ -157,6 +154,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
auto shape(ssize_t i) const { return shape_[i]; }
T *data() { return buffer_; }
const T *data() const { return buffer_; }
void print_all() const;
private:
@@ -180,6 +178,7 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
return *this;
}
};
template <typename T, ssize_t Ndim> void NDView<T, Ndim>::print_all() const {
for (auto row = 0; row < shape_[0]; ++row) {
for (auto col = 0; col < shape_[1]; ++col) {