added tests and features to load full file

This commit is contained in:
Erik Frojdh 2024-04-02 17:19:10 +02:00
parent d196eb5a2e
commit 670d9415e6
9 changed files with 78 additions and 22 deletions

View File

@ -19,6 +19,7 @@ enum class endian {
};
class DType {
//TODO! support for non native endianess?
static_assert(sizeof(long) == sizeof(int64_t), "long should be 64bits");
public:

View File

@ -5,11 +5,22 @@
#include <cstdint>
#include <numeric>
#include <vector>
#include <stdexcept>
namespace aare {
template <ssize_t Ndim> using Shape = std::array<ssize_t, Ndim>;
//TODO! fix mismatch between signed and unsigned
template <ssize_t Ndim>
Shape<Ndim> make_shape(const std::vector<size_t>& shape){
if(shape.size() != Ndim)
throw std::runtime_error("Shape size mismatch");
Shape<Ndim> arr;
std::copy_n(shape.begin(), Ndim, arr.begin());
return arr;
}
template <ssize_t Dim = 0, typename Strides> ssize_t element_offset(const Strides &) { return 0; }
template <ssize_t Dim = 0, typename Strides, typename... Ix>

View File

@ -72,11 +72,6 @@ struct RawFileConfig {
}
};
const char little_endian_char = '<';
const char big_endian_char = '>';
const char no_endian_char = '|';
const std::array<char, 3> endian_chars = {little_endian_char, big_endian_char, no_endian_char};
const std::array<char, 4> numtype_chars = {'f', 'i', 'u', 'c'};
} // namespace aare

View File

@ -32,6 +32,7 @@ endif()
if(AARE_TESTS)
set(TestSources
${CMAKE_CURRENT_SOURCE_DIR}/test/NumpyFile.test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/NumpyHelpers.test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test/RawFile.test.cpp
)

View File

@ -1,6 +1,7 @@
#pragma once
#include "aare/FileInterface.hpp"
#include "aare/NumpyHelpers.hpp"
#include "aare/DType.hpp"
#include "aare/defs.hpp"
#include <iostream>
#include <numeric>
@ -45,6 +46,18 @@ class NumpyFile : public FileInterface {
ssize_t cols() const override { return m_header.shape[2]; }
ssize_t bitdepth() const override { return m_header.dtype.bitdepth(); }
DType dtype() const { return m_header.dtype; }
std::vector<size_t> shape() const { return m_header.shape; }
//load the full numpy file into a NDArray
template<typename T, size_t NDim>
NDArray<T,NDim> load(){
NDArray<T,NDim> arr(make_shape<NDim>(m_header.shape));
fseek(fp, header_size, SEEK_SET);
fread(arr.data(), sizeof(T), arr.size(), fp);
return arr;
}
~NumpyFile();
};

View File

@ -14,7 +14,7 @@
#include "aare/DType.hpp"
#include "aare/defs.hpp"
using shape_t = std::vector<uint64_t>;
using shape_t = std::vector<size_t>;
struct header_t {
header_t() : dtype(aare::DType(aare::DType::ERROR)), fortran_order(false), shape(shape_t()){};
@ -53,7 +53,7 @@ std::string get_value_from_map(const std::string &mapstr);
std::unordered_map<std::string, std::string> parse_dict(std::string in, const std::vector<std::string> &keys);
template <typename T, size_t N> inline bool in_array(T val, const std::array<T, N> &arr) {
template <typename T, size_t N> bool in_array(T val, const std::array<T, N> &arr) {
return std::find(std::begin(arr), std::end(arr), val) != std::end(arr);
}
bool is_digits(const std::string &str);

View File

@ -139,9 +139,9 @@ void NumpyFile::load_metadata(){
}
// read header
auto buf_v = std::vector<char>(header_len);
fread(buf_v.data(), header_len,1,fp);
std::string header(buf_v.data(), header_len);
std::string header(header_len, '\0');
fread(header.data(), header_len,1,fp);
// parse header
std::vector<std::string> keys{"descr", "fortran_order", "shape"};

View File

@ -41,13 +41,13 @@ std::unordered_map<std::string, std::string> parse_dict(std::string in, const st
std::vector<std::pair<size_t, std::string>> positions;
for (auto const &value : keys) {
size_t pos = in.find("'" + value + "'");
for (auto const &key : keys) {
size_t pos = in.find("'" + key + "'");
if (pos == std::string::npos)
throw std::runtime_error("Missing '" + value + "' key.");
throw std::runtime_error("Missing '" + key + "' key.");
std::pair<size_t, std::string> position_pair{pos, value};
std::pair<size_t, std::string> position_pair{pos, key};
positions.push_back(position_pair);
}
@ -78,12 +78,19 @@ std::unordered_map<std::string, std::string> parse_dict(std::string in, const st
}
aare::DType parse_descr(std::string typestring) {
if (typestring.length() < 3) {
throw std::runtime_error("invalid typestring (length)");
}
char byteorder_c = typestring.at(0);
char kind_c = typestring.at(1);
constexpr char little_endian_char = '<';
constexpr char big_endian_char = '>';
constexpr char no_endian_char = '|';
constexpr std::array<char, 3> endian_chars = {little_endian_char, big_endian_char, no_endian_char};
constexpr std::array<char, 4> numtype_chars = {'f', 'i', 'u', 'c'};
const char byteorder_c = typestring[0];
const char kind_c = typestring[1];
std::string itemsize_s = typestring.substr(2);
if (!in_array(byteorder_c, endian_chars)) {
@ -97,7 +104,6 @@ aare::DType parse_descr(std::string typestring) {
if (!is_digits(itemsize_s)) {
throw std::runtime_error("invalid typestring (itemsize)");
}
// unsigned int itemsize = std::stoul(itemsize_s);
return aare::DType(typestring);
}
@ -107,8 +113,7 @@ bool parse_bool(const std::string &in) {
return true;
if (in == "False")
return false;
throw std::runtime_error("Invalid python boolan.");
throw std::runtime_error("Invalid python boolean.");
}
std::string get_value_from_map(const std::string &mapstr) {
@ -124,7 +129,7 @@ bool is_digits(const std::string &str) { return std::all_of(str.begin(), str.end
std::vector<std::string> parse_tuple(std::string in) {
std::vector<std::string> v;
const char seperator = ',';
const char separator = ',';
in = trim(in);
@ -135,7 +140,7 @@ std::vector<std::string> parse_tuple(std::string in) {
std::istringstream iss(in);
for (std::string token; std::getline(iss, token, seperator);) {
for (std::string token; std::getline(iss, token, separator);) {
v.push_back(token);
}
@ -150,7 +155,6 @@ std::string trim(const std::string &str) {
return "";
auto end = str.find_last_not_of(whitespace);
return str.substr(begin, end - begin + 1);
}

View File

@ -30,4 +30,35 @@ TEST_CASE("trim whitespace"){
REQUIRE(trim("hej ") == "hej");
REQUIRE(trim(" ") == "");
REQUIRE(trim(" \thej hej ") == "hej hej");
}
TEST_CASE("parse data type descriptions"){
REQUIRE(parse_descr("<i1") == aare::DType::INT8);
REQUIRE(parse_descr("<i2") == aare::DType::INT16);
REQUIRE(parse_descr("<i4") == aare::DType::INT32);
REQUIRE(parse_descr("<i8") == aare::DType::INT64);
REQUIRE(parse_descr("<u1") == aare::DType::UINT8);
REQUIRE(parse_descr("<u2") == aare::DType::UINT16);
REQUIRE(parse_descr("<u4") == aare::DType::UINT32);
REQUIRE(parse_descr("<u8") == aare::DType::UINT64);
REQUIRE(parse_descr("<f4") == aare::DType::FLOAT);
REQUIRE(parse_descr("<f8") == aare::DType::DOUBLE);
}
TEST_CASE("is element in array"){
REQUIRE(in_array(1, std::array<int, 3>{1,2,3}));
REQUIRE_FALSE(in_array(4, std::array<int, 3>{1,2,3}));
REQUIRE(in_array(1, std::array<int, 1>{1}));
REQUIRE_FALSE(in_array(1, std::array<int, 0>{}));
}
TEST_CASE("Parse numpy dict"){
std::string in = "{'descr': '<f4', 'fortran_order': False, 'shape': (3, 4)}";
std::vector<std::string> keys{"descr", "fortran_order", "shape"};
auto map = parse_dict(in, keys);
REQUIRE(map["descr"] == "'<f4'");
REQUIRE(map["fortran_order"] == "False");
REQUIRE(map["shape"] == "(3, 4)");
}