Files
Jungfraujoch/image_analysis/NeuralNetResPredictor.h
2025-05-05 19:32:22 +02:00

39 lines
1.4 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#ifndef JUNGFRAUJOCH_NEURALNETRESPREDICTOR_H
#define JUNGFRAUJOCH_NEURALNETRESPREDICTOR_H
#ifdef JFJOCH_USE_TORCH
#include <torch/script.h>
#endif
#include "../common/DiffractionExperiment.h"
// Based on model described in:
// Mendez, D., Holton, J. M., Lyubimov, A. Y., Hollatz, S., Mathews, I. I., Cichosz, A., Martirosyan, V.,
// Zeng, T., Stofer, R., Liu, R., Song, J., McPhillips, S., Soltis, M. & Cohen, A. E. (2024).
// Acta Cryst. D80, 26-43.
class NeuralNetResPredictor {
std::vector<float> model_input;
bool enable;
#ifdef JFJOCH_USE_TORCH
torch::Device device;
torch::jit::script::Module module;
#endif
template<class T>
void PrepareInternal(const DiffractionExperiment& experiment, const T* image);
public:
explicit NeuralNetResPredictor(const std::string& model_path);
void Prepare(const DiffractionExperiment& experiment, const int16_t* image);
void Prepare(const DiffractionExperiment& experiment, const int32_t* image);
void Prepare(const DiffractionExperiment& experiment, const int8_t* image);
std::optional<float> Inference(const DiffractionExperiment& experiment, const void* image);
size_t GetMaxPoolFactor(const DiffractionExperiment& experiment) const;
const std::vector<float> &GetModelInput() const;
};
#endif //JUNGFRAUJOCH_NEURALNETRESPREDICTOR_H