#pragma once #include "aare/ArrayExpr.hpp" #include "aare/defs.hpp" #include #include #include #include #include #include #include #include #include #include namespace aare { template using Shape = std::array; // TODO! fix mismatch between signed and unsigned template Shape make_shape(const std::vector &shape) { if (shape.size() != Ndim) throw std::runtime_error("Shape size mismatch"); Shape arr; std::copy_n(shape.begin(), Ndim, arr.begin()); return arr; } template ssize_t element_offset(const Strides & /*unused*/) { return 0; } template ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) { return i * strides[Dim] + element_offset(strides, index...); } template std::array c_strides(const std::array &shape) { std::array strides{}; std::fill(strides.begin(), strides.end(), 1); for (ssize_t i = Ndim - 1; i > 0; --i) { strides[i - 1] = strides[i] * shape[i]; } return strides; } template std::array make_array(const std::vector &vec) { assert(vec.size() == Ndim); std::array arr{}; std::copy_n(vec.begin(), Ndim, arr.begin()); return arr; } template class NDView : public ArrayExpr, Ndim> { public: NDView() = default; ~NDView() = default; NDView(const NDView &) = default; NDView(NDView &&) = default; NDView(T *buffer, std::array shape) : buffer_(buffer), strides_(c_strides(shape)), shape_(shape), size_(std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<>())) {} // NDView(T *buffer, const std::vector &shape) // : buffer_(buffer), // strides_(c_strides(make_array(shape))), // shape_(make_array(shape)), // size_(std::accumulate(std::begin(shape), std::end(shape), 1, // std::multiplies<>())) {} template std::enable_if_t operator()(Ix... index) { return buffer_[element_offset(strides_, index...)]; } template std::enable_if_t operator()(Ix... index) const { return buffer_[element_offset(strides_, index...)]; } ssize_t size() const { return static_cast(size_); } size_t total_bytes() const { return size_ * sizeof(T); } std::array strides() const noexcept { return strides_; } T *begin() { return buffer_; } T *end() { return buffer_ + size_; } T const *begin() const { return buffer_; } T const *end() const { return buffer_ + size_; } T &operator()(ssize_t i) const { return buffer_[i]; } T &operator[](ssize_t i) const { return buffer_[i]; } bool operator==(const NDView &other) const { if (size_ != other.size_) return false; for (uint64_t i = 0; i != size_; ++i) { if (buffer_[i] != other.buffer_[i]) return false; } return true; } NDView &operator+=(const T val) { return elemenwise(val, std::plus()); } NDView &operator-=(const T val) { return elemenwise(val, std::minus()); } NDView &operator*=(const T val) { return elemenwise(val, std::multiplies()); } NDView &operator/=(const T val) { return elemenwise(val, std::divides()); } NDView &operator/=(const NDView &other) { return elemenwise(other, std::divides()); } template NDView &operator=(const std::array &arr) { if (size() != static_cast(arr.size())) throw std::runtime_error(LOCATION + "Array and NDView size mismatch"); std::copy(arr.begin(), arr.end(), begin()); return *this; } NDView &operator=(const T val) { for (auto it = begin(); it != end(); ++it) *it = val; return *this; } NDView &operator=(const NDView &other) { if (this == &other) return *this; shape_ = other.shape_; strides_ = other.strides_; size_ = other.size_; buffer_ = other.buffer_; return *this; } NDView &operator=(NDView &&other) noexcept { if (this == &other) return *this; shape_ = std::move(other.shape_); strides_ = std::move(other.strides_); size_ = other.size_; buffer_ = other.buffer_; other.buffer_ = nullptr; return *this; } auto &shape() const { return shape_; } auto shape(ssize_t i) const { return shape_[i]; } T *data() { return buffer_; } void print_all() const; private: T *buffer_{nullptr}; std::array strides_{}; std::array shape_{}; uint64_t size_{}; template NDView &elemenwise(T val, BinaryOperation op) { for (uint64_t i = 0; i != size_; ++i) { buffer_[i] = op(buffer_[i], val); } return *this; } template NDView &elemenwise(const NDView &other, BinaryOperation op) { for (uint64_t i = 0; i != size_; ++i) { buffer_[i] = op(buffer_[i], other.buffer_[i]); } return *this; } }; template void NDView::print_all() const { for (auto row = 0; row < shape_[0]; ++row) { for (auto col = 0; col < shape_[1]; ++col) { std::cout << std::setw(3); std::cout << (*this)(row, col) << " "; } std::cout << "\n"; } } template std::ostream &operator<<(std::ostream &os, const NDView &arr) { for (auto row = 0; row < arr.shape(0); ++row) { for (auto col = 0; col < arr.shape(1); ++col) { os << std::setw(3); os << arr(row, col) << " "; } os << "\n"; } return os; } template NDView make_view(std::vector &vec) { return NDView(vec.data(), {static_cast(vec.size())}); } } // namespace aare