Files
DKS/test/testStockFFT3D.cpp
2016-10-10 14:49:32 +02:00

181 lines
4.3 KiB
C++

#include <iostream>
#include <cstdlib>
#include <complex>
#include "Utility/TimeStamp.h"
#include "DKSBase.h"
using namespace std;
void printData3DN4(complex<double>* &data, int N, int dim);
void compareData(complex<double>* &data1, complex<double>* &data2, int N, int dim);
int main(int argc, char *argv[]) {
int n = 2;
if (argc == 2)
n = atoi(argv[1]);
int N = pow(2,n);
cout << "Begin DKS Base tests" << endl;
cout << "FFT size: " << N << endl;
int dimsize[3] = {N, N, N};
complex<double> *cdata = new complex<double>[N*N*N];
complex<double> *cfft = new complex<double>[N*N*N];
complex<double> *cfft2 = new complex<double>[N*N*N];
complex<double> *cfft3 = new complex<double>[N*N*N];
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
for (int k = 0; k < N; k++) {
//cdata[i*N*N + j*N + k] = complex<double>((double)k/(N*N*N), 0);
cdata[i*N*N + j*N + k] = complex<double>(k, 0);
cfft[i*N*N + j*N + k] = complex<double>(0, 0);
cfft2[i*N*N + j*N + k] = complex<double>(0, 0);
cfft3[i*N + j*N + k] = complex<double>(0, 0);
}
}
}
if (N == 4)
printData3DN4(cdata, N, 3);
/* init DKSBase */
cout << "Init device and set function" << endl;
int ierr;
timestamp_t t0, t1;
/* stockham radix-2 out-of-place fft */
DKSBase base2;
base2.setAPI("OpenCL", 6);
base2.setDevice("-gpu", 4);
base2.initDevice();
cout << endl;
void *src_ptr;
for (int i = 0; i < 5; i++) {
t0 = get_timestamp();
src_ptr = base2.allocateMemory< complex<double> >(N*N*N, ierr);
base2.writeData< complex<double> >(src_ptr, cdata, N*N*N);
base2.callFFTStockham(src_ptr, 3, dimsize);
base2.readData< complex<double> >(src_ptr, cfft2, N*N*N);
base2.freeMemory< complex<double> >(src_ptr, N*N*N);
t1 = get_timestamp();
cout << "out-of-place FFT time: " << get_secs(t0, t1) << endl;
}
if (N == 4)
printData3DN4(cfft2, N, 3);
//delete base2;
cout << endl;
/* CUDA cufft */
DKSBase base3;
base3.setAPI("Cuda", 4);
base3.setDevice("-gpu", 4);
base3.initDevice();
cout << endl;
void *cuda_ptr;
for (int i = 0; i < 5; i++) {
t0 = get_timestamp();
cuda_ptr = base3.allocateMemory< complex<double> >(N*N*N, ierr);
base3.writeData< complex<double> >(cuda_ptr, cdata, N*N*N);
base3.callFFT(cuda_ptr, 3, dimsize);
base3.readData< complex<double> >(cuda_ptr, cfft3, N*N*N);
base3.freeMemory< complex<double> >(cuda_ptr, N*N*N);
t1 = get_timestamp();
cout << "Cuda FFT time: " << get_secs(t0, t1) << endl;
}
if (N == 4)
printData3DN4(cfft3, N, 3);
//delete base3;
cout << endl;
/* radix-2 in place fft */
DKSBase base;
base.setAPI("OpenCL", 6);
base.setDevice("-gpu", 4);
base.initDevice();
cout << endl;
void *mem_ptr;
for (int i = 0; i < 5; i++) {
t0 = get_timestamp();
mem_ptr = base.allocateMemory< complex<double> >(N*N*N, ierr);
base.writeData< complex<double> >(mem_ptr, cdata, N*N*N);
base.callFFT(mem_ptr, 3, dimsize);
base.readData< complex<double> >(mem_ptr, cfft, N*N*N);
base.freeMemory< complex<double> >(mem_ptr, N*N*N);
t1 = get_timestamp();
cout << "in-place FFT time: " << get_secs(t0, t1) << endl;
}
if (N == 4)
printData3DN4(cfft, N, 3);
//delete base;
cout << endl;
/* compare results */
cout << endl;
cout << "Radix 2 vs Stockham: ";
compareData(cfft, cfft2, N, 3);
cout << "Radix 2 vs Cufft: ";
compareData(cfft, cfft3, N, 3);
cout << "Stockham vs Cufft: ";
compareData(cfft2, cfft3, N, 3);
return 0;
}
void printData3DN4(complex<double>* &data, int N, int dim) {
for (int j = 0; j < N; j++) {
for (int i = 0; i < N; i++) {
for (int k = 0; k < N; k++) {
double d = data[i*N*N + j*N + k].real();
if (d > 10e-5 || d < -10e-5)
cout << d << "\t";
else
cout << 0 << "\t";
}
}
cout << endl;
}
cout << endl;
}
void compareData(complex<double>* &data1, complex<double>* &data2, int N, int dim) {
int ni, nj, nk, id;
ni = (dim > 2) ? N : 1;
nj = (dim > 1) ? N : 1;
nk = N;
double sum = 0;
for (int i = 0; i < ni; i++) {
for (int j = 0; j < nj; j++) {
for (int k = 0; k < nk; k++) {
id = i*ni*ni + j*nj + k;
sum += fabs(data1[id].real() - data2[id].real());
sum += fabs(data1[id].imag() - data2[id].imag());
}
}
}
cout << "CC <--> CC diff: " << sum << endl;
}