mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2025-12-20 12:01:24 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f2d937d74 |
@@ -5,9 +5,49 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
namespace aare {
|
namespace aare {
|
||||||
|
|
||||||
|
template <ssize_t Dim = 0, typename Strides>
|
||||||
|
ssize_t element_offset(const Strides & /*unused*/) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <ssize_t Dim = 0, typename Strides, typename... Ix>
|
||||||
|
ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) {
|
||||||
|
return i * strides[Dim] + element_offset<Dim + 1>(strides, index...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Derived, typename T, ssize_t Ndim>
|
||||||
|
class NDIndexOps {
|
||||||
|
public:
|
||||||
|
template <typename... Ix>
|
||||||
|
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
|
||||||
|
return derived().data()[element_offset(derived().strides(), index...)];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ix>
|
||||||
|
std::enable_if_t<sizeof...(Ix) == Ndim, const T &> operator()(Ix... index) const {
|
||||||
|
return derived().data()[element_offset(derived().strides(), index...)];
|
||||||
|
}
|
||||||
|
|
||||||
|
T &operator()(ssize_t i) {
|
||||||
|
return derived().data()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T &operator()(ssize_t i) const {
|
||||||
|
return derived().data()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
T &operator[](ssize_t i) { return derived().data()[i]; }
|
||||||
|
const T &operator[](ssize_t i) const { return derived().data()[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Derived &derived() { return static_cast<Derived &>(*this); }
|
||||||
|
const Derived &derived() const { return static_cast<const Derived &>(*this); }
|
||||||
|
};
|
||||||
|
|
||||||
template <typename E, ssize_t Ndim> class ArrayExpr {
|
template <typename E, ssize_t Ndim> class ArrayExpr {
|
||||||
public:
|
public:
|
||||||
static constexpr bool is_leaf = false;
|
static constexpr bool is_leaf = false;
|
||||||
@@ -96,4 +136,4 @@ auto operator/(const ArrayExpr<A, Ndim> &arr1, const ArrayExpr<B, Ndim> &arr2) {
|
|||||||
return ArrayDiv<ArrayExpr<A, Ndim>, ArrayExpr<B, Ndim>, Ndim>(arr1, arr2);
|
return ArrayDiv<ArrayExpr<A, Ndim>, ArrayExpr<B, Ndim>, Ndim>(arr1, arr2);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace aare
|
} // namespace aare
|
||||||
|
|||||||
@@ -20,7 +20,8 @@
|
|||||||
namespace aare {
|
namespace aare {
|
||||||
|
|
||||||
template <typename T, ssize_t Ndim = 2>
|
template <typename T, ssize_t Ndim = 2>
|
||||||
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
|
class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim>,
|
||||||
|
public NDIndexOps<NDArray<T, Ndim>, T, Ndim> {
|
||||||
std::array<ssize_t, Ndim> shape_;
|
std::array<ssize_t, Ndim> shape_;
|
||||||
std::array<ssize_t, Ndim> strides_;
|
std::array<ssize_t, Ndim> strides_;
|
||||||
size_t size_{}; // TODO! do we need to store size when we have shape?
|
size_t size_{}; // TODO! do we need to store size when we have shape?
|
||||||
@@ -151,61 +152,15 @@ class NDArray : public ArrayExpr<NDArray<T, Ndim>, Ndim> {
|
|||||||
//
|
//
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
using NDIndexOps<NDArray<T, Ndim>, T, Ndim>::operator();
|
||||||
|
using NDIndexOps<NDArray<T, Ndim>, T, Ndim>::operator[];
|
||||||
|
|
||||||
auto *begin() { return data_; }
|
auto *begin() { return data_; }
|
||||||
const auto *begin() const { return data_; }
|
const auto *begin() const { return data_; }
|
||||||
|
|
||||||
auto *end() { return data_ + size_; }
|
auto *end() { return data_ + size_; }
|
||||||
const auto *end() const { return data_ + size_; }
|
const auto *end() const { return data_ + size_; }
|
||||||
|
|
||||||
/*
|
|
||||||
* @brief Access element at given multi-dimensional index.
|
|
||||||
* i.e. arr(i,j,k,...)
|
|
||||||
*
|
|
||||||
* @note The fast index is the last index. Please take care when iterating
|
|
||||||
* through the array.
|
|
||||||
*/
|
|
||||||
template <typename... Ix>
|
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
|
|
||||||
return data_[element_offset(strides_, index...)];
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* @brief Access element at given multi-dimensional index (const version).
|
|
||||||
* i.e. arr(i,j,k,...)
|
|
||||||
*
|
|
||||||
* @note The fast index is the last index. Please take care when iterating
|
|
||||||
* through the array.
|
|
||||||
*/
|
|
||||||
template <typename... Ix>
|
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, const T &>
|
|
||||||
operator()(Ix... index) const {
|
|
||||||
return data_[element_offset(strides_, index...)];
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
@brief Index the array as it would be a 1D array. To get a certain
|
|
||||||
pixel in a multidimensional array use the (i,j,k,...) operator instead.
|
|
||||||
*/
|
|
||||||
T &operator()(ssize_t i) { return data_[i]; }
|
|
||||||
|
|
||||||
/*
|
|
||||||
@brief Index the array as it would be a 1D array. To get a certain
|
|
||||||
pixel in a multidimensional array use the (i,j,k,...) operator instead.
|
|
||||||
*/
|
|
||||||
const T &operator()(ssize_t i) const { return data_[i]; }
|
|
||||||
|
|
||||||
/*
|
|
||||||
@brief Index the array as it would be a 1D array. To get a certain
|
|
||||||
pixel in a multidimensional array use the (i,j,k,...) operator instead.
|
|
||||||
*/
|
|
||||||
T &operator[](ssize_t i) { return data_[i]; }
|
|
||||||
|
|
||||||
/*
|
|
||||||
@brief Index the array as it would be a 1D array. To get a certain
|
|
||||||
pixel in a multidimensional array use the (i,j,k,...) operator instead.
|
|
||||||
*/
|
|
||||||
const T &operator[](ssize_t i) const { return data_[i]; }
|
|
||||||
|
|
||||||
/* @brief Return a raw pointer to the data */
|
/* @brief Return a raw pointer to the data */
|
||||||
T *data() { return data_; }
|
T *data() { return data_; }
|
||||||
|
|
||||||
|
|||||||
@@ -54,16 +54,6 @@ size_t num_elements(const Shape<Ndim> &shape) {
|
|||||||
std::multiplies<size_t>());
|
std::multiplies<size_t>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ssize_t Dim = 0, typename Strides>
|
|
||||||
ssize_t element_offset(const Strides & /*unused*/) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <ssize_t Dim = 0, typename Strides, typename... Ix>
|
|
||||||
ssize_t element_offset(const Strides &strides, ssize_t i, Ix... index) {
|
|
||||||
return i * strides[Dim] + element_offset<Dim + 1>(strides, index...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <ssize_t Ndim>
|
template <ssize_t Ndim>
|
||||||
std::array<ssize_t, Ndim> c_strides(const std::array<ssize_t, Ndim> &shape) {
|
std::array<ssize_t, Ndim> c_strides(const std::array<ssize_t, Ndim> &shape) {
|
||||||
std::array<ssize_t, Ndim> strides{};
|
std::array<ssize_t, Ndim> strides{};
|
||||||
@@ -83,7 +73,8 @@ std::array<ssize_t, Ndim> make_array(const std::vector<ssize_t> &vec) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, ssize_t Ndim = 2>
|
template <typename T, ssize_t Ndim = 2>
|
||||||
class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim>,
|
||||||
|
public NDIndexOps<NDView<T, Ndim>, T, Ndim> {
|
||||||
public:
|
public:
|
||||||
NDView() = default;
|
NDView() = default;
|
||||||
~NDView() = default;
|
~NDView() = default;
|
||||||
@@ -94,26 +85,9 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
: buffer_(buffer), strides_(c_strides<Ndim>(shape)), shape_(shape),
|
: buffer_(buffer), strides_(c_strides<Ndim>(shape)), shape_(shape),
|
||||||
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
|
size_(std::accumulate(std::begin(shape), std::end(shape), 1,
|
||||||
std::multiplies<>())) {}
|
std::multiplies<>())) {}
|
||||||
|
|
||||||
template <typename... Ix>
|
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, T &> operator()(Ix... index) {
|
|
||||||
return buffer_[element_offset(strides_, index...)];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Ix>
|
using NDIndexOps<NDView<T, Ndim>, T, Ndim>::operator();
|
||||||
std::enable_if_t<sizeof...(Ix) == 1 && (Ndim > 1), NDView<T, Ndim - 1>> operator()(Ix... index) {
|
using NDIndexOps<NDView<T, Ndim>, T, Ndim>::operator[];
|
||||||
// return a view of the next dimension
|
|
||||||
std::array<ssize_t, Ndim - 1> new_shape{};
|
|
||||||
std::copy_n(shape_.begin() + 1, Ndim - 1, new_shape.begin());
|
|
||||||
return NDView<T, Ndim - 1>(&buffer_[element_offset(strides_, index...)],
|
|
||||||
new_shape);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Ix>
|
|
||||||
std::enable_if_t<sizeof...(Ix) == Ndim, const T &> operator()(Ix... index) const {
|
|
||||||
return buffer_[element_offset(strides_, index...)];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ssize_t size() const { return static_cast<ssize_t>(size_); }
|
ssize_t size() const { return static_cast<ssize_t>(size_); }
|
||||||
@@ -129,16 +103,6 @@ class NDView : public ArrayExpr<NDView<T, Ndim>, Ndim> {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Access element at index i.
|
|
||||||
*/
|
|
||||||
T &operator[](ssize_t i) { return buffer_[i]; }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Access element at index i.
|
|
||||||
*/
|
|
||||||
const T &operator[](ssize_t i) const { return buffer_[i]; }
|
|
||||||
|
|
||||||
bool operator==(const NDView &other) const {
|
bool operator==(const NDView &other) const {
|
||||||
if (size_ != other.size_)
|
if (size_ != other.size_)
|
||||||
return false;
|
return false;
|
||||||
@@ -270,4 +234,4 @@ template <typename T> NDView<T, 1> make_view(std::vector<T> &vec) {
|
|||||||
return NDView<T, 1>(vec.data(), {static_cast<ssize_t>(vec.size())});
|
return NDView<T, 1>(vec.data(), {static_cast<ssize_t>(vec.size())});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace aare
|
} // namespace aare
|
||||||
|
|||||||
@@ -552,3 +552,8 @@ TEST_CASE("Move construct from an array with Ndim + 1 throws on size mismatch")
|
|||||||
REQUIRE_THROWS(NDArray<int, 2>(std::move(a)));
|
REQUIRE_THROWS(NDArray<int, 2>(std::move(a)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("Assign an element in a 2D array"){
|
||||||
|
NDArray<int,2> a({{3,4}},0);
|
||||||
|
a(1,2) = 57;
|
||||||
|
REQUIRE(a(1,2) == 57);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user