add python bindings for numpy

This commit is contained in:
Bechir Braham
2024-03-11 16:05:54 +01:00
parent 25d282717c
commit c4c88c50d1
5 changed files with 49 additions and 2 deletions

View File

@ -21,7 +21,7 @@ class File:
raise FileNotFoundError(f"File not found: {path}")
ext = os.path.splitext(path)[1]
if ext not in (".raw", ".json"):
if ext not in (".raw", ".json", ".npy"):
raise ValueError(f"Invalid file extension: {ext}")
if ext == ".json":
@ -33,6 +33,17 @@ class File:
bitdepth = 16
else:
bitdepth = master_data["Dynamic Range"]
elif ext == ".npy":
# TODO: find solution for this. maybe add a None detector type
detector = "Jungfrau"
with open(path, "rb") as fobj:
import numpy as np
version = np.lib.format.read_magic(fobj)
# find what function to call based on the version
func_name = 'read_array_header_' + '_'.join(str(v) for v in version)
func = getattr(np.lib.format, func_name)
header = func(fobj)
bitdepth = header[2].itemsize * 8
else:
NotImplementedError("Raw file not implemented yet")

View File

@ -1,13 +1,32 @@
import os
from pathlib import Path
from aare import File, Frame
import numpy as np
if __name__ == "__main__":
#get env variable
root_dir = Path(os.environ.get("PROJECT_ROOT_DIR"))
# read JSON master file
data_path = str(root_dir / "data"/"jungfrau_single_master_0.json")
file = File(data_path)
frame = file.get_frame(0)
print(frame.rows, frame.cols)
print(frame.get(0,0))
print(frame.get(0,0))
# read Numpy file
data_path = str(root_dir / "data"/"test_numpy_file.npy")
file = File(data_path)
frame = file.get_frame(0)
print(frame.rows, frame.cols)
print(frame.get(0,0))
arr = np.array(frame.get_array())
print(arr)
print(arr.shape)
print(np.array_equal(arr, np.load(data_path)[0]))

View File

@ -1,6 +1,7 @@
#include <cstdint>
#include <filesystem>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "aare/defs.hpp"
@ -27,10 +28,13 @@ PYBIND11_MODULE(_aare, m) {
py::class_<Frame<uint16_t>>(m, "_Frame16")
.def(py::init<std::byte*, ssize_t, ssize_t>())
.def("get", &Frame<uint16_t>::get)
.def("get_array", &Frame<uint16_t>::get_array)
.def_property_readonly("rows", &Frame<uint16_t>::rows)
.def_property_readonly("cols", &Frame<uint16_t>::cols)
.def_property_readonly("bitdepth", &Frame<uint16_t>::bitdepth);