// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "FrenchWilson.h" #include "../../common/ResolutionShells.h" #include #include #include #include namespace { struct PosteriorMoments { double mean_I; // double mean_F; // <|F|> }; /// Numerically stable posterior integration via log-shift PosteriorMoments IntegratePosteriorStable(double I_obs, double sigma_obs, double mean_I_bin, bool acentric, int npts, double z_max) { if (mean_I_bin <= 0.0 || !std::isfinite(mean_I_bin)) mean_I_bin = 1.0; if (sigma_obs <= 0.0 || !std::isfinite(sigma_obs)) sigma_obs = std::max(std::abs(I_obs) * 0.1, 1e-6); const double inv_2sig2 = 1.0 / (2.0 * sigma_obs * sigma_obs); const double dz = z_max / static_cast(npts); // First pass: compute all log_w and find max std::vector log_w_arr(npts); double max_log_w = -std::numeric_limits::infinity(); for (int i = 0; i < npts; ++i) { const double z = (static_cast(i) + 0.5) * dz; const double I_true = z * mean_I_bin; const double diff = I_obs - I_true; double log_prior; if (acentric) { log_prior = -z; } else { if (z <= 1e-30) { log_w_arr[i] = -std::numeric_limits::infinity(); continue; } log_prior = -0.5 * std::log(2.0 * M_PI * z) - z / 2.0; } log_w_arr[i] = log_prior + (-diff * diff * inv_2sig2); if (log_w_arr[i] > max_log_w) max_log_w = log_w_arr[i]; } // Second pass: accumulate with shift double sum_w = 0.0; double sum_wI = 0.0; double sum_wF = 0.0; for (int i = 0; i < npts; ++i) { const double z = (static_cast(i) + 0.5) * dz; const double I_true = z * mean_I_bin; const double w = std::exp(log_w_arr[i] - max_log_w); if (!std::isfinite(w)) continue; sum_w += w; sum_wI += w * I_true; sum_wF += w * std::sqrt(I_true); } PosteriorMoments m{}; if (sum_w > 0.0) { m.mean_I = sum_wI / sum_w; m.mean_F = sum_wF / sum_w; } else { const double I_pos = std::max(I_obs, 0.0); m.mean_I = I_pos; m.mean_F = std::sqrt(I_pos); } return m; } } // namespace std::vector FrenchWilson(const std::vector& merged, const FrenchWilsonOptions& opts) { const size_t n = merged.size(); std::vector out(n); if (n == 0) return out; // --- Step 1: determine d-range and build ResolutionShells --- float d_min = std::numeric_limits::max(); float d_max = 0.0f; for (const auto& r : merged) { const auto d = static_cast(r.d); if (!std::isfinite(d) || d <= 0.0f) continue; if (d < d_min) d_min = d; if (d > d_max) d_max = d; } // Guard: if we couldn't determine a range, fall back to naive sqrt if (d_min >= d_max || d_min <= 0.0f) { for (size_t i = 0; i < n; ++i) { out[i].h = merged[i].h; out[i].k = merged[i].k; out[i].l = merged[i].l; out[i].sigmaI = merged[i].sigma; const double I_pos = std::max(merged[i].I, 0.0); out[i].I = I_pos; out[i].F = std::sqrt(I_pos); out[i].sigmaF = 0.0; } return out; } // Slight padding so that reflections exactly at d_min / d_max are included const float d_min_padded = d_min * 0.999f; const float d_max_padded = d_max * 1.001f; const int nshells = std::max(1, opts.num_shells); ResolutionShells shells(d_min_padded, d_max_padded, nshells); // --- Step 2: assign each reflection to a shell and compute per-shell --- std::vector shell_id(n, -1); std::vector shell_sum_I(nshells, 0.0); std::vector shell_count(nshells, 0); for (size_t i = 0; i < n; ++i) { const auto d = static_cast(merged[i].d); if (!std::isfinite(d) || d <= 0.0f) continue; auto s = shells.GetShell(d); if (!s.has_value()) continue; shell_id[i] = s.value(); if (std::isfinite(merged[i].I) && std::isfinite(merged[i].sigma)) { shell_sum_I[s.value()] += merged[i].I; shell_count[s.value()] += 1; } } std::vector shell_mean_I(nshells, 1.0); for (int s = 0; s < nshells; ++s) { if (shell_count[s] >= opts.min_reflections_per_shell) shell_mean_I[s] = std::max(shell_sum_I[s] / static_cast(shell_count[s]), 1e-10); } // --- Step 3: apply French-Wilson to each reflection --- for (size_t i = 0; i < n; ++i) { out[i].h = merged[i].h; out[i].k = merged[i].k; out[i].l = merged[i].l; out[i].sigmaI = merged[i].sigma; const double I_obs = merged[i].I; const double sigma = merged[i].sigma; // If no valid shell or bad data, naive fallback if (shell_id[i] < 0 || !std::isfinite(I_obs) || !std::isfinite(sigma) || sigma <= 0.0) { const double I_pos = std::max(I_obs, 0.0); out[i].I = I_pos; out[i].F = std::sqrt(std::max(I_pos, 0.0)); out[i].sigmaF = 0.0; continue; } const double meanI = shell_mean_I[shell_id[i]]; auto moments = IntegratePosteriorStable( I_obs, sigma, meanI, opts.acentric, opts.num_quadrature_points, opts.z_max); out[i].I = moments.mean_I; out[i].F = moments.mean_F; // sigma(F) = sqrt( - ² ) = sqrt( - ² ) const double var_F = moments.mean_I - moments.mean_F * moments.mean_F; out[i].sigmaF = (var_F > 0.0) ? std::sqrt(var_F) : 0.0; } return out; }