Files
Jungfraujoch/tools/jfjoch_resonet_test.cpp
2025-05-12 14:17:24 +02:00

173 lines
5.6 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include <fstream>
#include <chrono>
#include <future>
#include <cmath>
#include <getopt.h>
#include <httplib/httplib.h>
#include "../image_analysis/NeuralNetInferenceClient.h"
#include "../writer/HDF5Objects.h"
#include "../common/print_license.h"
bool verbose = false;
int64_t niter = 1000;
size_t nthreads = 1;
int nquads = 4;
Logger logger("jfjoch_resonet_test");
void print_usage() {
logger.Info("Usage ./jfjoch_resonet_test {options}");
logger.Info("");
logger.Info("Available options:");
logger.Info("--verbose/-v Verbose");
logger.Info("--iterations=<int>/-I<int> Number of iterations (default: 1000)");
logger.Info("--threads=<int>/-N<int> Number of threads (default: 1)");
logger.Info("--quads=<int>/-q<int> Number of quads 1-4 (default: 4)");
logger.Info("");
}
int parse_options(int argc, char **argv) {
int c;
static struct option long_options[] = {
{"verbose", no_argument, 0, 'v'},
{"iterations", required_argument, 0, 'I'},
{"threads", required_argument, 0, 'N'},
{"quads", required_argument, 0, 'q'},
{"help", no_argument, 0, 'h'},
{0, 0, 0, 0}
};
int option_index = 0;
int opt;
while ((opt = getopt_long(argc, argv, "q:I:N:vh",long_options, &option_index)) != -1 ) {
switch (opt) {
case 'q':
nquads = atol(optarg);
if ((nquads < 1) || (nquads > 4)){
logger.Info("Number of quads must be in range from 1 to 4");
exit(EXIT_FAILURE);
}
break;
case 'I':
niter = atol(optarg);
if (niter <= 0) {
logger.Info("Number of iterations requires positive number");
exit(EXIT_FAILURE);
}
break;
case 'N':
nthreads = atol(optarg);
if (nthreads <= 0) {
logger.Info("Number of threads requires positive number");
exit(EXIT_FAILURE);
}
break;
case 'v':
verbose = true;
break;
case 'h':
print_usage();
exit(EXIT_SUCCESS);
default:
logger.Error("Unknown option {}", opt);
exit(EXIT_FAILURE);
}
}
return optind;
}
int main(int argc, char **argv) {
print_license("jfjoch_resonet_test");
RegisterHDF5Filter();
size_t first_argc = parse_options(argc, argv);
logger.Verbose(verbose);
DiffractionExperiment experiment(DetJF4M());
experiment.DetectorDistance_mm(75).IncidentEnergy_keV(12.4).BeamY_pxl(1136).BeamX_pxl(1090);
NeuralNetInferenceClient predictor;
for (size_t i = first_argc; i < argc; i++) {
try {
predictor.AddHost(argv[i]);
logger.Info("Added host {}", argv[i]);
} catch (std::exception &e) {
logger.ErrorException(e);
exit(EXIT_FAILURE);
}
}
predictor.AddLogger(&logger);
if (predictor.GetHostCount() == 0) {
logger.Info("Default localhost:8000 for inference.");
predictor.AddHost("localhost", 8000);
}
HDF5ReadOnlyFile data("../../tests/test_data/compression_benchmark.h5");
HDF5DataSet dataset(data, "/entry/data/data");
HDF5DataSpace file_space(dataset);
size_t xpixel = file_space.GetDimensions()[2];
size_t ypixel = file_space.GetDimensions()[1];
size_t nimages = file_space.GetDimensions()[0];
std::vector<int16_t> image_conv (nimages * xpixel * ypixel);
std::vector<hsize_t> start = {0,0,0};
std::vector<hsize_t> file_size = {nimages, ypixel, xpixel};
dataset.ReadVector(image_conv, start, file_size);
logger.Info("Dimension {:d}x{:d} pxl, max pooling factor {:d}",
xpixel, ypixel, predictor.GetMaxPoolFactor(experiment));
// Start time
auto start_time = std::chrono::high_resolution_clock::now();
std::atomic<int64_t> done = 0;
std::vector<std::future<void>> futures;
for (int t = 0; t < nthreads; t++) {
futures.emplace_back(std::async(std::launch::async, [&] {
for (int i = 0; i < niter; i += nthreads) {
auto iter_start_time = std::chrono::high_resolution_clock::now();
size_t image_number = i % nimages;
const auto image = image_conv.data() + image_number * xpixel * ypixel;
auto val = predictor.Inference(experiment, image, nquads);
auto iter_end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> iter_diff = iter_end_time - iter_start_time;
auto iter_diff_ms = std::lround(iter_diff.count() * 1000.0);
if (val) {
done++;
logger.Debug("Res: {:2.f} Duration {:d} ms", val.value(), iter_diff_ms);
} else
logger.Warning("Missing results");
}
}));
}
for (auto &f: futures)
f.get();
// End time
auto end = std::chrono::high_resolution_clock::now();
// Calculate duration
std::chrono::duration<double> diff = end - start_time;
// Output time taken
logger.Info("Total performance {:.1f} Hz (iteration: {:d}; {:.1f}%)", niter / diff.count(), niter,
std::floor(done / static_cast<double>(niter)) * 100.0);
}