36 lines
1018 B
Plaintext
36 lines
1018 B
Plaintext
// Copyright (2019-2023) Paul Scherrer Institute
|
|
|
|
#include "CUDAWrapper.h"
|
|
#include "JFJochException.h"
|
|
|
|
inline void cuda_err(cudaError_t val) {
|
|
if (val != cudaSuccess)
|
|
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, cudaGetErrorString(val));
|
|
}
|
|
|
|
int32_t get_gpu_count() {
|
|
int device_count;
|
|
cudaError_t val = cudaGetDeviceCount(&device_count);
|
|
switch (val) {
|
|
case cudaSuccess:
|
|
return device_count;
|
|
case cudaErrorNoDevice:
|
|
case cudaErrorInsufficientDriver:
|
|
return 0;
|
|
default:
|
|
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, cudaGetErrorString(val));
|
|
}
|
|
}
|
|
|
|
void set_gpu(int32_t dev_id) {
|
|
auto dev_count = get_gpu_count();
|
|
|
|
// Ignore if no GPU present
|
|
if (dev_count > 0) {
|
|
if ((dev_id < 0) || (dev_id >= dev_count))
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Device ID cannot be negative");
|
|
|
|
cuda_err(cudaSetDevice(dev_id));
|
|
}
|
|
}
|