Files
aare/benchmarks/fit_benchmark.cpp
T
Khalil Ferjaoui a6afa45b3b
Build on RHEL8 / build (push) Successful in 3m6s
Build on RHEL9 / build (push) Successful in 3m20s
Run tests using data on local RHEL8 / build (push) Successful in 3m36s
Build on local RHEL8 / build (push) Successful in 2m21s
Feature/minuit2 wrapper (#279)
## Unified Minuit2 fitting framework with FitModel API

### Models (`Models.hpp`)
Consolidate all model structs (Gaussian, RisingScurve, FallingScurve)
into a
single header. Each model provides: `eval`, `eval_and_grad`, `is_valid`,
`estimate_par`, `compute_steps`, and `param_info` metadata. No Minuit2
dependency.

### Chi2 functors (`Chi2.hpp`)
Generic `Chi2Model1DGrad` (analytic gradient) templated on the model
struct.
Replaces the separate Chi2Gaussian, Chi2GaussianGradient,
Chi2Scurves, and Chi2ScurvesGradient headers.

### FitModel (`FitModel.hpp`)
Configuration object wrapping `MnUserParameters`, strategy, tolerance,
and
user-override tracking. User constraints (fixed parameters, start
values, limits)
always take precedence over automatic data-driven estimates.

### Fit functions (`Fit.hpp`)
- `fit_pixel<Model, FCN>(model, x, y, y_err)` -> single-pixel,
self-contained
- `fit_pixel<Model, FCN>(model, upar_local, x, y, y_err)` -> pre-cloned
upar for hot loops
- `fit_3d<Model, FCN>(model, x, y, y_err, ..., n_threads)` ->
row-parallel over pixel grid

### Python bindings
- `Pol1`, `Pol2`, `Gaussian`, `RisingScurve`, `FallingScurve` model
classes with
  `FixParameter`, `SetParLimits`, `SetParameter`, and properties for
  `max_calls`, `tolerance`, `compute_errors`
- Single `fit(model, x, y, y_err, n_threads)` dispatch replacing the old
`fit_gaus_minuit`, `fit_gaus_minuit_grad`, `fit_scurve_minuit_grad`,
etc.

### Benchmarks
- Updated `fit_benchmark.cpp` (Google Benchmark) to use the new FitModel
API
- Jupyter notebooks for 1D and 3D S-curve fitting (lmfit vs Minuit2
analytic)
- ~1.8x speedup over lmfit, near-linear thread scaling up to physical
core count

---------

Co-authored-by: Erik Fröjdh <erik.frojdh@psi.ch>
2026-03-30 09:12:23 +02:00

162 lines
5.0 KiB
C++

// SPDX-License-Identifier: MPL-2.0
#include "aare/Fit.hpp"
#include "aare/Chi2.hpp"
#include "aare/Models.hpp"
#include "aare/FitModel.hpp"
#include <benchmark/benchmark.h>
#include <cmath>
#include <random>
#include <string>
#include <vector>
struct TestCase {
std::string name;
double true_A;
double true_mu;
double true_sig;
double noise_frac;
};
static const std::vector<TestCase> &get_test_cases() {
static const std::vector<TestCase> cases = {
{"Clean_signal", 1000.0, 50.0, 5.0, 0.02},
{"Moderate_noise", 1000.0, 50.0, 5.0, 0.10},
{"High_noise", 1000.0, 50.0, 5.0, 0.30},
{"Narrow_peak", 500.0, 25.0, 1.0, 0.05},
{"Wide_peak", 200.0, 100.0, 20.0, 0.05},
{"Off_center_peak", 800.0, -15.0, 3.0, 0.05},
};
return cases;
}
// ----------------------------------------------------------------
// Synthetic data generation (deterministic per test case)
// ----------------------------------------------------------------
static constexpr ssize_t N_POINTS = 100;
static constexpr unsigned SEED = 42;
struct GeneratedData {
aare::NDArray<double, 1> x;
aare::NDArray<double, 1> y;
aare::NDArray<double, 1> y_err;
GeneratedData() : x({N_POINTS}), y({N_POINTS}), y_err({N_POINTS}) {}
};
static GeneratedData generate_gaussian_data(const TestCase &tc) {
GeneratedData d;
double x_min = tc.true_mu - 5.0 * tc.true_sig;
double x_max = tc.true_mu + 5.0 * tc.true_sig;
double dx = (x_max - x_min) / (N_POINTS - 1);
std::mt19937 rng(SEED);
double noise_sigma = tc.noise_frac * tc.true_A;
std::normal_distribution<double> noise(0.0, noise_sigma);
for (ssize_t i = 0; i < N_POINTS; ++i) {
d.x[i] = x_min + i * dx;
double clean = tc.true_A *
std::exp(-std::pow(d.x[i] - tc.true_mu, 2) /
(2.0 * std::pow(tc.true_sig, 2)));
d.y[i] = clean + noise(rng);
d.y_err[i] = noise_sigma;
}
return d;
}
static void report_accuracy(benchmark::State &state,
const TestCase &tc,
const aare::NDArray<double, 1> &result) {
state.counters["dA"] = result(0) - tc.true_A;
state.counters["dMu"] = result(1) - tc.true_mu;
state.counters["dSig"] = result(2) - tc.true_sig;
}
// ----------
// Benchmarks
// ----------
// 1. lmcurve
static void BM_FitGausLm(benchmark::State &state) {
const auto &tc = get_test_cases()[state.range(0)];
auto data = generate_gaussian_data(tc);
auto xv = data.x.view();
auto yv = data.y.view();
aare::NDArray<double, 1> result;
for (auto _ : state) {
result = aare::fit_gaus(xv, yv);
benchmark::DoNotOptimize(result.data());
}
report_accuracy(state, tc, result);
state.SetLabel(tc.name);
}
// 2. Minuit2, analytic gradient (no Hesse)
static void BM_FitGausMinuitGrad(benchmark::State &state) {
const auto &tc = get_test_cases()[state.range(0)];
auto data = generate_gaussian_data(tc);
auto xv = data.x.view();
auto yv = data.y.view();
const auto model = aare::FitModel<aare::model::Gaussian>(/*strategy = */0,
/*max_calls = */500, // increase for noisy signals
/*tolerance = */0.5,
/*compute_errors = */false);
aare::NDArray<double, 1> result;
for (auto _ : state) {
result = aare::fit_pixel<aare::model::Gaussian, aare::func::Chi2Gaussian>(model, xv, yv);
benchmark::DoNotOptimize(result.data());
}
report_accuracy(state, tc, result);
state.SetLabel(tc.name);
}
// 3. Minuit2, analytic gradient + Hesse
static void BM_FitGausMinuitGradHesse(benchmark::State &state) {
const auto &tc = get_test_cases()[state.range(0)];
auto data = generate_gaussian_data(tc);
auto xv = data.x.view();
auto yv = data.y.view();
auto ev = data.y_err.view();
const auto model = aare::FitModel<aare::model::Gaussian>(0, 500, 0.5, true); // compute_errors = true -> Runs Hesse and provides errors on fitted params
aare::NDArray<double, 1> result;
for (auto _ : state) {
result = aare::fit_pixel<aare::model::Gaussian, aare::func::Chi2Gaussian>(model, xv, yv, ev);
benchmark::DoNotOptimize(result.data());
}
// result has 6 elements: [A, mu, sig, err_A, err_mu, err_sig]
report_accuracy(state, tc, result);
// Also report Hesse uncertainties
if (result.size() >= 6) {
state.counters["errA"] = result(3);
state.counters["errMu"] = result(4);
state.counters["errSig"] = result(5);
}
state.SetLabel(tc.name);
}
BENCHMARK(BM_FitGausLm)
->DenseRange(0, 5)
->Unit(benchmark::kMicrosecond);
BENCHMARK(BM_FitGausMinuitGrad)
->DenseRange(0, 5)
->Unit(benchmark::kMicrosecond);
BENCHMARK(BM_FitGausMinuitGradHesse)
->DenseRange(0, 5)
->Unit(benchmark::kMicrosecond);
BENCHMARK_MAIN();