Files
Jungfraujoch/image_analysis/indexing/FFTIndexerCPU.cpp
2025-09-21 19:27:51 +02:00

110 lines
3.3 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "FFTIndexerCPU.h"
#include "EigenRefine.h"
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <cassert>
#include <mutex>
static std::mutex fftw_plan_mutex;
static inline double dot_abs(const Coord& a, const Coord& b) {
return std::fabs(a.x * b.x + a.y * b.y + a.z * b.z);
}
FFTIndexerCPU::FFTIndexerCPU(const IndexingSettings& settings)
: FFTIndexer(settings) {
// Allocate host buffers
h_input_fft.resize(input_size, 0.0);
h_output_fft.resize(output_size);
// Validate allocations vs. expected sizes
if (h_input_fft.size() != input_size)
throw std::runtime_error("FFTWIndexer: input buffer size mismatch");
if (h_output_fft.size() != output_size)
throw std::runtime_error("FFTWIndexer: output buffer size mismatch");
const int H = static_cast<int>(histogram_size);
const int out_len = (H / 2) + 1;
int n[1] = { H };
{
std::unique_lock ul(fftw_plan_mutex);
plan = fftwf_plan_many_dft_r2c(
1, n, nDirections,
h_input_fft.data(), nullptr, 1, H,
reinterpret_cast<fftwf_complex*>(h_output_fft.data()), nullptr, 1, out_len,
FFTW_ESTIMATE);
}
if (!plan)
throw std::runtime_error("fftw_plan_many_dft_r2c failed");
}
FFTIndexerCPU::~FFTIndexerCPU() {
std::unique_lock ul(fftw_plan_mutex);
fftwf_destroy_plan(plan);
}
void FFTIndexerCPU::ExecuteFFT(const std::vector<Coord> &coord, size_t nspots) {
// Build histograms: one per direction
const int H = static_cast<int>(histogram_size);
const int D = nDirections;
const int out_len = (H / 2) + 1;
std::fill(h_input_fft.begin(), h_input_fft.end(), 0.0);
for (int d = 0; d < D; ++d) {
float* hist = h_input_fft.data() + static_cast<size_t>(d) * H;
for (size_t i = 0; i < nspots; i++) {
const auto& r = coord[i];
double dot = dot_abs(direction_vectors[d], r);
long long bin = static_cast<long long>(dot / histogram_spacing);
if (bin >= 0 && bin < H) {
hist[bin] += 1.0;
}
}
}
// Plan and execute batched R2C FFT with FFTW
fftwf_execute(plan);
// Post-process: pick peaks past min_length_A
const double len_coeff = 2.0 * static_cast<double>(max_length_A) / static_cast<double>(H);
for (int d = 0; d < D; ++d) {
const auto* spec = h_output_fft.data() + static_cast<size_t>(d) * out_len;
double best_mag = 0.0;
double best_len = -1.0;
for (int j = 0; j < out_len; ++j) {
double len = len_coeff * static_cast<double>(j);
if (len <= static_cast<double>(min_length_A)) continue;
const double re = spec[j][0];
const double im = spec[j][1];
const double mag = std::hypot(re, im);
if (mag > best_mag) {
best_mag = mag;
best_len = len;
}
}
result_fft[d] = FFTResult{
.magnitude = static_cast<float>(best_mag),
.direction = d,
.length = static_cast<float>(best_len)
};
}
}