Files
Jungfraujoch/image_analysis/scale_merge/ScaleAndMerge.cpp

451 lines
14 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "ScaleAndMerge.h"
#include <ceres/ceres.h>
#include <thread>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <limits>
#include <stdexcept>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
namespace {
struct HKLKey {
int64_t h = 0;
int64_t k = 0;
int64_t l = 0;
bool is_positive = true; // only relevant if opt.merge_friedel == false
bool operator==(const HKLKey& o) const noexcept {
return h == o.h && k == o.k && l == o.l && is_positive == o.is_positive;
}
};
struct HKLKeyHash {
size_t operator()(const HKLKey& key) const noexcept {
auto mix = [](uint64_t x) {
x ^= x >> 33;
x *= 0xff51afd7ed558ccdULL;
x ^= x >> 33;
x *= 0xc4ceb9fe1a85ec53ULL;
x ^= x >> 33;
return x;
};
const uint64_t a = static_cast<uint64_t>(key.h);
const uint64_t b = static_cast<uint64_t>(key.k);
const uint64_t c = static_cast<uint64_t>(key.l);
const uint64_t d = static_cast<uint64_t>(key.is_positive ? 1 : 0);
return static_cast<size_t>(mix(a) ^ (mix(b) << 1) ^ (mix(c) << 2) ^ (mix(d) << 3));
}
};
inline int RoundImageId(float image_number, double rounding_step) {
if (!(rounding_step > 0.0))
rounding_step = 1.0;
const double x = static_cast<double>(image_number) / rounding_step;
const double r = std::round(x) * rounding_step;
return static_cast<int>(std::llround(r / rounding_step));
}
inline double SafeSigma(double s, double min_sigma) {
if (!std::isfinite(s) || s <= 0.0)
return min_sigma;
return std::max(s, min_sigma);
}
inline double SafeD(double d) {
if (!std::isfinite(d) || d <= 0.0)
return std::numeric_limits<double>::quiet_NaN();
return d;
}
inline int SafeToInt(int64_t x) {
if (x < std::numeric_limits<int>::min() || x > std::numeric_limits<int>::max())
throw std::out_of_range("HKL index out of int range for Gemmi");
return static_cast<int>(x);
}
inline double SafeInv(double x, double fallback) {
if (!std::isfinite(x) || x == 0.0)
return fallback;
return 1.0 / x;
}
inline HKLKey CanonicalizeHKLKey(const Reflection& r, const ScaleMergeOptions& opt) {
HKLKey key{};
key.h = r.h;
key.k = r.k;
key.l = r.l;
key.is_positive = true;
if (!opt.space_group.has_value()) {
if (!opt.merge_friedel) {
const HKLKey neg{-r.h, -r.k, -r.l, true};
const bool pos = std::tie(key.h, key.k, key.l) >= std::tie(neg.h, neg.k, neg.l);
if (!pos) {
key.h = -key.h;
key.k = -key.k;
key.l = -key.l;
key.is_positive = false;
}
}
return key;
}
const gemmi::SpaceGroup& sg = *opt.space_group;
const gemmi::GroupOps gops = sg.operations();
const gemmi::ReciprocalAsu rasu(&sg);
const gemmi::Op::Miller in{{SafeToInt(r.h), SafeToInt(r.k), SafeToInt(r.l)}};
const auto [asu_hkl, sign_plus] = rasu.to_asu_sign(in, gops);
key.h = asu_hkl[0];
key.k = asu_hkl[1];
key.l = asu_hkl[2];
key.is_positive = opt.merge_friedel ? true : sign_plus;
return key;
}
/// CrystFEL-like log-scaling residual
///
/// residual = w * [ ln(I_obs) - ln(G) - ln(partiality) - ln(lp) - ln(I_true) ]
///
/// Only observations with I_obs > 0 should be fed in (the caller skips the rest).
/// G and I_true are constrained to be positive via Ceres lower bounds.
struct LogIntensityResidual {
LogIntensityResidual(const Reflection& r, double sigma_obs, double wedge_deg, bool refine_partiality)
: log_Iobs_(std::log(std::max(static_cast<double>(r.I), 1e-30))),
weight_(SafeInv(sigma_obs / std::max(static_cast<double>(r.I), 1e-30), 1.0)),
delta_phi_(r.delta_phi_deg),
log_lp_(std::log(std::max(SafeInv(static_cast<double>(r.rlp), 1.0), 1e-30))),
half_wedge_(wedge_deg / 2.0),
c1_(r.zeta / std::sqrt(2.0)),
partiality_(r.partiality),
refine_partiality_(refine_partiality) {}
template<typename T>
bool operator()(const T* const G,
const T* const mosaicity,
const T* const Itrue,
T* residual) const {
T partiality;
if (refine_partiality_ && mosaicity[0] != 0.0) {
const T arg_plus = T(delta_phi_ + half_wedge_) * T(c1_) / mosaicity[0];
const T arg_minus = T(delta_phi_ - half_wedge_) * T(c1_) / mosaicity[0];
partiality = (ceres::erf(arg_plus) - ceres::erf(arg_minus)) / T(2.0);
} else {
partiality = T(partiality_);
}
// Clamp partiality away from zero so log is safe
const T min_p = T(1e-30);
if (partiality < min_p)
partiality = min_p;
// ln(I_pred) = ln(G) + ln(partiality) + ln(lp) + ln(Itrue)
const T log_Ipred = ceres::log(G[0]) + ceres::log(partiality) + T(log_lp_) + ceres::log(Itrue[0]);
residual[0] = (log_Ipred - T(log_Iobs_)) * T(weight_);
return true;
}
double log_Iobs_;
double weight_; // w_i ≈ I_obs / sigma_obs (relative weight in log-space)
double delta_phi_;
double log_lp_;
double half_wedge_;
double c1_;
double partiality_;
bool refine_partiality_;
};
struct IntensityResidual {
IntensityResidual(const Reflection& r, double sigma_obs, double wedge_deg, bool refine_partiality)
: Iobs_(static_cast<double>(r.I)),
inv_sigma_(SafeInv(sigma_obs, 1.0)),
delta_phi_(r.delta_phi_deg),
lp_(SafeInv(static_cast<double>(r.rlp), 1.0)),
half_wedge_(wedge_deg / 2.0),
c1_(r.zeta / std::sqrt(2.0)),
partiality_(r.partiality),
refine_partiality_(refine_partiality) {}
template<typename T>
bool operator()(const T* const G,
const T* const mosaicity,
const T* const Itrue,
T* residual) const {
T partiality;
if (refine_partiality_ && mosaicity[0] != 0.0) {
const T arg_plus = T(delta_phi_ + half_wedge_) * T(c1_) / mosaicity[0];
const T arg_minus = T(delta_phi_ - half_wedge_) * T(c1_) / mosaicity[0];
partiality = (ceres::erf(arg_plus) - ceres::erf(arg_minus)) / T(2.0);
} else
partiality = T(partiality_);
const T Ipred = G[0] * partiality * T(lp_) * Itrue[0];
residual[0] = (Ipred - T(Iobs_)) * T(inv_sigma_);
return true;
}
double Iobs_;
double inv_sigma_;
double delta_phi_;
double lp_;
double half_wedge_;
double c1_;
double partiality_;
bool refine_partiality_;
};
struct ScaleRegularizationResidual {
explicit ScaleRegularizationResidual(double sigma_k)
: inv_sigma_(SafeInv(sigma_k, 1.0)) {}
template <typename T>
bool operator()(const T* const k, T* residual) const {
residual[0] = (k[0] - T(1.0)) * T(inv_sigma_);
return true;
}
double inv_sigma_;
};
} // namespace
ScaleMergeResult ScaleAndMergeReflectionsCeres(const std::vector<Reflection>& observations,
const ScaleMergeOptions& opt) {
ScaleMergeResult out;
struct ObsRef {
const Reflection* r = nullptr;
int img_id = 0;
int img_slot = -1;
int hkl_slot = -1;
double sigma = 0.0;
};
std::vector<ObsRef> obs;
obs.reserve(observations.size());
std::unordered_map<int, int> imgIdToSlot;
imgIdToSlot.reserve(256);
std::unordered_map<HKLKey, int, HKLKeyHash> hklToSlot;
hklToSlot.reserve(observations.size());
for (const auto& r : observations) {
const double d = SafeD(r.d);
if (!std::isfinite(d))
continue;
if (!std::isfinite(r.I))
continue;
if (!std::isfinite(r.zeta) || r.zeta <= 0.0f)
continue;
if (!std::isfinite(r.rlp) || r.rlp == 0.0f)
continue;
const double sigma = SafeSigma(static_cast<double>(r.sigma), opt.min_sigma);
const int img_id = RoundImageId(r.image_number, opt.image_number_rounding);
int img_slot;
{
auto it = imgIdToSlot.find(img_id);
if (it == imgIdToSlot.end()) {
img_slot = static_cast<int>(imgIdToSlot.size());
imgIdToSlot.emplace(img_id, img_slot);
} else {
img_slot = it->second;
}
}
int hkl_slot;
try {
const HKLKey key = CanonicalizeHKLKey(r, opt);
auto it = hklToSlot.find(key);
if (it == hklToSlot.end()) {
hkl_slot = static_cast<int>(hklToSlot.size());
hklToSlot.emplace(key, hkl_slot);
} else {
hkl_slot = it->second;
}
} catch (...) {
continue;
}
ObsRef o;
o.r = &r;
o.img_id = img_id;
o.img_slot = img_slot;
o.hkl_slot = hkl_slot;
o.sigma = sigma;
obs.push_back(o);
}
const int nimg = static_cast<int>(imgIdToSlot.size());
const int nhkl = static_cast<int>(hklToSlot.size());
out.image_scale_g.assign(nimg, 1.0);
out.image_ids.assign(nimg, 0);
for (const auto& kv : imgIdToSlot) {
out.image_ids[kv.second] = kv.first;
}
std::vector<double> g(nimg, 1.0);
std::vector<double> Itrue(nhkl, 0.0);
// Mosaicity: always per-image
std::vector<double> mosaicity(nimg, opt.mosaicity_init_deg);
// Initialize Itrue from per-HKL median of observed intensities
{
std::vector<std::vector<double>> per_hkl_I(nhkl);
for (const auto& o : obs) {
per_hkl_I[o.hkl_slot].push_back(static_cast<double>(o.r->I));
}
for (int h = 0; h < nhkl; ++h) {
auto& v = per_hkl_I[h];
if (v.empty()) {
Itrue[h] = std::max(opt.min_sigma, 1e-6);
continue;
}
std::nth_element(v.begin(), v.begin() + static_cast<long>(v.size() / 2), v.end());
double med = v[v.size() / 2];
if (!std::isfinite(med) || med <= opt.min_sigma)
med = opt.min_sigma;
Itrue[h] = med;
}
}
ceres::Problem problem;
const bool refine_partiality = opt.wedge_deg > 0.0;
for (const auto& o : obs) {
if (opt.log_scaling_residual) {
// Log residual requires positive I_obs
if (o.r->I <= 0.0f)
continue;
auto* cost = new ceres::AutoDiffCostFunction<LogIntensityResidual, 1, 1, 1, 1>(
new LogIntensityResidual(*o.r, o.sigma, opt.wedge_deg, refine_partiality));
problem.AddResidualBlock(cost,
nullptr,
&g[o.img_slot],
&mosaicity[o.img_slot],
&Itrue[o.hkl_slot]);
} else {
auto* cost = new ceres::AutoDiffCostFunction<IntensityResidual, 1, 1, 1, 1>(
new IntensityResidual(*o.r, o.sigma, opt.wedge_deg, refine_partiality));
problem.AddResidualBlock(cost,
nullptr,
&g[o.img_slot],
&mosaicity[o.img_slot],
&Itrue[o.hkl_slot]);
}
}
// For log residual, G and Itrue must stay positive
if (opt.log_scaling_residual) {
for (int i = 0; i < nimg; ++i)
problem.SetParameterLowerBound(&g[i], 0, 1e-12);
for (int h = 0; h < nhkl; ++h)
problem.SetParameterLowerBound(&Itrue[h], 0, 1e-12);
}
// Optional Kabsch-like regularization for k
// Mosaicity refinement + bounds
if (!opt.refine_mosaicity) {
for (int i = 0; i < nimg; ++i)
problem.SetParameterBlockConstant(&mosaicity[i]);
} else {
for (int i = 0; i < nimg; ++i) {
problem.SetParameterLowerBound(&mosaicity[i], 0, opt.mosaicity_min_deg);
problem.SetParameterUpperBound(&mosaicity[i], 0, opt.mosaicity_max_deg);
}
}
// use all available threads
unsigned int hw = std::thread::hardware_concurrency();
if (hw == 0)
hw = 1; // fallback
ceres::Solver::Options options;
options.linear_solver_type = ceres::SPARSE_NORMAL_CHOLESKY;
options.minimizer_progress_to_stdout = true;
options.max_num_iterations = opt.max_num_iterations;
options.max_solver_time_in_seconds = opt.max_solver_time_s;
options.num_threads = static_cast<int>(hw);
ceres::Solver::Summary summary;
ceres::Solve(options, &problem, &summary);
// --- Export per-image results ---
for (int i = 0; i < nimg; ++i)
out.image_scale_g[i] = g[i];
out.mosaicity_deg.resize(nimg);
for (int i = 0; i < nimg; ++i)
out.mosaicity_deg[i] = mosaicity[i];
// --- Compute goodness-of-fit (reduced chi-squared) ---
const int n_obs = static_cast<int>(obs.size());
// Count free parameters: nhkl Itrue + per-image (k + mosaicity) minus fixed ones
int n_params = nhkl;
for (int i = 0; i < nimg; ++i) {
n_params += 1; // k
if (opt.refine_mosaicity)
n_params += 1; // mosaicity
}
std::vector<HKLKey> slotToHKL(nhkl);
for (const auto& kv : hklToSlot) {
slotToHKL[kv.second] = kv.first;
}
out.merged.resize(nhkl);
for (int h = 0; h < nhkl; ++h) {
out.merged[h].h = slotToHKL[h].h;
out.merged[h].k = slotToHKL[h].k;
out.merged[h].l = slotToHKL[h].l;
out.merged[h].I = Itrue[h];
out.merged[h].sigma = 0.0;
out.merged[h].d = 0.0;
}
// Populate d from median of observations per HKL
{
std::vector<std::vector<double>> per_hkl_d(nhkl);
for (const auto& o : obs) {
const double d_val = static_cast<double>(o.r->d);
if (std::isfinite(d_val) && d_val > 0.0)
per_hkl_d[o.hkl_slot].push_back(d_val);
}
for (int h = 0; h < nhkl; ++h) {
auto& v = per_hkl_d[h];
if (!v.empty()) {
std::nth_element(v.begin(), v.begin() + static_cast<long>(v.size() / 2), v.end());
out.merged[h].d = v[v.size() / 2];
}
}
}
std::cout << summary.FullReport() << std::endl;
return out;
}