added support for subtracting pedestal from np.array

This commit is contained in:
Erik Fröjdh
2026-06-12 09:56:21 +02:00
parent ee7503082d
commit 4c0cc7a1f2
3 changed files with 77 additions and 6 deletions
+36 -4
View File
@@ -5,6 +5,7 @@
#include <cstdint>
#include <filesystem>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@@ -12,8 +13,10 @@ namespace py = pybind11;
template <typename SUM_TYPE>
void define_pedestal_bindings(py::module &m, const std::string &name) {
py::class_<Pedestal<SUM_TYPE>>(m, name.c_str())
.def(py::init<int, int, int>())
auto pedestal =
py::class_<Pedestal<SUM_TYPE>>(m, name.c_str(), py::buffer_protocol());
pedestal.def(py::init<int, int, int>())
.def(py::init<int, int>())
.def("mean",
[](Pedestal<SUM_TYPE> &self) {
@@ -50,6 +53,23 @@ void define_pedestal_bindings(py::module &m, const std::string &name) {
*std = self.std();
return return_image_data(std);
})
.def(
"__array_ufunc__",
[](py::object self, py::object ufunc, const std::string &method,
py::args inputs, py::kwargs kwargs) -> py::object {
if (method != "__call__" || inputs.size() != 2 ||
inputs[1].ptr() != self.ptr() ||
py::cast<std::string>(ufunc.attr("__name__")) !=
"subtract") {
return py::reinterpret_borrow<py::object>(
Py_NotImplemented);
}
auto mean =
py::module_::import("builtins").attr("memoryview")(self);
return ufunc(inputs[0], mean, **kwargs);
},
"Support subtracting a Pedestal from a NumPy array.")
.def("clear", py::overload_cast<>(&Pedestal<SUM_TYPE>::clear))
.def_property_readonly("rows", &Pedestal<SUM_TYPE>::rows)
.def_property_readonly("cols", &Pedestal<SUM_TYPE>::cols)
@@ -84,5 +104,17 @@ void define_pedestal_bindings(py::module &m, const std::string &name) {
pedestal.push_no_update(v);
},
py::arg().noconvert())
.def("update_mean", &Pedestal<SUM_TYPE>::update_mean);
}
.def("update_mean", &Pedestal<SUM_TYPE>::update_mean)
.def_buffer([](Pedestal<SUM_TYPE> &self) {
auto mean = self.view();
return py::buffer_info(
const_cast<SUM_TYPE *>(mean.data()), sizeof(SUM_TYPE),
py::format_descriptor<SUM_TYPE>::format(), 2,
{static_cast<py::ssize_t>(mean.shape(0)),
static_cast<py::ssize_t>(mean.shape(1))},
{static_cast<py::ssize_t>(mean.strides()[0] * sizeof(SUM_TYPE)),
static_cast<py::ssize_t>(mean.strides()[1] *
sizeof(SUM_TYPE))},
true);
});
}
+40
View File
@@ -0,0 +1,40 @@
import numpy as np
import pytest
from aare import Pedestal_d, Pedestal_f
@pytest.mark.parametrize(
("pedestal_type", "expected_dtype"),
[(Pedestal_d, np.float64), (Pedestal_f, np.float32)],
)
def test_numpy_array_minus_pedestal(pedestal_type, expected_dtype):
pedestal = pedestal_type(2, 3)
pedestal.push(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.uint16))
array = np.array([[12, 14, 16], [18, 20, 22]], dtype=np.uint16)
result = array - pedestal
np.testing.assert_array_equal(
result, np.array([[10, 10, 10], [10, 10, 10]], dtype=expected_dtype)
)
assert result.dtype == expected_dtype
def test_numpy_array_minus_pedestal_rejects_incompatible_shape():
pedestal = Pedestal_d(2, 3)
array = np.zeros((2, 2), dtype=np.float64)
with pytest.raises(ValueError):
array - pedestal
def test_pedestal_exposes_mean_as_read_only_buffer():
pedestal = Pedestal_d(2, 3)
pedestal.push(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.uint16))
mean = np.asarray(pedestal)
np.testing.assert_array_equal(mean, pedestal.view())
assert np.shares_memory(mean, pedestal.view())
assert not mean.flags.writeable