IndexerThreadPool: Bind to GPU NUMA node

This commit is contained in:
2025-10-04 09:22:43 +02:00
parent 0563f15a23
commit a15aa4eaa7
7 changed files with 83 additions and 2 deletions

View File

@@ -11,4 +11,8 @@ int32_t get_gpu_count() {
void set_gpu(int32_t dev_id) {}
int get_gpu_numa_node(int dev_id) {
return -1;
}
#endif

View File

@@ -1,6 +1,8 @@
// SPDX-FileCopyrightText: 2024 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include <fstream>
#include "CUDAWrapper.h"
#include "JFJochException.h"
@@ -21,6 +23,7 @@ int32_t get_gpu_count() {
default:
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, cudaGetErrorString(val));
}
}
void set_gpu(int32_t dev_id) {
@@ -34,3 +37,57 @@ void set_gpu(int32_t dev_id) {
cuda_err(cudaSetDevice(dev_id));
}
}
// Return CUDA device PCI Bus ID as "domain:bus:device.function", e.g., "0000:65:00.0"
static std::string get_cuda_device_pci_bus_id(int dev_id) {
// CUDA API provides cudaDeviceGetPCIBusId
char buf[64] = {0};
cudaDeviceProp prop;
cudaError_t st = cudaGetDeviceProperties(&prop, dev_id);
if (st != cudaSuccess) {
throw JFJochException(JFJochExceptionCategory::GPUCUDAError, cudaGetErrorString(st));
}
// Prefer cudaDeviceGetPCIBusId for full id including domain and function
cudaError_t st2 = cudaDeviceGetPCIBusId(buf, static_cast<int>(sizeof(buf)), dev_id);
if (st2 == cudaSuccess) {
return std::string(buf);
}
// Fallback: synthesize from properties (domain may be missing on very old drivers)
// Note: function is typically ".0"
char alt[64];
std::snprintf(alt, sizeof(alt), "%04x:%02x:%02x.%u",
prop.pciDomainID, prop.pciBusID, prop.pciDeviceID, 0u);
return std::string(alt);
}
// Resolve NUMA node from PCI address using Linux sysfs
// Returns:
// >=0 NUMA node index
// -1 if NUMA node is not available/unknown
int get_gpu_numa_node(int dev_id) {
auto dev_count = get_gpu_count();
if (dev_count <= 0) return -1;
if (dev_id < 0 || dev_id >= dev_count) {
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid CUDA device ID");
}
// We don't need to call cudaSetDevice here; querying by id is sufficient.
const std::string pci_bus_id = get_cuda_device_pci_bus_id(dev_id); // "dddd:bb:dd.f"
// sysfs path for PCI device. Examples:
// - /sys/bus/pci/devices/0000:65:00.0/numa_node
const std::string sysfs_path = std::string("/sys/bus/pci/devices/") + pci_bus_id + "/numa_node";
std::ifstream f(sysfs_path);
if (!f.is_open()) {
// On some systems, the symlink may be via /sys/class/drm or nvidia, but primary path should exist.
return -1;
}
int numa = -1;
f >> numa;
if (!f.good()) {
return -1;
}
return numa;
}

View File

@@ -8,5 +8,6 @@
int32_t get_gpu_count();
void set_gpu(int32_t dev_id);
int get_gpu_numa_node(int dev_id);
#endif //JUNGFRAUJOCH_CUDAWRAPPER_H

View File

@@ -108,6 +108,16 @@ void NUMAHWPolicy::SelectGPU(int32_t gpu) {
}
}
void NUMAHWPolicy::SelectGPUAndItsNUMA(int32_t gpu) {
int numa = get_gpu_numa_node(gpu);
if (numa >= 0) {
RunOnNode(numa);
MemOnNode(numa);
}
set_gpu(gpu);
}
const std::string &NUMAHWPolicy::GetName() const {
return name;
}

View File

@@ -30,6 +30,7 @@ public:
static void RunOnNode(int32_t cpu_node);
static void MemOnNode(int32_t mem_node);
static void SelectGPU(int32_t gpu);
static void SelectGPUAndItsNUMA(int32_t gpu);
};
#endif //JUNGFRAUJOCH_NUMAHWPOLICY_H

View File

@@ -67,9 +67,17 @@ std::future<std::optional<CrystalLattice> > IndexerThreadPool::Run(const Diffrac
return result;
}
void IndexerThreadPool::Worker(size_t threadIndex, const NUMAHWPolicy &numa_policy, const IndexingSettings &settings) {
void IndexerThreadPool::Worker(int32_t threadIndex, const NUMAHWPolicy &numa_policy, const IndexingSettings &settings) {
try {
#ifdef JFJOCH_USE_CUDA
auto gpu_count = get_gpu_count();
if (gpu_count > 0)
NUMAHWPolicy::SelectGPUAndItsNUMA(threadIndex % gpu_count);
else
numa_policy.Bind(threadIndex);
#else
numa_policy.Bind(threadIndex);
#endif
} catch (...) {
// NUMA policy errors are not critical and should be ignored for the time being.
}

View File

@@ -39,7 +39,7 @@ class IndexerThreadPool {
std::latch workers_ready;
bool stop;
void Worker(size_t threadIndex, const NUMAHWPolicy &numa_policy, const IndexingSettings& settings);
void Worker(int32_t threadIndex, const NUMAHWPolicy &numa_policy, const IndexingSettings& settings);
public:
IndexerThreadPool(const IndexingSettings& settings, const NUMAHWPolicy &numa_policy = NUMAHWPolicy());
~IndexerThreadPool();