Files
Jungfraujoch/image_analysis/indexing/CUDAMemHelpers.h
2025-09-21 19:27:51 +02:00

239 lines
7.7 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#ifndef JFJOCH_CUDAMEMHELPERS_H
#define JFJOCH_CUDAMEMHELPERS_H
#include <cuda_runtime.h>
#include <cufft.h>
#include <stdexcept>
#include <vector>
#include "../common/JFJochException.h"
class CudaStream {
cudaStream_t stream_ = nullptr;
public:
CudaStream(unsigned int flags = cudaStreamDefault) {
if (cudaStreamCreateWithFlags(&stream_, flags) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to create CUDA stream");
}
~CudaStream() {
if (stream_) cudaStreamDestroy(stream_);
}
// Move-only type
CudaStream(CudaStream&& other) noexcept : stream_(other.stream_) { other.stream_ = nullptr; }
CudaStream& operator=(CudaStream&& other) noexcept {
if (this != &other) {
if (stream_) cudaStreamDestroy(stream_);
stream_ = other.stream_;
other.stream_ = nullptr;
}
return *this;
}
CudaStream(const CudaStream&) = delete;
CudaStream& operator=(const CudaStream&) = delete;
operator cudaStream_t() const { return stream_; }
cudaStream_t get() const { return stream_; }
};
class CudaFFTPlan {
cufftHandle plan_ = 0;
public:
CudaFFTPlan() = default;
CudaFFTPlan(
int rank,
const int* n,
const int* inembed, int istride, int idist,
const int* onembed, int ostride, int odist,
cufftType type, int batch)
{
if (cufftPlanMany(&plan_, rank, const_cast<int*>(n),
const_cast<int*>(inembed), istride, idist,
const_cast<int*>(onembed), ostride, odist,
type, batch) != CUFFT_SUCCESS)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to create CUFFT plan with cufftPlanMany");
}
// Convenience overload for vector input
CudaFFTPlan(
int rank,
const std::vector<int>& n,
const std::vector<int>& inembed, int istride, int idist,
const std::vector<int>& onembed, int ostride, int odist,
cufftType type, int batch)
: CudaFFTPlan(rank, n.data(), inembed.data(), istride, idist, onembed.data(), ostride, odist, type, batch)
{}
~CudaFFTPlan() {
if (plan_) cufftDestroy(plan_);
}
// Move-only type
CudaFFTPlan(CudaFFTPlan&& other) noexcept : plan_(other.plan_) { other.plan_ = 0; }
CudaFFTPlan& operator=(CudaFFTPlan&& other) noexcept {
if (this != &other) {
if (plan_) cufftDestroy(plan_);
plan_ = other.plan_;
other.plan_ = 0;
}
return *this;
}
CudaFFTPlan(const CudaFFTPlan&) = delete;
CudaFFTPlan& operator=(const CudaFFTPlan&) = delete;
operator cufftHandle() const { return plan_; }
cufftHandle get() const { return plan_; }
};
template <typename T>
class CudaDevicePtr {
T* ptr_ = nullptr;
public:
CudaDevicePtr() = default;
explicit CudaDevicePtr(size_t count) {
if (cudaMalloc(&ptr_, count * sizeof(T)) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to allocate device memory");
}
~CudaDevicePtr() {
if (ptr_) cudaFree(ptr_);
}
// Move-only type
CudaDevicePtr(CudaDevicePtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; }
CudaDevicePtr& operator=(CudaDevicePtr&& other) noexcept {
if (this != &other) {
if (ptr_) cudaFree(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
CudaDevicePtr(const CudaDevicePtr&) = delete;
CudaDevicePtr& operator=(const CudaDevicePtr&) = delete;
T* get() const { return ptr_; }
operator T*() const { return ptr_; }
};
template <typename T>
class CudaHostPtr {
T* ptr_ = nullptr;
public:
CudaHostPtr() = default;
explicit CudaHostPtr(size_t count) {
if (cudaMallocHost(&ptr_, count * sizeof(T)) != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
"Failed to allocate pinned host memory");
}
~CudaHostPtr() {
if (ptr_) cudaFreeHost(ptr_);
}
// Move-only type
CudaHostPtr(CudaHostPtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; }
CudaHostPtr& operator=(CudaHostPtr&& other) noexcept {
if (this != &other) {
if (ptr_) cudaFreeHost(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
CudaHostPtr(const CudaHostPtr&) = delete;
CudaHostPtr& operator=(const CudaHostPtr&) = delete;
T* get() const { return ptr_; }
operator T*() const { return ptr_; }
};
template <typename T>
class CudaRegisteredVector {
std::vector<T>* vec_ = nullptr;
bool registered_ = false;
static void registerPtr(void* ptr, size_t bytes, unsigned int flags) {
cudaError_t err = cudaHostRegister(ptr, bytes, flags);
if (err != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "cudaHostRegister failed");
}
static void unregisterPtr(void* ptr) {
cudaError_t err = cudaHostUnregister(ptr);
if (err != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "cudaHostUnregister failed");
}
public:
// Non-owning wrapper. Does NOT provide accessors to the vector.
CudaRegisteredVector() = default;
CudaRegisteredVector(std::vector<T>& vec, unsigned int flags = cudaHostRegisterDefault)
: vec_(&vec)
{
if (!vec.empty()) {
registerPtr(vec.data(), vec.size() * sizeof(T), flags);
registered_ = true;
}
}
~CudaRegisteredVector() {
if (registered_ && vec_ && !vec_->empty()) {
unregisterPtr(vec_->data());
}
}
// Move-only
CudaRegisteredVector(CudaRegisteredVector&& other) noexcept
: vec_(other.vec_), registered_(other.registered_) {
other.vec_ = nullptr;
other.registered_ = false;
}
CudaRegisteredVector& operator=(CudaRegisteredVector&& other) noexcept {
if (this != &other) {
// Clean current registration
if (registered_ && vec_ && !vec_->empty()) {
unregisterPtr(vec_->data());
}
vec_ = other.vec_;
registered_ = other.registered_;
other.vec_ = nullptr;
other.registered_ = false;
}
return *this;
}
CudaRegisteredVector(const CudaRegisteredVector&) = delete;
CudaRegisteredVector& operator=(const CudaRegisteredVector&) = delete;
// Re-register after vector capacity/size change. Caller must ensure
// the vector is not registered at the moment of mutation.
void rebind(std::vector<T>& vec, unsigned int flags = cudaHostRegisterDefault) {
// Unregister previous if needed
if (registered_ && vec_ && !vec_->empty()) {
unregisterPtr(vec_->data());
}
vec_ = &vec;
if (!vec.empty()) {
registerPtr(vec.data(), vec.size() * sizeof(T), flags);
registered_ = true;
} else {
registered_ = false;
}
}
// Explicit unregister (optional).
void unregister() {
if (registered_ && vec_ && !vec_->empty()) {
unregisterPtr(vec_->data());
registered_ = false;
}
}
bool isRegistered() const { return registered_; }
};
#endif //JFJOCH_CUDAMEMHELPERS_H