// Copyright (2019-2024) Paul Scherrer Institute #include "NeuralNetResPredictor.h" #include #include "../common/JFJochException.h" NeuralNetResPredictor::NeuralNetResPredictor(const std::string& model_path) : model_input(512*512), enable(!model_path.empty()) #ifdef JFJOCH_USE_TORCH , device(torch::kCUDA) #endif { #ifdef JFJOCH_USE_TORCH if (enable) { module = torch::jit::load(model_path); module.to(device); } #else enable = false; #endif } template void NeuralNetResPredictor::PrepareInternal(const DiffractionExperiment& experiment, const T* image) { size_t pool_factor = GetMaxPoolFactor(experiment); size_t xpixel = experiment.GetXPixelsNum(); size_t ypixel = experiment.GetYPixelsNum(); size_t start_x = std::lround(experiment.GetBeamX_pxl()); size_t start_y = std::lround(experiment.GetBeamY_pxl()); for (size_t y = 0; y < 512; y++) { size_t y0 = y * pool_factor + start_y; size_t min_yp = std::min(y0 + pool_factor, ypixel); for (size_t x = 0; x < 512; x++) { float val = 0.0; size_t x0 = x * pool_factor + start_x; size_t min_xp = std::min(x0 + pool_factor, ypixel); for (size_t yp = y0; yp < min_yp; yp++) { for (size_t xp = x0; xp < min_xp; xp++) { int16_t pxl = image[yp * xpixel + xp]; if (val < pxl) val = pxl; } } float max_pool = floorf(sqrtf(val)); model_input[512 * y + x] = max_pool; } } } void NeuralNetResPredictor::Prepare(const DiffractionExperiment& experiment, const int16_t *image) { PrepareInternal(experiment, image); } void NeuralNetResPredictor::Prepare(const DiffractionExperiment& experiment, const int32_t *image) { PrepareInternal(experiment, image); } size_t NeuralNetResPredictor::GetMaxPoolFactor(const DiffractionExperiment& experiment) const { float max_direction = std::max(experiment.GetXPixelsNum(), experiment.GetYPixelsNum()) / 2.0; size_t pool_factor = std::lround(max_direction / 512.0f); if (pool_factor <= 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Detector size is too small"); if (pool_factor > 8) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Detector size is too large"); return pool_factor; } std::optional NeuralNetResPredictor::Inference(const DiffractionExperiment& experiment, const void *image) { if (!enable) return {}; #ifdef JFJOCH_USE_TORCH if (experiment.GetPixelDepth() == 2) Prepare(experiment, (int16_t *) image); else Prepare(experiment, (int32_t *) image); auto options = torch::TensorOptions().dtype(at::kFloat); auto model_input_tensor = torch::from_blob(model_input.data(), {1,1,512,512}, options).to(device); std::vector pixels; pixels.emplace_back(std::move(model_input_tensor)); auto output = module.forward(pixels).toTensor(); auto tensor_output = output[0].item(); float two_theta = atanf(((2.0f * experiment.GetPixelSize_mm() / experiment.GetDetectorDistance_mm()) * tensor_output)); float stheta = sinf(two_theta * 0.5f); float resolution = experiment.GetWavelength_A() / (2.0f * stheta); return resolution; #else return {}; #endif } const std::vector &NeuralNetResPredictor::GetModelInput() const { return model_input; }