1.0.0-rc.66
This commit is contained in:
140
image_analysis/indexing/IndexerThreadPool.cpp
Normal file
140
image_analysis/indexing/IndexerThreadPool.cpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
|
||||
// SPDX-License-Identifier: GPL-3.0-only
|
||||
|
||||
#include "IndexerThreadPool.h"
|
||||
#include "../common/CUDAWrapper.h"
|
||||
|
||||
#ifdef JFJOCH_USE_CUDA
|
||||
#include "FFBIDXIndexer.h"
|
||||
#include "FFTIndexer.h"
|
||||
#endif
|
||||
|
||||
IndexerThreadPool::IndexerThreadPool(const NUMAHWPolicy &numa_policy, const IndexingSettings& settings)
|
||||
: stop(false), workers_ready(settings.GetIndexingThreads()) {
|
||||
for (size_t i = 0; i < settings.GetIndexingThreads(); ++i)
|
||||
workers.emplace_back([this, i, numa_policy, settings] { Worker(i, numa_policy, settings); });
|
||||
workers_ready.wait();
|
||||
|
||||
if (failed_start) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
stop = true;
|
||||
}
|
||||
cond.notify_all();
|
||||
|
||||
for (std::thread &worker : workers) {
|
||||
if (worker.joinable())
|
||||
worker.join();
|
||||
}
|
||||
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
|
||||
"Cannot configure indexer (likely too many threads, not enough memory)");
|
||||
}
|
||||
}
|
||||
|
||||
IndexerThreadPool::~IndexerThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
stop = true;
|
||||
}
|
||||
cond.notify_all();
|
||||
|
||||
for (std::thread &worker : workers) {
|
||||
if (worker.joinable())
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
std::future<std::optional<CrystalLattice>> IndexerThreadPool::Run(const DiffractionExperiment& experiment,
|
||||
DataMessage& message,
|
||||
const std::vector<DiffractionSpot>& coord) {
|
||||
|
||||
// Create a promise/future pair
|
||||
auto promise = std::make_shared<std::promise<std::optional<CrystalLattice>>>();
|
||||
std::future<std::optional<CrystalLattice>> result = promise->get_future();
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
|
||||
// Don't allow enqueueing after stopping the pool
|
||||
if (stop) {
|
||||
throw std::runtime_error("Cannot enqueue on stopped thread pool");
|
||||
}
|
||||
|
||||
// Create a task package with the data message and coordinates
|
||||
taskQueue.emplace(TaskPackage{promise, &experiment, &message, &coord});
|
||||
}
|
||||
|
||||
cond.notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void IndexerThreadPool::Worker(size_t threadIndex, const NUMAHWPolicy &numa_policy, const IndexingSettings& settings) {
|
||||
try {
|
||||
numa_policy.Bind(threadIndex);
|
||||
} catch (...) {
|
||||
// NUMA policy errors are not critical and should be ignored for the time being.
|
||||
}
|
||||
|
||||
std::unique_ptr<Indexer> fft_indexer, ffbidx_indexer;
|
||||
|
||||
#ifdef JFJOCH_USE_CUDA
|
||||
try {
|
||||
if (get_gpu_count() > 0) {
|
||||
fft_indexer = std::make_unique<FFTIndexer>(settings);
|
||||
ffbidx_indexer = std::make_unique<FFBIDXIndexer>();
|
||||
}
|
||||
} catch (...) {
|
||||
failed_start = true;
|
||||
}
|
||||
#endif
|
||||
workers_ready.count_down();
|
||||
|
||||
while (true) {
|
||||
TaskPackage task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
|
||||
// Add a timeout to the wait to ensure we can exit even if no notification
|
||||
cond.wait_for(lock, std::chrono::seconds(1), [this] {
|
||||
return stop || !taskQueue.empty();
|
||||
});
|
||||
|
||||
// Check for exit conditions
|
||||
if (stop && taskQueue.empty())
|
||||
return; // Exit cleanly
|
||||
|
||||
if (!taskQueue.empty()) {
|
||||
task = std::move(taskQueue.front());
|
||||
taskQueue.pop();
|
||||
} else {
|
||||
continue; // No tasks, go back to waiting
|
||||
}
|
||||
}
|
||||
try {
|
||||
std::optional<CrystalLattice> result;
|
||||
|
||||
auto algorithm = task.experiment->GetIndexingAlgorithm();
|
||||
Indexer *indexer = nullptr;
|
||||
|
||||
if (algorithm == IndexingAlgorithmEnum::FFT && fft_indexer) {
|
||||
indexer = fft_indexer.get();
|
||||
} else if (algorithm == IndexingAlgorithmEnum::FFBIDX && ffbidx_indexer) {
|
||||
indexer = ffbidx_indexer.get();
|
||||
}
|
||||
|
||||
if (indexer) {
|
||||
indexer->Setup(*task.experiment);
|
||||
result = indexer->Run(*task.message, *task.coord);
|
||||
}
|
||||
|
||||
// Set the result via the promise
|
||||
if (task.promise) {
|
||||
task.promise->set_value(result);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
if (task.promise)
|
||||
task.promise->set_exception(std::current_exception());
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user