FFT for OpenCL using clFFT library
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <complex>
|
||||
#include <string>
|
||||
|
||||
#include "Utility/TimeStamp.h"
|
||||
#include "DKSBase.h"
|
||||
@@ -18,22 +19,30 @@ int main(int argc, char *argv[]) {
|
||||
int N = 16;
|
||||
char *api_name = new char[10];
|
||||
char *device_name = new char[10];
|
||||
if (argc == 2) {
|
||||
N = atoi(argv[1]);
|
||||
strcpy(api_name, "Cuda");
|
||||
strcpy(device_name, "-gpu");
|
||||
} else if (argc == 3) {
|
||||
N = atoi(argv[1]);
|
||||
strcpy(api_name, argv[2]);
|
||||
strcpy(device_name, "-gpu");
|
||||
} else if (argc == 4) {
|
||||
N = atoi(argv[1]);
|
||||
strcpy(api_name, argv[2]);
|
||||
strcpy(device_name, argv[3]);
|
||||
} else {
|
||||
N = 16;
|
||||
strcpy(api_name, "OpenCL");
|
||||
strcpy(device_name, "-gpu");
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (argv[i] == string("-cuda")) {
|
||||
strcpy(api_name, "Cuda");
|
||||
strcpy(device_name, "-gpu");
|
||||
}
|
||||
|
||||
if (argv[i] == string("-opencl")) {
|
||||
strcpy(api_name, "OpenCL");
|
||||
strcpy(device_name, "-gpu");
|
||||
}
|
||||
|
||||
if (argv[i] == string("-mic")) {
|
||||
strcpy(api_name, "OpenMP");
|
||||
strcpy(device_name, "-mic");
|
||||
}
|
||||
|
||||
if (argv[i] == string("-cpu")) {
|
||||
strcpy(api_name, "OpenCL");
|
||||
strcpy(device_name, "-cpu");
|
||||
}
|
||||
|
||||
if (argv[i] == string("-N"))
|
||||
N = atoi(argv[i+1]);
|
||||
}
|
||||
|
||||
cout << "Use api: " << api_name << ", " << device_name << endl;
|
||||
@@ -74,9 +83,16 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
/* write data to device */
|
||||
ierr = base.writeData< complex<double> >(mem_ptr, cdata, N*N*N);
|
||||
if (N < 5)
|
||||
printData3DN4(cdata, N, 3);
|
||||
|
||||
|
||||
/* execute fft */
|
||||
base.callFFT(mem_ptr, 3, dimsize);
|
||||
if (N < 5) {
|
||||
base.readData< complex<double> > (mem_ptr, cfft, N*N*N);
|
||||
printData3DN4(cfft, N, 3);
|
||||
}
|
||||
|
||||
/* execute ifft */
|
||||
base.callIFFT(mem_ptr, 3, dimsize);
|
||||
@@ -86,7 +102,9 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
/* read data from device */
|
||||
base.readData< complex<double> >(mem_ptr, cifft, N*N*N);
|
||||
|
||||
if (N < 5)
|
||||
printData3DN4(cifft, N, 3);
|
||||
|
||||
/* free device memory */
|
||||
base.freeMemory< complex<double> >(mem_ptr, N*N*N);
|
||||
|
||||
@@ -130,7 +148,7 @@ void printData3DN4(complex<double>* &data, int N, int dim) {
|
||||
if (a < 10e-5 && a > -10e-5)
|
||||
a = 0;
|
||||
|
||||
cout << d << "; " << a << "\t";
|
||||
cout << "(" << d << "," << a << ") ";
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
@@ -157,3 +175,5 @@ void compareData(complex<double>* &data1, complex<double>* &data2, int N, int di
|
||||
cout << "Size " << N << " CC <--> CC diff: " << sum << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user