From 4c0cc7a1f26ac044bc36c5af915a38597fa62866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Fr=C3=B6jdh?= Date: Fri, 12 Jun 2026 09:56:21 +0200 Subject: [PATCH] added support for subtracting pedestal from np.array --- include/aare/NDView.hpp | 3 +-- python/src/pedestal.hpp | 40 +++++++++++++++++++++++++++++++---- python/tests/test_Pedestal.py | 40 +++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 python/tests/test_Pedestal.py diff --git a/include/aare/NDView.hpp b/include/aare/NDView.hpp index 403d203..8c2432d 100644 --- a/include/aare/NDView.hpp +++ b/include/aare/NDView.hpp @@ -90,8 +90,7 @@ class NDView : public ArrayExpr, Ndim> { NDView(T *buffer, std::array shape) : buffer_(buffer), strides_(c_strides(shape)), shape_(shape), - size_(std::accumulate(std::begin(shape), std::end(shape), 1, - std::multiplies<>())) {} + size_(num_elements(shape)) {} template std::enable_if_t operator()(Ix... index) { diff --git a/python/src/pedestal.hpp b/python/src/pedestal.hpp index 6f8fc06..2e3b2d0 100644 --- a/python/src/pedestal.hpp +++ b/python/src/pedestal.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -12,8 +13,10 @@ namespace py = pybind11; template void define_pedestal_bindings(py::module &m, const std::string &name) { - py::class_>(m, name.c_str()) - .def(py::init()) + auto pedestal = + py::class_>(m, name.c_str(), py::buffer_protocol()); + + pedestal.def(py::init()) .def(py::init()) .def("mean", [](Pedestal &self) { @@ -50,6 +53,23 @@ void define_pedestal_bindings(py::module &m, const std::string &name) { *std = self.std(); return return_image_data(std); }) + .def( + "__array_ufunc__", + [](py::object self, py::object ufunc, const std::string &method, + py::args inputs, py::kwargs kwargs) -> py::object { + if (method != "__call__" || inputs.size() != 2 || + inputs[1].ptr() != self.ptr() || + py::cast(ufunc.attr("__name__")) != + "subtract") { + return py::reinterpret_borrow( + Py_NotImplemented); + } + + auto mean = + py::module_::import("builtins").attr("memoryview")(self); + return ufunc(inputs[0], mean, **kwargs); + }, + "Support subtracting a Pedestal from a NumPy array.") .def("clear", py::overload_cast<>(&Pedestal::clear)) .def_property_readonly("rows", &Pedestal::rows) .def_property_readonly("cols", &Pedestal::cols) @@ -84,5 +104,17 @@ void define_pedestal_bindings(py::module &m, const std::string &name) { pedestal.push_no_update(v); }, py::arg().noconvert()) - .def("update_mean", &Pedestal::update_mean); -} \ No newline at end of file + .def("update_mean", &Pedestal::update_mean) + .def_buffer([](Pedestal &self) { + auto mean = self.view(); + return py::buffer_info( + const_cast(mean.data()), sizeof(SUM_TYPE), + py::format_descriptor::format(), 2, + {static_cast(mean.shape(0)), + static_cast(mean.shape(1))}, + {static_cast(mean.strides()[0] * sizeof(SUM_TYPE)), + static_cast(mean.strides()[1] * + sizeof(SUM_TYPE))}, + true); + }); +} diff --git a/python/tests/test_Pedestal.py b/python/tests/test_Pedestal.py new file mode 100644 index 0000000..2c94811 --- /dev/null +++ b/python/tests/test_Pedestal.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from aare import Pedestal_d, Pedestal_f + + +@pytest.mark.parametrize( + ("pedestal_type", "expected_dtype"), + [(Pedestal_d, np.float64), (Pedestal_f, np.float32)], +) +def test_numpy_array_minus_pedestal(pedestal_type, expected_dtype): + pedestal = pedestal_type(2, 3) + pedestal.push(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.uint16)) + array = np.array([[12, 14, 16], [18, 20, 22]], dtype=np.uint16) + + result = array - pedestal + + np.testing.assert_array_equal( + result, np.array([[10, 10, 10], [10, 10, 10]], dtype=expected_dtype) + ) + assert result.dtype == expected_dtype + + +def test_numpy_array_minus_pedestal_rejects_incompatible_shape(): + pedestal = Pedestal_d(2, 3) + array = np.zeros((2, 2), dtype=np.float64) + + with pytest.raises(ValueError): + array - pedestal + + +def test_pedestal_exposes_mean_as_read_only_buffer(): + pedestal = Pedestal_d(2, 3) + pedestal.push(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.uint16)) + + mean = np.asarray(pedestal) + + np.testing.assert_array_equal(mean, pedestal.view()) + assert np.shares_memory(mean, pedestal.view()) + assert not mean.flags.writeable