Files
Jungfraujoch/image_analysis/indexing/CUDAMemHelpers.h
2025-07-18 11:42:39 +02:00

154 lines
4.9 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#ifndef JFJOCH_CUDAMEMHELPERS_H
#define JFJOCH_CUDAMEMHELPERS_H
#include <cuda_runtime.h>
#include <cufft.h>
#include <stdexcept>
#include <vector>
#include "../common/JFJochException.h"
class CudaStream {
cudaStream_t stream_ = nullptr;
public:
CudaStream(unsigned int flags = cudaStreamDefault) {
if (cudaStreamCreateWithFlags(&stream_, flags) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to create CUDA stream");
}
~CudaStream() {
if (stream_) cudaStreamDestroy(stream_);
}
// Move-only type
CudaStream(CudaStream&& other) noexcept : stream_(other.stream_) { other.stream_ = nullptr; }
CudaStream& operator=(CudaStream&& other) noexcept {
if (this != &other) {
if (stream_) cudaStreamDestroy(stream_);
stream_ = other.stream_;
other.stream_ = nullptr;
}
return *this;
}
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const { return stream_; }
cudaStream_t get() const { return stream_; }
};
class CudaFFTPlan {
cufftHandle plan_ = 0;
public:
CudaFFTPlan() = default;
CudaFFTPlan(
int rank,
const int* n,
const int* inembed, int istride, int idist,
const int* onembed, int ostride, int odist,
cufftType type, int batch)
{
if (cufftPlanMany(&plan_, rank, const_cast<int*>(n),
const_cast<int*>(inembed), istride, idist,
const_cast<int*>(onembed), ostride, odist,
type, batch) != CUFFT_SUCCESS)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to create CUFFT plan with cufftPlanMany");
}
// Convenience overload for vector input
CudaFFTPlan(
int rank,
const std::vector<int>& n,
const std::vector<int>& inembed, int istride, int idist,
const std::vector<int>& onembed, int ostride, int odist,
cufftType type, int batch)
: CudaFFTPlan(rank, n.data(), inembed.data(), istride, idist, onembed.data(), ostride, odist, type, batch)
{}
~CudaFFTPlan() {
if (plan_) cufftDestroy(plan_);
}
// Move-only type
CudaFFTPlan(CudaFFTPlan&& other) noexcept : plan_(other.plan_) { other.plan_ = 0; }
CudaFFTPlan& operator=(CudaFFTPlan&& other) noexcept {
if (this != &other) {
if (plan_) cufftDestroy(plan_);
plan_ = other.plan_;
other.plan_ = 0;
}
return *this;
}
CudaFFTPlan(const CudaFFTPlan&) = delete;
CudaFFTPlan& operator=(const CudaFFTPlan&) = delete;
operator cufftHandle() const { return plan_; }
cufftHandle get() const { return plan_; }
};
template <typename T>
class CudaDevicePtr {
T* ptr_ = nullptr;
public:
CudaDevicePtr() = default;
explicit CudaDevicePtr(size_t count) {
if (cudaMalloc(&ptr_, count * sizeof(T)) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to allocate device memory");
}
~CudaDevicePtr() {
if (ptr_) cudaFree(ptr_);
}
// Move-only type
CudaDevicePtr(CudaDevicePtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; }
CudaDevicePtr& operator=(CudaDevicePtr&& other) noexcept {
if (this != &other) {
if (ptr_) cudaFree(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
CudaDevicePtr(const CudaDevicePtr&) = delete;
CudaDevicePtr& operator=(const CudaDevicePtr&) = delete;
T* get() const { return ptr_; }
operator T*() const { return ptr_; }
};
template <typename T>
class CudaHostPtr {
T* ptr_ = nullptr;
public:
CudaHostPtr() = default;
explicit CudaHostPtr(size_t count) {
if (cudaMallocHost(&ptr_, count * sizeof(T)) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to allocate pinned host memory");
}
~CudaHostPtr() {
if (ptr_) cudaFreeHost(ptr_);
}
// Move-only type
CudaHostPtr(CudaHostPtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; }
CudaHostPtr& operator=(CudaHostPtr&& other) noexcept {
if (this != &other) {
if (ptr_) cudaFreeHost(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
CudaHostPtr(const CudaHostPtr&) = delete;
CudaHostPtr& operator=(const CudaHostPtr&) = delete;
T* get() const { return ptr_; }
operator T*() const { return ptr_; }
};
#endif //JFJOCH_CUDAMEMHELPERS_H