Files
Jungfraujoch/resonet/NeuralNetResPredictor.cpp

94 lines
3.1 KiB
C++

// Copyright (2019-2024) Paul Scherrer Institute
#include "NeuralNetResPredictor.h"
#include <cmath>
#include "../common/JFJochException.h"
NeuralNetResPredictor::NeuralNetResPredictor(const DiffractionExperiment &in_experiment)
: experiment(in_experiment), model_input(512*512)
#ifdef JFJOCH_USE_TORCH
, device(torch::kCUDA)
#endif
{
float max_direction = std::max(in_experiment.GetXPixelsNum(), in_experiment.GetYPixelsNum()) / 2.0;
pool_factor = std::lround(max_direction / 512.0f);
if (pool_factor <= 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Detector size is too small");
if (pool_factor > 8)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Detector size is too large");
#ifdef JFJOCH_USE_TORCH
module = torch::jit::load(experiment.GetNeuralNetModelPath());
module.to(device);
#endif
}
template<class T>
void NeuralNetResPredictor::PrepareInternal(const T* image) {
size_t xpixel = experiment.GetXPixelsNum();
size_t ypixel = experiment.GetYPixelsNum();
size_t start_x = std::lround(experiment.GetBeamX_pxl());
size_t start_y = std::lround(experiment.GetBeamY_pxl());
for (size_t y = 0; y < 512; y++) {
size_t y0 = y * pool_factor + start_y;
size_t min_yp = std::min(y0 + pool_factor, ypixel);
for (size_t x = 0; x < 512; x++) {
float val = 0.0;
size_t x0 = x * pool_factor + start_x;
size_t min_xp = std::min(x0 + pool_factor, ypixel);
for (size_t yp = y0; yp < min_yp; yp++) {
for (size_t xp = x0; xp < min_xp; xp++) {
int16_t pxl = image[yp * xpixel + xp];
if (val < pxl)
val = pxl;
}
}
float max_pool = floorf(sqrtf(val));
model_input[512 * y + x] = max_pool;
}
}
}
void NeuralNetResPredictor::Prepare(const int16_t *image) {
PrepareInternal(image);
}
void NeuralNetResPredictor::Prepare(const int32_t *image) {
PrepareInternal(image);
}
size_t NeuralNetResPredictor::GetMaxPoolFactor() const {
return pool_factor;
}
float NeuralNetResPredictor::Inference(const void *image) {
if (experiment.GetPixelDepth() == 2)
Prepare((int16_t *) image);
else
Prepare((int32_t *) image);
#ifdef JFJOCH_USE_TORCH
auto options = torch::TensorOptions().dtype(at::kFloat);
auto model_input_tensor = torch::from_blob(model_input.data(), {1,1,512,512}, options).to(device);
std::vector<torch::jit::IValue> pixels;
pixels.emplace_back(std::move(model_input_tensor));
auto output = module.forward(pixels).toTensor();
auto tensor_output = output[0].item<float>();
float two_theta = atanf(((2.0f * experiment.GetPixelSize_mm() / experiment.GetDetectorDistance_mm()) * tensor_output));
float stheta = sinf(two_theta * 0.5f);
float resolution = experiment.GetWavelength_A() / (2.0f * stheta);
#else
float resolution = 50.0;
#endif
return resolution;
}
const std::vector<float> &NeuralNetResPredictor::GetModelInput() const {
return model_input;
}