mirror of
https://github.com/slsdetectorgroup/aare.git
synced 2026-04-21 05:04:36 +02:00
a6afa45b3b
## Unified Minuit2 fitting framework with FitModel API ### Models (`Models.hpp`) Consolidate all model structs (Gaussian, RisingScurve, FallingScurve) into a single header. Each model provides: `eval`, `eval_and_grad`, `is_valid`, `estimate_par`, `compute_steps`, and `param_info` metadata. No Minuit2 dependency. ### Chi2 functors (`Chi2.hpp`) Generic `Chi2Model1DGrad` (analytic gradient) templated on the model struct. Replaces the separate Chi2Gaussian, Chi2GaussianGradient, Chi2Scurves, and Chi2ScurvesGradient headers. ### FitModel (`FitModel.hpp`) Configuration object wrapping `MnUserParameters`, strategy, tolerance, and user-override tracking. User constraints (fixed parameters, start values, limits) always take precedence over automatic data-driven estimates. ### Fit functions (`Fit.hpp`) - `fit_pixel<Model, FCN>(model, x, y, y_err)` -> single-pixel, self-contained - `fit_pixel<Model, FCN>(model, upar_local, x, y, y_err)` -> pre-cloned upar for hot loops - `fit_3d<Model, FCN>(model, x, y, y_err, ..., n_threads)` -> row-parallel over pixel grid ### Python bindings - `Pol1`, `Pol2`, `Gaussian`, `RisingScurve`, `FallingScurve` model classes with `FixParameter`, `SetParLimits`, `SetParameter`, and properties for `max_calls`, `tolerance`, `compute_errors` - Single `fit(model, x, y, y_err, n_threads)` dispatch replacing the old `fit_gaus_minuit`, `fit_gaus_minuit_grad`, `fit_scurve_minuit_grad`, etc. ### Benchmarks - Updated `fit_benchmark.cpp` (Google Benchmark) to use the new FitModel API - Jupyter notebooks for 1D and 3D S-curve fitting (lmfit vs Minuit2 analytic) - ~1.8x speedup over lmfit, near-linear thread scaling up to physical core count --------- Co-authored-by: Erik Fröjdh <erik.frojdh@psi.ch>
340 KiB
340 KiB
In [1]:
import time import random import numpy as np import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec import sys sys.path.insert(0, '/home/ferjao_k/sw/aare/build') from aare import fit_gaus # lmfit from aare import Gaussian, fit # minuit2
In [2]:
ROWS = 100 COLS = 100 N_SCAN = 100 NOISE_FRAC = 0.05 SEED = 42 N_THREADS = 4 N_REPEATS = 7 N_WARMUP = 3 # untimed iterations (icache + branch predictor warmup) COOLDOWN = 2.0 # seconds between (method, thread_count) pairs
Data generator¶
In [3]:
def generate_3d_data(rows, cols, n_scan, noise_frac, seed): """ Generate a synthetic detector image stack where each pixel has a Gaussian response curve with per-pixel variation in A, mu, sigma. Returns x (n_scan,), y (rows, cols, n_scan), y_err (rows, cols, n_scan), and the ground-truth parameter arrays. """ rng = np.random.default_rng(seed) # Per-pixel true params each of shape: [rows, cols, 1] A_true = rng.uniform(200, 1000, size=(rows, cols)) mu_true = rng.uniform(20, 80, size=(rows, cols)) sig_true = rng.uniform(3, 12, size=(rows, cols)) # One common binned energy array x = np.linspace(0, 100, n_scan) # shape [1, 1, nscan] # Build ground truth signals per-pixel exponent = -0.5 * ((x[None, None, :] - mu_true[:, :, None]) / sig_true[:,:, None])**2 # shape [rows, cols, nscan] y_clean = A_true[:, :, None] * np.exp(exponent) # Perturb with noise noise_sigma = noise_frac * A_true[:, :, None] * np.ones_like(y_clean) # shape [rows, cols, nscan] noise = rng.normal(0, noise_sigma) y = y_clean + noise y_err = noise_sigma.copy() return x, y, y_err, A_true, mu_true, sig_true
Profiling function¶
In [4]:
def bench(fn, n_warmup=N_WARMUP, n_repeats=N_REPEATS): """ Warmup then time `fn` over `n_repeats` calls. Returns (last_result, list_of_walltimes_in_seconds). """ # warmup: primes icache, branch predictor, and lets CPU ramp to boost clock for _ in range(n_warmup): res = fn() times = [] for _ in range(n_repeats): t0 = time.perf_counter() res = fn() t1 = time.perf_counter() times.append(t1 - t0) return res, times
Quick check on small (2x2) frame¶
In [5]:
# Generate 2 x 2 dataset of Gaussian-like profiles for each pixel x2, y2, yerr2, true_A2, true_mu2, true_sig2 = generate_3d_data( 2, 2, N_SCAN, NOISE_FRAC, SEED ) model_g = Gaussian() model_g.compute_errors = True result = model_g.fit(x2, y2, yerr2) from pprint import pprint print("== True Gaussian params == ") print("A_true = \n", true_A2) print("mu_true = \n", true_mu2) print("sig_true = \n",true_sig2) print("\n") print("== Fit results ==") par = result['par'] # print(par) A_fit = par[:, :, 0] mu_fit = par[:, :, 1] sig_fit = par[:, :, 2] print("A_fit = \n", A_fit) print("mu_fit = \n", mu_fit) print("sig_fit = \n", sig_fit)
== True Gaussian params == A_true = [[819.16483884 551.1027518 ] [886.87833593 757.89442325]] mu_true = [[25.65064087 78.5373411 ] [65.66838212 67.16385832]] sig_true = [[ 4.15302269 7.05347344] [ 6.33718222 11.3408849 ]] == Fit results == A_fit = [[812.09277132 559.04069721] [899.09335849 759.24481682]] mu_fit = [[25.6598209 78.40461782] [65.52261318 66.84540995]] sig_fit = [[ 4.2778026 7.041045 ] [ 6.29190225 11.34233504]]
In [6]:
fig, ax = plt.subplots(2, 2, figsize=(12,8)) # Gaussians in 2x2 frame: True vs Fit for row in range(2): for col in range(2): ax[row, col].plot(x2, y2[row, col,:], label="data") ax[row, col].plot(x2, model_g(x2, result['par'][row, col,:]), linewidth=1, color="green", label="minuit") ax[row, col].set_title(f"Gaussian Fit to data in pixel [{row}, {col}]") ax[row, col].legend()
Fit data with different backends¶
In [7]:
# =============== # DATA GENERATION # =============== print(f"Generating synthetic data: {ROWS}x{COLS} pixels, " f"{N_SCAN} scan points, noise_frac={NOISE_FRAC}\n") x, y, yerr, true_A, true_mu, true_sig = generate_3d_data( ROWS, COLS, N_SCAN, NOISE_FRAC, SEED ) model = Gaussian() print(f"model.max_calls = {model.max_calls}") print(f"model.tolerance = {model.tolerance}") print("model.compute_errors =", model.compute_errors) METHOD_DEFS = [ ("lmfit (LM)", lambda nt: lambda: fit_gaus(x, y, n_threads=nt), "#2196F3", {"linewidth": 3.0, "linestyle": "-"}), ("Minuit2 (obj API)", lambda nt: lambda: model.fit(x, y, n_threads=nt), "#FF9800", {"linewidth": 2.5, "linestyle": ":"}), ] colors = {label: c for label, _, c, _ in METHOD_DEFS} styles = {label: s for label, _, _, s in METHOD_DEFS}
Generating synthetic data: 100x100 pixels, 100 scan points, noise_frac=0.05 model.max_calls = 100 model.tolerance = 0.5 model.compute_errors = False
In [8]:
# ==================================== # SINGLE-CALL BENCHMARK (at N_THREADS) # ==================================== def extract_result(label, res): """Normalize return values across fitters into a common dict.""" if isinstance(res, dict): out = {"par": res["par"]} if "par_err" in res: out["par_err"] = res["par_err"] if "chi2" in res: out["chi2"] = res["chi2"] return out # fit_gaus without y_err returns a raw array return {"par": res} methods = {} for label, factory, _, _ in METHOD_DEFS: time.sleep(COOLDOWN) res, times = bench(factory(N_THREADS)) entry = extract_result(label, res) entry["times"] = times methods[label] = entry # ---- Print summary ---- ndf = N_SCAN - 3 print(f"{'Method':24s} {'time (ms)':>10s} {'med|dA|':>10s} {'med|dMu|':>10s} {'med|dSig|':>10s}") print("-" * 80) for name, m in methods.items(): par = m["par"] med_t = np.median(m["times"]) * 1e3 dA = np.median(np.abs(par[:,:,0] - true_A)) dMu = np.median(np.abs(par[:,:,1] - true_mu)) dSig = np.median(np.abs(par[:,:,2] - true_sig)) chi2_str = "" if "chi2" in m: chi2_str = f" chi2/ndf={np.median(m['chi2'] / ndf):.4f}" print(f"[{name:22s}] {med_t:8.2f} ms " f"{dA:10.3f} {dMu:10.4f} {dSig:10.4f}{chi2_str}")
Method time (ms) med|dA| med|dMu| med|dSig| -------------------------------------------------------------------------------- [lmfit (LM) ] 189.58 ms 6.272 0.0940 0.0949 [Minuit2 (obj API) ] 124.59 ms 6.272 0.0940 0.0949 chi2/ndf=880.9946
In [9]:
# =============== # THREAD SCALING # =============== thread_counts = [1, 2, 4, 8, 16] thread_times = {label: [] for label, _, _, _ in METHOD_DEFS} ttimes_stddev = {label: [] for label, _, _, _ in METHOD_DEFS} for nt in thread_counts: # shuffle method order per thread count to decorrelate thermal bias run_order = list(METHOD_DEFS) random.shuffle(run_order) for label, factory, _, _ in run_order: time.sleep(COOLDOWN) _, times = bench(factory(nt)) med = np.median(times) * 1e3 std = np.std(times) * 1e3 thread_times[label].append(med) ttimes_stddev[label].append(std) per_px = med / (ROWS * COLS) * 1e3 per_px_std = std / (ROWS * COLS) * 1e3 print(f" {label:22s} n_threads={nt:2d} " f"{med:8.2f} ± {std:6.2f} ms " f"({per_px:.4f} ± {per_px_std:.4f} μs/pixel)") print("\n")
Minuit2 (obj API) n_threads= 1 456.26 ± 10.76 ms (45.6262 ± 1.0762 μs/pixel) lmfit (LM) n_threads= 1 752.77 ± 108.16 ms (75.2771 ± 10.8165 μs/pixel) Minuit2 (obj API) n_threads= 2 238.34 ± 26.57 ms (23.8345 ± 2.6566 μs/pixel) lmfit (LM) n_threads= 2 410.95 ± 84.96 ms (41.0945 ± 8.4959 μs/pixel) lmfit (LM) n_threads= 4 205.44 ± 35.26 ms (20.5445 ± 3.5259 μs/pixel) Minuit2 (obj API) n_threads= 4 139.22 ± 14.06 ms (13.9224 ± 1.4060 μs/pixel) Minuit2 (obj API) n_threads= 8 130.12 ± 3.27 ms (13.0118 ± 0.3269 μs/pixel) lmfit (LM) n_threads= 8 199.97 ± 6.88 ms (19.9968 ± 0.6877 μs/pixel) Minuit2 (obj API) n_threads=16 134.81 ± 10.13 ms (13.4807 ± 1.0130 μs/pixel) lmfit (LM) n_threads=16 188.32 ± 9.86 ms (18.8322 ± 0.9860 μs/pixel)
In [10]:
# ============================= # FIGURE 1: Residual histograms # ============================= param_names = ["A", "μ", "σ"] param_truths = [true_A, true_mu, true_sig] fig1, axes1 = plt.subplots(1, 3, figsize=(15, 5)) fig1.suptitle(f"Parameter Residuals — {ROWS}×{COLS} pixels, {N_SCAN} scan points", fontsize=14, fontweight="bold") for col, (pname, truth) in enumerate(zip(param_names, param_truths)): ax = axes1[col] # collect residuals across all methods for shared bin edges res_by_method = {} all_res = [] for mname, m in methods.items(): residual = (m["par"][:, :, col] - truth).ravel() res_by_method[mname] = residual all_res.append(residual) all_res = np.concatenate(all_res) lo, hi = np.percentile(all_res, [0.5, 99.5]) edges = np.linspace(lo, hi, 101) for mname, r in res_by_method.items(): ax.hist(r, bins=edges, histtype="step", label=mname, color=colors[mname], linewidth=styles[mname]["linewidth"], linestyle=styles[mname]["linestyle"]) ax.axvline(0, color="k", linestyle="--", linewidth=1, alpha=0.7) ax.set_xlabel(f"Fitted {pname} − True {pname}") ax.set_ylabel("Pixel count") ax.set_title(f"Δ{pname}") ax.legend(fontsize=8) ax.grid(alpha=0.3) fig1.tight_layout() # fig1.savefig("fig1_residual_histograms.png", dpi=150, bbox_inches="tight") # print("\nSaved fig1_residual_histograms.png") # ==================================================== # FIGURE 2: Performance — bar chart + thread scaling # ==================================================== fig2 = plt.figure(figsize=(14, 5)) gs = GridSpec(1, 2, figure=fig2, width_ratios=[1, 1.3]) # -- Left: bar chart at N_THREADS -- ax2a = fig2.add_subplot(gs[0]) names = list(methods.keys()) medians = [np.median(methods[n]["times"]) * 1e3 for n in names] bars = ax2a.barh(names, medians, color=[colors[n] for n in names], edgecolor="white", height=0.5) ax2a.set_xlabel("Median wall time (ms)") ax2a.set_title(f"Single call — {ROWS}×{COLS} px, {N_THREADS} threads") for bar, val in zip(bars, medians): ax2a.text(bar.get_width() + max(medians) * 0.02, bar.get_y() + bar.get_height() / 2, f"{val:.1f} ms", va="center", fontsize=10) ax2a.grid(axis="x", alpha=0.3) ax2a.set_xlim(0, max(medians) * 1.25) # -- Right: thread scaling with error bars -- ax2b = fig2.add_subplot(gs[1]) for label, _, _, _ in METHOD_DEFS: tt = thread_times[label] sd = ttimes_stddev[label] speedup = [tt[0] / t for t in tt] # propagate uncertainty: S = t0/t → δS/S = sqrt((δt0/t0)² + (δt/t)²) speedup_err = [ s * np.sqrt((sd[0] / tt[0])**2 + (sd[i] / tt[i])**2) for i, s in enumerate(speedup) ] ax2b.errorbar(thread_counts, speedup, yerr=speedup_err, fmt="o-", label=label, color=colors[label], linewidth=2, markersize=7, capsize=4) ax2b.plot(thread_counts, thread_counts, "k--", alpha=0.4, label="Ideal linear") ax2b.set_xlabel("Number of threads") ax2b.set_ylabel("Speedup vs 1 thread") ax2b.set_title("Thread scaling") ax2b.set_xticks(thread_counts) ax2b.legend(fontsize=9) ax2b.grid(alpha=0.3) fig2.tight_layout() # fig2.savefig("fig2_performance.png", dpi=150, bbox_inches="tight") # print("Saved fig2_performance.png") plt.show()
In [ ]: