snapshot of svn
This commit is contained in:
91
test/testMPIFFT.cpp
Normal file
91
test/testMPIFFT.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
#include <iostream>
|
||||
#include <mpi.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "DKSBase.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void printData(complex<double> *data, int N, int nprocs, const char *message = "") {
|
||||
if (strcmp(message, "") != 0)
|
||||
cout << message;
|
||||
|
||||
for (int i = 0; i < nprocs; i++) {
|
||||
for (int j = 0; j < N; j++)
|
||||
cout << data[i*N + j] << "\t";
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void initData(complex<double> *data, int N, int rank) {
|
||||
for (int i = 0; i < N; i++)
|
||||
data[i] = complex<double>((double)rank+1.0, 0.0);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
int ierr;
|
||||
int rank, nprocs;
|
||||
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
|
||||
|
||||
cout << "Rank " << (rank+1) << " from " << nprocs << endl;
|
||||
|
||||
int n = 8;
|
||||
|
||||
complex<double> *hdata_in = new complex<double>[n];
|
||||
complex<double> *hdata_out = new complex<double>[n];
|
||||
initData(hdata_in, n, rank);
|
||||
cout << "In data for process " << rank+1 << ":\t";
|
||||
printData(hdata_in, n, 1);
|
||||
|
||||
|
||||
DKSBase base = DKSBase();
|
||||
base.setAPI("Cuda", 4);
|
||||
base.setDevice("-gpu", 4);
|
||||
base.initDevice();
|
||||
|
||||
|
||||
if (rank == 0) {
|
||||
|
||||
complex<double> *hdata_out_all = new complex<double>[nprocs*n];
|
||||
void* mem_ptr;
|
||||
mem_ptr = base.allocateMemory< complex<double> >(nprocs*n, ierr);
|
||||
|
||||
|
||||
MPI_Gather(hdata_in, n, MPI_DOUBLE_COMPLEX, mem_ptr, n, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
|
||||
|
||||
|
||||
int dimsize[3] = {n*nprocs, 1, 1};
|
||||
base.callFFT(mem_ptr, 1, dimsize);
|
||||
base.readData< complex<double> >(mem_ptr, hdata_out_all, n*nprocs);
|
||||
|
||||
MPI_Scatter(mem_ptr, n, MPI_DOUBLE_COMPLEX, hdata_out, n, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
|
||||
|
||||
base.freeMemory< complex<double> >(mem_ptr, n*nprocs);
|
||||
|
||||
printData(hdata_out_all, n, nprocs, "Out data 1:\n");
|
||||
cout << "Scatter data for proces: " << rank + 1 << ": \t";
|
||||
printData(hdata_out, n, 1);
|
||||
} else {
|
||||
|
||||
MPI_Gather(hdata_in, n, MPI_DOUBLE_COMPLEX, NULL, NULL, NULL, 0, MPI_COMM_WORLD);
|
||||
|
||||
MPI_Scatter(NULL, NULL, NULL, hdata_out, n, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
|
||||
|
||||
cout << "Scatter data for proces: " << rank + 1 << ": \t";
|
||||
printData(hdata_out, n, 1);
|
||||
|
||||
}
|
||||
|
||||
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user