// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #ifndef JFJOCH_CUDAMEMHELPERS_H #define JFJOCH_CUDAMEMHELPERS_H #include #include #include #include #include "../common/JFJochException.h" class CudaStream { cudaStream_t stream_ = nullptr; public: CudaStream(unsigned int flags = cudaStreamDefault) { if (cudaStreamCreateWithFlags(&stream_, flags) != cudaSuccess) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "Failed to create CUDA stream"); } ~CudaStream() { if (stream_) cudaStreamDestroy(stream_); } // Move-only type CudaStream(CudaStream&& other) noexcept : stream_(other.stream_) { other.stream_ = nullptr; } CudaStream& operator=(CudaStream&& other) noexcept { if (this != &other) { if (stream_) cudaStreamDestroy(stream_); stream_ = other.stream_; other.stream_ = nullptr; } return *this; } CudaStream(const CudaStream&) = delete; CudaStream& operator=(const CudaStream&) = delete; operator cudaStream_t() const { return stream_; } cudaStream_t get() const { return stream_; } }; class CudaFFTPlan { cufftHandle plan_ = 0; public: CudaFFTPlan() = default; CudaFFTPlan( int rank, const int* n, const int* inembed, int istride, int idist, const int* onembed, int ostride, int odist, cufftType type, int batch) { if (cufftPlanMany(&plan_, rank, const_cast(n), const_cast(inembed), istride, idist, const_cast(onembed), ostride, odist, type, batch) != CUFFT_SUCCESS) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "Failed to create CUFFT plan with cufftPlanMany"); } // Convenience overload for vector input CudaFFTPlan( int rank, const std::vector& n, const std::vector& inembed, int istride, int idist, const std::vector& onembed, int ostride, int odist, cufftType type, int batch) : CudaFFTPlan(rank, n.data(), inembed.data(), istride, idist, onembed.data(), ostride, odist, type, batch) {} ~CudaFFTPlan() { if (plan_) cufftDestroy(plan_); } // Move-only type CudaFFTPlan(CudaFFTPlan&& other) noexcept : plan_(other.plan_) { other.plan_ = 0; } CudaFFTPlan& operator=(CudaFFTPlan&& other) noexcept { if (this != &other) { if (plan_) cufftDestroy(plan_); plan_ = other.plan_; other.plan_ = 0; } return *this; } CudaFFTPlan(const CudaFFTPlan&) = delete; CudaFFTPlan& operator=(const CudaFFTPlan&) = delete; operator cufftHandle() const { return plan_; } cufftHandle get() const { return plan_; } }; template class CudaDevicePtr { T* ptr_ = nullptr; public: CudaDevicePtr() = default; explicit CudaDevicePtr(size_t count) { if (cudaMalloc(&ptr_, count * sizeof(T)) != cudaSuccess) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "Failed to allocate device memory"); } ~CudaDevicePtr() { if (ptr_) cudaFree(ptr_); } // Move-only type CudaDevicePtr(CudaDevicePtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; } CudaDevicePtr& operator=(CudaDevicePtr&& other) noexcept { if (this != &other) { if (ptr_) cudaFree(ptr_); ptr_ = other.ptr_; other.ptr_ = nullptr; } return *this; } CudaDevicePtr(const CudaDevicePtr&) = delete; CudaDevicePtr& operator=(const CudaDevicePtr&) = delete; T* get() const { return ptr_; } operator T*() const { return ptr_; } }; template class CudaHostPtr { T* ptr_ = nullptr; public: CudaHostPtr() = default; explicit CudaHostPtr(size_t count) { if (cudaMallocHost(&ptr_, count * sizeof(T)) != cudaSuccess) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "Failed to allocate pinned host memory"); } ~CudaHostPtr() { if (ptr_) cudaFreeHost(ptr_); } // Move-only type CudaHostPtr(CudaHostPtr&& other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; } CudaHostPtr& operator=(CudaHostPtr&& other) noexcept { if (this != &other) { if (ptr_) cudaFreeHost(ptr_); ptr_ = other.ptr_; other.ptr_ = nullptr; } return *this; } CudaHostPtr(const CudaHostPtr&) = delete; CudaHostPtr& operator=(const CudaHostPtr&) = delete; T* get() const { return ptr_; } operator T*() const { return ptr_; } }; template class CudaRegisteredVector { std::vector* vec_ = nullptr; bool registered_ = false; static void registerPtr(void* ptr, size_t bytes, unsigned int flags) { cudaError_t err = cudaHostRegister(ptr, bytes, flags); if (err != cudaSuccess) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "cudaHostRegister failed"); } static void unregisterPtr(void* ptr) { cudaError_t err = cudaHostUnregister(ptr); if (err != cudaSuccess) throw JFJochException(JFJochExceptionCategory::GPUCUDAError, "cudaHostUnregister failed"); } public: // Non-owning wrapper. Does NOT provide accessors to the vector. CudaRegisteredVector() = default; CudaRegisteredVector(std::vector& vec, unsigned int flags = cudaHostRegisterDefault) : vec_(&vec) { if (!vec.empty()) { registerPtr(vec.data(), vec.size() * sizeof(T), flags); registered_ = true; } } ~CudaRegisteredVector() { if (registered_ && vec_ && !vec_->empty()) { unregisterPtr(vec_->data()); } } // Move-only CudaRegisteredVector(CudaRegisteredVector&& other) noexcept : vec_(other.vec_), registered_(other.registered_) { other.vec_ = nullptr; other.registered_ = false; } CudaRegisteredVector& operator=(CudaRegisteredVector&& other) noexcept { if (this != &other) { // Clean current registration if (registered_ && vec_ && !vec_->empty()) { unregisterPtr(vec_->data()); } vec_ = other.vec_; registered_ = other.registered_; other.vec_ = nullptr; other.registered_ = false; } return *this; } CudaRegisteredVector(const CudaRegisteredVector&) = delete; CudaRegisteredVector& operator=(const CudaRegisteredVector&) = delete; // Re-register after vector capacity/size change. Caller must ensure // the vector is not registered at the moment of mutation. void rebind(std::vector& vec, unsigned int flags = cudaHostRegisterDefault) { // Unregister previous if needed if (registered_ && vec_ && !vec_->empty()) { unregisterPtr(vec_->data()); } vec_ = &vec; if (!vec.empty()) { registerPtr(vec.data(), vec.size() * sizeof(T), flags); registered_ = true; } else { registered_ = false; } } // Explicit unregister (optional). void unregister() { if (registered_ && vec_ && !vec_->empty()) { unregisterPtr(vec_->data()); registered_ = false; } } bool isRegistered() const { return registered_; } }; #endif //JFJOCH_CUDAMEMHELPERS_H