Files
Jungfraujoch/image_analysis/scale_merge/Merge.cpp
T

428 lines
14 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "Merge.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <random>
#include <unordered_map>
#include <spdlog/fmt/fmt.h>
#include <gemmi/reciproc.hpp>
#include "../../common/CorrelationCoefficient.h"
#include "../../common/ResolutionShells.h"
#include "HKLKey.h"
MergeOnTheFly::MergeOnTheFly(const DiffractionExperiment &x)
: space_group_number(x.GetSpaceGroupNumber().value_or(1)),
scaling_settings(x.GetScalingSettings()),
indexing_settings(x.GetIndexingSettings()),
high_resolution_limit(scaling_settings.GetHighResolutionLimit_A()),
image_cc_limit(scaling_settings.GetMinCCForImage()),
min_partiality(scaling_settings.GetMinPartiality()),
generator(scaling_settings.GetMergeFriedel(), space_group_number) {
}
MergeOnTheFly &MergeOnTheFly::ReferenceCell(const std::optional<UnitCell> &cell) {
reference_cell = cell;
return *this;
}
void MergeOnTheFly::AddImage(const IntegrationOutcome &outcome, bool cc_mask) {
std::unique_lock ul(merged_mutex);
const int half = half_dist(rng);
if (Mask(outcome, cc_mask))
return;
for (const auto &r: outcome.reflections) {
if (generator.IsSystematicallyAbsent(r))
continue;
if (r.image_scale_corr <= 0.0 || !std::isfinite(r.image_scale_corr))
continue;
if (!AcceptReflection(r, high_resolution_limit))
continue;
if (r.partiality < min_partiality)
continue;
const float I_corr = r.I * r.image_scale_corr;
const float sigma_corr = r.sigma * r.image_scale_corr;
if (!std::isfinite(I_corr) || !std::isfinite(sigma_corr) || sigma_corr <= 0.0)
continue;
auto hkl = generator(r);
auto hkl_key = hkl.pack();
auto it = accumulator.find(hkl_key);
if (it == accumulator.end())
it = accumulator.emplace(hkl_key, MergeAccum{
.h = hkl.plus ? hkl.h : -hkl.h,
.k = hkl.plus ? hkl.k : -hkl.k,
.l = hkl.plus ? hkl.l : -hkl.l,
}).first;
const float w = 1.0f / (sigma_corr * sigma_corr);
const float wI = w * I_corr;
it->second.sum_wI += wI;
it->second.sum_w += w;
it->second.sum_wI_half[half] += wI;
it->second.sum_w_half[half] += w;
it->second.n_half[half]++;
if (!std::isfinite(it->second.d) && std::isfinite(r.d) && r.d > 0.0f)
it->second.d = r.d;
}
}
bool MergeOnTheFly::Mask(const IntegrationOutcome &outcome, bool cc_mask) {
if (reference_cell) {
auto cell = outcome.latt.GetUnitCell();
if (!cell.is_close(*reference_cell,
indexing_settings.GetUnitCellDistTolerance(),
indexing_settings.GetUnitCellAngleTolerance_deg()))
return true;
}
if (cc_mask && image_cc_limit) {
if (!outcome.image_scale_cc
|| std::isnan(outcome.image_scale_cc.value())
|| outcome.image_scale_cc.value() < image_cc_limit.value())
return true;
}
return false;
}
std::vector<MergedReflection> MergeOnTheFly::ExportReflections() {
std::unique_lock ul(merged_mutex);
float d_min = std::numeric_limits<float>::max();
float d_max = 0.0f;
std::vector<MergedReflection> out;
out.reserve(accumulator.size());
for (const auto &accum: accumulator | std::views::values) {
if (accum.sum_w <= 0.0)
continue;
MergedReflection mr{
.h = accum.h,
.k = accum.k,
.l = accum.l,
.I = static_cast<float>(accum.sum_wI / accum.sum_w),
.sigma = 1.0f / std::sqrt(static_cast<float>(accum.sum_w)),
.I_half = {NAN, NAN},
.sigma_half = {NAN, NAN},
.d = accum.d
};
if (accum.n_half[0] + accum.n_half[1] > 0 && accum.sum_w_half[0] > 0.0 && accum.sum_w_half[1] > 0.0) {
for (int i = 0; i < 2; ++i) {
mr.I_half[i] = static_cast<float>(accum.sum_wI_half[i] / accum.sum_w_half[i]);
mr.sigma_half[i] = 1.0f / std::sqrt(static_cast<float>(accum.sum_w_half[i]));
}
}
if (!std::isfinite(accum.d) || accum.d <= 0.0f)
continue;
d_min = std::min(d_min, accum.d);
d_max = std::max(d_max, accum.d);
out.emplace_back(mr);
}
const double rfree_fraction = scaling_settings.GetRfreeFraction();
if (rfree_fraction > 0.0 && !out.empty()) {
if (d_min < d_max && d_min > 0.0f) {
constexpr int n_shells = 20;
const float d_min_pad = d_min * 0.999f;
const float d_max_pad = d_max * 1.001f;
ResolutionShells shells(d_min_pad, d_max_pad, n_shells);
std::vector<std::vector<size_t>> shell_groups(n_shells);
for (size_t i = 0; i < out.size(); ++i) {
const auto shell = shells.GetShell(out[i].d);
if (!shell.has_value())
continue;
const int s = *shell;
if (s >= 0 && s < n_shells)
shell_groups[s].push_back(i);
}
std::mt19937 rfree_rng(12345u);
std::bernoulli_distribution rfree_dist(rfree_fraction);
for (const auto &group: shell_groups) {
for (const size_t idx: group)
out[idx].rfree_flag = rfree_dist(rfree_rng);
}
}
}
return out;
}
std::vector<MergedReflection> MergeAll(const DiffractionExperiment &x,
const std::vector<IntegrationOutcome> &integration_outcome,
bool mask) {
MergeOnTheFly merge(x);
for (const auto &outcome: integration_outcome)
merge.AddImage(outcome, mask);
return merge.ExportReflections();
}
struct ShellAccum {
int total_obs = 0;
int unique = 0;
int possible = 0;
double sum_i_over_sigma = 0.0;
int n_i_over_sigma = 0;
CorrelationCoefficient cc_half;
CorrelationCoefficient cc_ref;
};
void CalcPossibleReflections(int space_group_number ,
const UnitCell &cell,
double d_min,
double d_max,
const ResolutionShells &shells,
std::vector<ShellAccum> &acc) {
gemmi::UnitCell gemmi_cell = cell;
const gemmi::SpaceGroup *sg = gemmi::find_spacegroup_by_number(space_group_number);
if (sg == nullptr)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Invalid space group number " + std::to_string(space_group_number));
// Generate unique reflections
std::vector<gemmi::Miller> possible_hkls = gemmi::make_miller_vector(gemmi_cell, sg, d_min, d_max, true);
CrystalLattice lattice(cell);
const auto astar = lattice.Astar();
const auto bstar = lattice.Bstar();
const auto cstar = lattice.Cstar();
for (const auto &hkl: possible_hkls) {
const auto q = hkl[0] * astar + hkl[1] * bstar + hkl[2] * cstar;
const auto qlen = q.Length();
if (qlen < 1e-6)
continue;
const auto d = 1.0 / qlen;
const auto shell = shells.GetShell(d);
if (!shell.has_value())
continue;
const int s = *shell;
if (s >= 0 && s < acc.size())
acc[s].possible++;
}
}
MergeStatistics MergeOnTheFly::MergeStats(const std::vector<MergedReflection> &merged,
const std::vector<IntegrationOutcome > &integration_outcome,
const std::vector<MergedReflection> &reference) {
constexpr int n_shells = 10;
auto d_min_limit_A = scaling_settings.GetHighResolutionLimit_A();
std::unordered_map<uint64_t, float> reference_intensities;
if (!reference.empty()) {
reference_intensities.reserve(reference.size());
for (const auto &r: reference) {
if (!std::isfinite(r.I))
continue;
const auto hkl = generator(r);
reference_intensities[hkl.pack()] = r.I;
}
}
float d_min = std::numeric_limits<float>::max();
float d_max = 0.0f;
for (const auto &m: merged) {
if (!std::isfinite(m.d) || m.d <= 0.0f)
continue;
if (d_min_limit_A && m.d < d_min_limit_A)
continue;
d_min = std::min(d_min, m.d);
d_max = std::max(d_max, m.d);
}
if (!(d_min < d_max && d_min > 0.0f))
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"MergeStats: Error in resolution calculation");
const float d_min_pad = d_min * 0.999f;
const float d_max_pad = d_max * 1.001f;
ResolutionShells shells(d_min_pad, d_max_pad, n_shells);
const auto shell_mean_1_d2 = shells.GetShellMeanOneOverResSq();
const auto shell_min_res = shells.GetShellMinRes();
std::vector<ShellAccum> acc(n_shells);
if (reference_cell.has_value())
CalcPossibleReflections(space_group_number, reference_cell.value(),
d_min_pad, d_max_pad, shells, acc);
CorrelationCoefficient cc_half_overall;
CorrelationCoefficient cc_ref_overall;
for (const auto &m: merged) {
const auto shell = shells.GetShell(m.d);
if (!shell.has_value())
continue;
const int s = *shell;
if (s >= 0 && s < n_shells) {
if (std::isfinite(m.I) && std::isfinite(m.sigma) && m.sigma > 0.0) {
acc[s].unique++;
acc[s].sum_i_over_sigma += m.I / m.sigma;
++acc[s].n_i_over_sigma;
if (!reference_intensities.empty()) {
const auto hkl = generator(m);
const auto ref_it = reference_intensities.find(hkl.pack());
if (ref_it != reference_intensities.end() && std::isfinite(ref_it->second)) {
acc[s].cc_ref.Add(m.I, ref_it->second);
cc_ref_overall.Add(m.I, ref_it->second);
}
}
if (std::isfinite(m.I_half[0]) && std::isfinite(m.I_half[1])) {
acc[s].cc_half.Add(m.I_half[0], m.I_half[1]);
cc_half_overall.Add(m.I_half[0], m.I_half[1]);
}
}
}
}
for (int i = 0; i < integration_outcome.size(); ++i) {
if (Mask(integration_outcome[i], true))
continue;
for (const auto &r: integration_outcome[i].reflections) {
if (generator.IsSystematicallyAbsent(r))
continue;
if (r.image_scale_corr <= 0.0 || !std::isfinite(r.image_scale_corr))
continue;
if (!AcceptReflection(r, d_min_limit_A))
continue;
if (r.partiality < min_partiality)
continue;
const float I_corr = r.I * r.image_scale_corr;
const float sigma_corr = r.sigma * r.image_scale_corr;
if (!std::isfinite(I_corr) || !std::isfinite(sigma_corr) || sigma_corr <= 0.0f)
continue;
const auto shell = shells.GetShell(r.d);
if (!shell.has_value())
continue;
const int s = *shell;
if (s >= 0 && s < n_shells)
acc[s].total_obs++;
}
}
MergeStatistics out;
out.shells.resize(n_shells);
for (int s = 0; s < n_shells; ++s) {
const auto &sa = acc[s];
auto &ss = out.shells[s];
ss.mean_one_over_d2 = shell_mean_1_d2[s];
ss.d_min = shell_min_res[s];
ss.d_max = s == 0 ? d_max_pad : shell_min_res[s - 1];
ss.total_observations = sa.total_obs;
ss.unique_reflections = sa.unique;
ss.possible_unique_reflections = sa.possible;
ss.mean_i_over_sigma = sa.n_i_over_sigma > 0
? sa.sum_i_over_sigma / sa.n_i_over_sigma
: 0.0;
ss.cc_half = sa.cc_half.GetCC();
ss.cc_ref = sa.cc_ref.GetCC();
}
auto &overall = out.overall;
overall.d_min = d_min;
overall.d_max = d_max;
int all_possible = 0;
int all_unique = 0;
double sum_i_over_sigma = 0.0;
int n_i_over_sigma = 0;
for (const auto &sa: acc) {
overall.total_observations += sa.total_obs;
all_unique += sa.unique;
all_possible += sa.possible;
sum_i_over_sigma += sa.sum_i_over_sigma;
n_i_over_sigma += sa.n_i_over_sigma;
}
overall.possible_unique_reflections = all_possible;
overall.unique_reflections = all_unique;
overall.mean_i_over_sigma = n_i_over_sigma > 0 ? sum_i_over_sigma / n_i_over_sigma : 0.0;
overall.cc_half = cc_half_overall.GetCC();
overall.cc_ref = cc_ref_overall.GetCC();
return out;
}
std::ostream &operator<<(std::ostream &output, const MergeStatisticsShell &in) {
double completeness = in.possible_unique_reflections > 0
? static_cast<double>(in.unique_reflections) / in.possible_unique_reflections * 100.0 : 0.0;
output << fmt::format("{:8d} {:8d} {:8d} {:7.1f}% {:8.1f} {:7.1f}% {:7.1f}%",
in.total_observations,
in.unique_reflections,
in.possible_unique_reflections,
completeness,
in.mean_i_over_sigma,
in.cc_half*100.0,
in.cc_ref*100.0);
return output;
}
std::ostream &operator<<(std::ostream &output, const MergeStatistics &in) {
output << std::endl;
output << fmt::format(" {:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s}",
"d_min", "N_obs", "N_uniq", "N_possib", "Compl","<I/sig>", "CC1/2", "CCref")
<< std::endl;;
output << fmt::format(" {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s}",
"", "", "", "", "", "", "", "") << std::endl;
for (const auto &sh: in.shells) {
if (sh.unique_reflections == 0)
continue;
output << fmt::format(" {:8.2f} ", sh.d_min);
output << sh;
output << std::endl;
}
output << fmt::format(" {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s} {:->8s}",
"", "", "", "", "", "", "", "") << std::endl;
output << fmt::format(" {:>8s} ", "Overall");
output << in.overall;
output << std::endl;
output << std::endl;
return output;
}