89 lines
2.1 KiB
Plaintext
89 lines
2.1 KiB
Plaintext
#ifndef H_CUDA_FFT
|
|
#define H_CUDA_FFT
|
|
|
|
#include <iostream>
|
|
#include <math.h>
|
|
#include <cuda_runtime.h>
|
|
#include <cufft.h>
|
|
#include "cublas_v2.h"
|
|
|
|
#include "../Algorithms/FFT.h"
|
|
#include "CudaBase.cuh"
|
|
|
|
class CudaFFT : public FFT {
|
|
|
|
private:
|
|
|
|
bool base_create;
|
|
CudaBase *m_base;
|
|
|
|
cufftHandle defaultPlanZ2Z;
|
|
cufftHandle defaultPlanD2Z;
|
|
cufftHandle defaultPlanZ2D;
|
|
cublasHandle_t defaultCublasFFT;
|
|
|
|
public:
|
|
|
|
/** Constructor with CudaBase as argument */
|
|
CudaFFT(CudaBase *base);
|
|
|
|
/** constructor */
|
|
CudaFFT();
|
|
|
|
/** destructor */
|
|
~CudaFFT();
|
|
|
|
/**
|
|
* Info: init cufftPlans witch can be reused for all FFTs of the same size and type
|
|
* Return: success or error code
|
|
*/
|
|
int setupFFT(int ndim, int N[3]);
|
|
int setupFFTRC(int ndim, int N[3], double scale = 1.0) { return DKS_SUCCESS; }
|
|
int setupFFTCR(int ndim, int N[3], double scale = 1.0) { return DKS_SUCCESS; }
|
|
|
|
/**
|
|
* Info: destroy default FFT plans
|
|
* Return: success or error code
|
|
*/
|
|
int destroyFFT();
|
|
|
|
/*
|
|
Info: execute complex to complex double precision fft using cufft library
|
|
Return: success or error code
|
|
*/
|
|
int executeFFT(void * mem_ptr, int ndim, int N[3], int streamId = -1, bool forward = true);
|
|
|
|
/*
|
|
Info: execute ifft
|
|
Return: success or error code
|
|
*/
|
|
int executeIFFT(void * mem_ptr, int ndim, int N[3], int streamId = -1);
|
|
|
|
/*
|
|
Info: execute normalize using cuda kernel for complex to complex iFFT
|
|
Return: success or error code
|
|
*/
|
|
int normalizeFFT(void * mem_ptr, int ndim, int N[3], int streamId = -1);
|
|
|
|
/*
|
|
Info: execute real to complex double precision FFT
|
|
Return: success or error code
|
|
*/
|
|
int executeRCFFT(void * real_ptr, void * comp_ptr, int ndim, int N[3], int streamId = -1);
|
|
|
|
/*
|
|
Info: exectue complex to real double precision FFT
|
|
Return: success or error code
|
|
*/
|
|
int executeCRFFT(void * real_ptr, void * comp_ptr, int ndim, int N[3], int streamId = -1);
|
|
|
|
/*
|
|
Info: execute normalize for complex to real iFFT
|
|
Return: success or error code
|
|
*/
|
|
int normalizeCRFFT(void *real_ptr, int ndim, int N[3], int streamId = -1);
|
|
|
|
};
|
|
|
|
#endif
|