Files
Jungfraujoch/tests/IndexingUnitTest.cpp

495 lines
18 KiB
C++

// SPDX-FileCopyrightText: 2024 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include <catch2/catch_all.hpp>
#include "../writer/HDF5Objects.h"
#include "../image_analysis/indexing/IndexerFactory.h"
#include "../image_analysis/indexing/PostIndexingRefinement.h"
#include "../image_analysis/bragg_prediction/BraggPrediction.h"
#include "../common/Logger.h"
inline double round_err(double x) {
return std::abs(x - std::round(x));
}
#ifdef JFJOCH_USE_CUDA
#include <Eigen/Dense>
#include <Eigen/Geometry>
namespace {
Eigen::Matrix3f MakeRotation(float ax_deg, float ay_deg, float az_deg) {
const float ax = ax_deg * static_cast<float>(M_PI) / 180.0f;
const float ay = ay_deg * static_cast<float>(M_PI) / 180.0f;
const float az = az_deg * static_cast<float>(M_PI) / 180.0f;
return (Eigen::AngleAxisf(az, Eigen::Vector3f::UnitZ())
* Eigen::AngleAxisf(ay, Eigen::Vector3f::UnitY())
* Eigen::AngleAxisf(ax, Eigen::Vector3f::UnitX())).toRotationMatrix();
}
CrystalLattice RotateLattice(const CrystalLattice &lattice, const Eigen::Matrix3f &rot) {
auto apply = [&](const Coord &v) {
Eigen::Vector3f x(v.x, v.y, v.z);
x = rot * x;
return Coord(x.x(), x.y(), x.z());
};
return {
apply(lattice.Vec0()),
apply(lattice.Vec1()),
apply(lattice.Vec2())
};
}
std::vector<uint8_t> BuildIndexedMask(const std::vector<Coord> &spots,
const CrystalLattice &lattice,
float tolerance,
int64_t &count) {
const Coord a = lattice.Vec0();
const Coord b = lattice.Vec1();
const Coord c = lattice.Vec2();
const float tol_sq = tolerance * tolerance;
std::vector<uint8_t> mask(spots.size(), 0);
count = 0;
for (size_t i = 0; i < spots.size(); ++i) {
const float h_fp = spots[i] * a;
const float k_fp = spots[i] * b;
const float l_fp = spots[i] * c;
const float h_frac = h_fp - std::round(h_fp);
const float k_frac = k_fp - std::round(k_fp);
const float l_frac = l_fp - std::round(l_fp);
const float norm_sq = h_frac * h_frac + k_frac * k_frac + l_frac * l_frac;
if (norm_sq < tol_sq) {
mask[i] = 1;
++count;
}
}
return mask;
}
int64_t MaskOverlap(const std::vector<uint8_t> &a, const std::vector<uint8_t> &b) {
int64_t overlap = 0;
for (size_t i = 0; i < a.size(); ++i) {
if (a[i] && b[i])
++overlap;
}
return overlap;
}
bool MatchesCellLengths(const UnitCell &lhs, const UnitCell &rhs, float rel_tol = 0.08f) {
std::array<float, 3> a = {
static_cast<float>(lhs.a),
static_cast<float>(lhs.b),
static_cast<float>(lhs.c)
};
std::array<float, 3> b = {
static_cast<float>(rhs.a),
static_cast<float>(rhs.b),
static_cast<float>(rhs.c)
};
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
for (int i = 0; i < 3; ++i) {
const float denom = std::max(b[i], 1e-6f);
if (std::abs(a[i] - b[i]) / denom > rel_tol)
return false;
}
return true;
}
std::vector<Coord> BuildPredictedReciprocalSpots(const DiffractionExperiment &experiment,
const CrystalLattice &lattice,
const BraggPredictionSettings &settings) {
BraggPrediction prediction(20000);
const int nref = prediction.Calc(experiment, lattice, settings);
REQUIRE(nref > 0);
const auto &refs = prediction.GetReflections();
const Coord astar = lattice.Astar();
const Coord bstar = lattice.Bstar();
const Coord cstar = lattice.Cstar();
std::vector<Coord> spots;
spots.reserve(nref);
for (int i = 0; i < nref; ++i) {
const auto &r = refs[i];
spots.emplace_back(static_cast<float>(r.h) * astar
+ static_cast<float>(r.k) * bstar
+ static_cast<float>(r.l) * cstar);
}
return spots;
}
} // namespace
TEST_CASE("FastFeedbackIndexer","[Indexing]") {
std::vector<Coord> hkl;
for (int i = 1; i < 7; i++)
for (int j = 1; j<6; j++)
for (int k = 1; k < 4; k++)
hkl.emplace_back(i,j,k);
std::vector<UnitCell> cells;
cells.emplace_back(30,40,50,90,90,90);
cells.emplace_back(80,80,90,90,90,120);
cells.emplace_back(40,45,80,90,82.5,90);
DiffractionExperiment experiment;
experiment.SetUnitCell(cells[0]);
experiment.IndexingAlgorithm(IndexingAlgorithmEnum::FFBIDX);
REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFBIDX);
std::unique_ptr<Indexer> indexer = CreateIndexer(experiment);
for (auto &c: cells) {
CrystalLattice l(c);
Eigen::Matrix3f m;
m << l.Vec0().x, l.Vec0().y, l.Vec0().z,
l.Vec1().x, l.Vec1().y, l.Vec1().z,
l.Vec2().x, l.Vec2().y, l.Vec2().z;
auto m1 = m.transpose().inverse();
CrystalLattice recip_l(Coord(m1(0,0), m1(0,1), m1(0,2)),
Coord(m1(1,0), m1(1,1), m1(1,2)),
Coord(m1(2,0), m1(2,1), m1(2,2)));
std::vector<Coord> recip;
recip.reserve(hkl.size());
for (const auto &i: hkl)
recip.emplace_back(i.x * recip_l.Vec0() + i.y * recip_l.Vec1() + i.z * recip_l.Vec2());
experiment.SetUnitCell(c);
indexer->Setup(experiment);
auto ret = indexer->Run(recip);
REQUIRE(!ret.lattice.empty());
double err[3] = {0.0, 0.0, 0.0};
for (const auto &iter: recip) {
err[0] += round_err(ret.lattice[0].Vec0() * iter);
err[1] += round_err(ret.lattice[0].Vec1() * iter);
err[2] += round_err(ret.lattice[0].Vec2() * iter);
}
REQUIRE (err[0] < 0.001 * recip.size());
REQUIRE (err[1] < 0.001 * recip.size());
REQUIRE (err[2] < 0.001 * recip.size());
}
}
TEST_CASE("FFTIndexer","[Indexing]") {
Logger logger("FFTIndexer");
UnitCell uc(39,45,78,90,90,90);
CrystalLattice cl(uc);
DiffractionExperiment experiment;
IndexingSettings settings;
settings.Algorithm(IndexingAlgorithmEnum::FFT)
.FFT_MaxUnitCell_A(250.0).FFT_HighResolution_A(2 * M_PI / 3.0);
experiment.ImportIndexingSettings(settings).SetUnitCell(uc);
REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT);
REQUIRE(experiment.GetIndexingSettings().GetTolerance() == Catch::Approx(0.1f));
std::unique_ptr<Indexer> indexer = CreateIndexer(experiment);
REQUIRE(indexer);
std::vector<Coord> vec;
for (int h = -2; h < 10; h++) {
for (int k = -5; k < 10; k++) {
for (int l = -3; l < 10; l++) {
vec.push_back(h * cl.Astar() + k * cl.Bstar() + l * cl.Cstar());
}
}
}
logger.Info("Spots {}", vec.size());
auto start = std::chrono::high_resolution_clock::now();
auto result = indexer->Run(vec);
auto end = std::chrono::high_resolution_clock::now();
REQUIRE(result.lattice.size() == 1);
auto uc_out = result.lattice[0].GetUnitCell();
CHECK(uc_out.a == Catch::Approx(uc.a));
CHECK(uc_out.b == Catch::Approx(uc.b));
CHECK(uc_out.c == Catch::Approx(uc.c));
CHECK(uc_out.alpha == Catch::Approx(uc.alpha));
CHECK(uc_out.beta == Catch::Approx(uc.beta));
CHECK(uc_out.gamma == Catch::Approx(uc.gamma));
logger.Info("Time: {} ms", std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
}
TEST_CASE("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") {
Logger logger("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction");
UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0};
CrystalLattice lysozyme_base(lysozyme_uc);
CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f));
CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f));
DiffractionExperiment experiment(DetJF4M());
experiment.DetectorDistance_mm(75)
.BeamY_pxl(1136)
.BeamX_pxl(1090)
.IncidentEnergy_keV(12.4);
BraggPredictionSettings pred_settings{
.high_res_A = 2.0f,
.ewald_dist_cutoff = 0.0010f,
.max_hkl = 20,
.centering = 'P',
.wedge_deg = 0.1f,
.mosaicity_deg = 0.2f,
.min_zeta = 0.05f,
.mosaicity_multiplier = 4.0f
};
const auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings);
const auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings);
logger.Info("Predicted spots lattice 1: {}", spots_1.size());
logger.Info("Predicted spots lattice 2: {}", spots_2.size());
std::vector<Coord> spots;
spots.reserve(spots_1.size() + spots_2.size());
spots.insert(spots.end(), spots_1.begin(), spots_1.end());
spots.insert(spots.end(), spots_2.begin(), spots_2.end());
Eigen::MatrixX3<float> oCell(9, 3);
Eigen::VectorX<float> scores(3);
auto put_lattice = [&](int idx, const CrystalLattice &lattice) {
oCell(idx * 3 + 0, 0) = lattice.Vec0().x;
oCell(idx * 3 + 0, 1) = lattice.Vec0().y;
oCell(idx * 3 + 0, 2) = lattice.Vec0().z;
oCell(idx * 3 + 1, 0) = lattice.Vec1().x;
oCell(idx * 3 + 1, 1) = lattice.Vec1().y;
oCell(idx * 3 + 1, 2) = lattice.Vec1().z;
oCell(idx * 3 + 2, 0) = lattice.Vec2().x;
oCell(idx * 3 + 2, 1) = lattice.Vec2().y;
oCell(idx * 3 + 2, 2) = lattice.Vec2().z;
};
put_lattice(0, lysozyme_rot_1);
put_lattice(1, lysozyme_rot_2);
put_lattice(2, lysozyme_rot_1); // duplicate to verify overlap rejection
// Keep bootstrap scores tiny to disable candidate drift in iterative re-fitting.
scores(0) = 1e-6f;
scores(1) = 1.1e-6f;
scores(2) = 2e-6f;
RefineParameters params{
.viable_cell_min_spots = 12,
.dist_tolerance_vs_reference = 0.05f,
.reference_unit_cell = std::nullopt,
.min_length_A = 20.0f,
.max_length_A = 120.0f,
.min_angle_deg = 60.0f,
.max_angle_deg = 120.0f,
.indexing_tolerance = 0.05f
};
auto refined = Refine(spots, spots.size(), oCell, scores, params);
REQUIRE(refined.size() >= 2);
int lysozyme_count = 0;
for (const auto &lattice : refined) {
if (MatchesCellLengths(lattice.GetUnitCell(), lysozyme_uc))
++lysozyme_count;
}
CHECK(lysozyme_count >= 2);
int64_t count_0 = 0;
int64_t count_1 = 0;
auto mask_0 = BuildIndexedMask(spots, refined[0], params.indexing_tolerance, count_0);
auto mask_1 = BuildIndexedMask(spots, refined[1], params.indexing_tolerance, count_1);
const int64_t overlap = MaskOverlap(mask_0, mask_1);
const int64_t max_set = std::max(count_0, count_1);
logger.Info("Returned lattice 0 indexes {} spots", count_0);
logger.Info("Returned lattice 1 indexes {} spots", count_1);
logger.Info("Overlap between returned lattices: {} / {}", overlap, max_set);
CHECK(overlap <= static_cast<int64_t>(0.2f * static_cast<float>(max_set)));
}
/*
TEST_CASE("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") {
Logger logger("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction");
UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0};
CrystalLattice lysozyme_base(lysozyme_uc);
CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f));
CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f));
DiffractionExperiment experiment(DetJF4M());
experiment.DetectorDistance_mm(75)
.BeamY_pxl(1136)
.BeamX_pxl(1090)
.IncidentEnergy_keV(12.4);
BraggPredictionSettings pred_settings{
.high_res_A = 2.0f,
.ewald_dist_cutoff = 0.0010f,
.max_hkl = 20,
.centering = 'P',
.wedge_deg = 0.1f,
.mosaicity_deg = 0.2f,
.min_zeta = 0.05f,
.mosaicity_multiplier = 4.0f
};
auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings);
auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings);
logger.Info("Predicted spots lattice 1: {}", spots_1.size());
logger.Info("Predicted spots lattice 2: {}", spots_2.size());
std::vector<Coord> spots;
spots.reserve(spots_1.size() + spots_2.size());
spots.insert(spots.end(), spots_1.begin(), spots_1.end());
spots.insert(spots.end(), spots_2.begin(), spots_2.end());
IndexingSettings settings;
settings.Algorithm(IndexingAlgorithmEnum::FFT)
.FFT_MaxUnitCell_A(120.0)
.FFT_HighResolution_A(2.0f)
.FFT_NumVectors(1024);
experiment.ImportIndexingSettings(settings)
.SetUnitCell(lysozyme_uc);
REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT);
std::unique_ptr<Indexer> indexer = CreateIndexer(experiment);
REQUIRE(indexer);
indexer->Setup(experiment);
auto result = indexer->Run(spots);
logger.Info("FFT returned {} lattices", result.lattice.size());
REQUIRE(result.lattice.size() >= 2);
const float tolerance = experiment.GetIndexingSettings().GetTolerance();
int lysozyme_count = 0;
for (size_t i = 0; i < result.lattice.size(); ++i) {
auto uc = result.lattice[i].GetUnitCell();
int64_t indexed_count = 0;
BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count);
logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots",
i, uc.a, uc.b, uc.c, indexed_count);
if (MatchesCellLengths(uc, lysozyme_uc))
++lysozyme_count;
}
CHECK(lysozyme_count >= 2);
// Verify the two best lysozyme lattices are distinct (low overlap)
if (result.lattice.size() >= 2) {
int64_t count_0 = 0, count_1 = 0;
auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0);
auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1);
const int64_t overlap = MaskOverlap(mask_0, mask_1);
const int64_t max_set = std::max(count_0, count_1);
logger.Info("Top-2 overlap: {} / {}", overlap, max_set);
CHECK(overlap <= static_cast<int64_t>(0.5f * static_cast<float>(max_set)));
}
}
TEST_CASE("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") {
Logger logger("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction");
UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0};
CrystalLattice lysozyme_base(lysozyme_uc);
CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f));
CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f));
DiffractionExperiment experiment(DetJF4M());
experiment.DetectorDistance_mm(75)
.BeamY_pxl(1136)
.BeamX_pxl(1090)
.IncidentEnergy_keV(12.4);
BraggPredictionSettings pred_settings{
.high_res_A = 2.0f,
.ewald_dist_cutoff = 0.0010f,
.max_hkl = 20,
.centering = 'P',
.wedge_deg = 0.1f,
.mosaicity_deg = 0.2f,
.min_zeta = 0.05f,
.mosaicity_multiplier = 4.0f
};
auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings);
auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings);
logger.Info("Predicted spots lattice 1: {}", spots_1.size());
logger.Info("Predicted spots lattice 2: {}", spots_2.size());
std::vector<Coord> spots;
spots.reserve(spots_1.size() + spots_2.size());
spots.insert(spots.end(), spots_1.begin(), spots_1.end());
spots.insert(spots.end(), spots_2.begin(), spots_2.end());
experiment.SetUnitCell(lysozyme_uc);
experiment.IndexingAlgorithm(IndexingAlgorithmEnum::FFBIDX);
REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFBIDX);
std::unique_ptr<Indexer> indexer = CreateIndexer(experiment);
REQUIRE(indexer);
indexer->Setup(experiment);
auto result = indexer->Run(spots);
logger.Info("FFBIDX returned {} lattices", result.lattice.size());
REQUIRE(result.lattice.size() >= 2);
const float tolerance = experiment.GetIndexingSettings().GetTolerance();
int lysozyme_count = 0;
for (size_t i = 0; i < result.lattice.size(); ++i) {
auto uc = result.lattice[i].GetUnitCell();
int64_t indexed_count = 0;
BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count);
logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots",
i, uc.a, uc.b, uc.c, indexed_count);
if (MatchesCellLengths(uc, lysozyme_uc))
++lysozyme_count;
}
CHECK(lysozyme_count >= 2);
// Verify the two best lysozyme lattices are distinct (low overlap)
if (result.lattice.size() >= 2) {
int64_t count_0 = 0, count_1 = 0;
auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0);
auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1);
const int64_t overlap = MaskOverlap(mask_0, mask_1);
const int64_t max_set = std::max(count_0, count_1);
logger.Info("Top-2 overlap: {} / {}", overlap, max_set);
CHECK(overlap <= static_cast<int64_t>(0.5f * static_cast<float>(max_set)));
}
} */
#endif