diff --git a/include/aare/utils/batch.hpp b/include/aare/utils/batch.hpp new file mode 100644 index 0000000..c3122f6 --- /dev/null +++ b/include/aare/utils/batch.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MPL-2.0 +#pragma once +#include +#include +#include "aare/NDArray.hpp" + + +template +void pack_frame_batch(const std::vector>& frames, + size_t first_frame, + size_t n_frames, + std::vector& batch) +{ + if (n_frames == 0) return; + + const size_t rows = frames[first_frame].shape(0); + const size_t cols = frames[first_frame].shape(1); + const size_t image_size = rows * cols; + const size_t total_size = n_frames * image_size; + + if (batch.size() != total_size) { + batch.resize(total_size); + } + + for (size_t k = 0; k < n_frames; ++k) { + const FRAME_TYPE* src = frames[first_frame + k].data(); + FRAME_TYPE* dst = batch.data() + k * image_size; + std::memcpy(dst, src, image_size * sizeof(FRAME_TYPE)); + } +} \ No newline at end of file diff --git a/src/ClusterFinderCUDA.test.cu b/src/ClusterFinderCUDA.test.cu index d321895..cf725e2 100644 --- a/src/ClusterFinderCUDA.test.cu +++ b/src/ClusterFinderCUDA.test.cu @@ -445,12 +445,8 @@ int main(int argc, char* argv[]) { for (size_t bi = 0; bi < n_batches; ++bi) { const size_t offset = bi * BATCH_SIZE; const size_t actual_batch = std::min(BATCH_SIZE, use_data - offset); - - for (size_t k = 0; k < actual_batch; ++k) { - std::memcpy(batch_buffer.data() + k * ROWS * COLS, - frames[offset + k].data(), - ROWS * COLS * sizeof(FRAME_TYPE)); - } + + pack_frame_batch(frames, offset, actual_batch, batch_buffer); aare::NDView batch_view( batch_buffer.data(),