// SPDX-FileCopyrightText: 2024 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include #include "../writer/HDF5Objects.h" #include "../image_analysis/indexing/IndexerFactory.h" #include "../image_analysis/indexing/PostIndexingRefinement.h" #include "../image_analysis/bragg_prediction/BraggPrediction.h" #include "../common/Logger.h" inline double round_err(double x) { return std::abs(x - std::round(x)); } #ifdef JFJOCH_USE_CUDA #include #include namespace { Eigen::Matrix3f MakeRotation(float ax_deg, float ay_deg, float az_deg) { const float ax = ax_deg * static_cast(M_PI) / 180.0f; const float ay = ay_deg * static_cast(M_PI) / 180.0f; const float az = az_deg * static_cast(M_PI) / 180.0f; return (Eigen::AngleAxisf(az, Eigen::Vector3f::UnitZ()) * Eigen::AngleAxisf(ay, Eigen::Vector3f::UnitY()) * Eigen::AngleAxisf(ax, Eigen::Vector3f::UnitX())).toRotationMatrix(); } CrystalLattice RotateLattice(const CrystalLattice &lattice, const Eigen::Matrix3f &rot) { auto apply = [&](const Coord &v) { Eigen::Vector3f x(v.x, v.y, v.z); x = rot * x; return Coord(x.x(), x.y(), x.z()); }; return { apply(lattice.Vec0()), apply(lattice.Vec1()), apply(lattice.Vec2()) }; } std::vector BuildIndexedMask(const std::vector &spots, const CrystalLattice &lattice, float tolerance, int64_t &count) { const Coord a = lattice.Vec0(); const Coord b = lattice.Vec1(); const Coord c = lattice.Vec2(); const float tol_sq = tolerance * tolerance; std::vector mask(spots.size(), 0); count = 0; for (size_t i = 0; i < spots.size(); ++i) { const float h_fp = spots[i] * a; const float k_fp = spots[i] * b; const float l_fp = spots[i] * c; const float h_frac = h_fp - std::round(h_fp); const float k_frac = k_fp - std::round(k_fp); const float l_frac = l_fp - std::round(l_fp); const float norm_sq = h_frac * h_frac + k_frac * k_frac + l_frac * l_frac; if (norm_sq < tol_sq) { mask[i] = 1; ++count; } } return mask; } int64_t MaskOverlap(const std::vector &a, const std::vector &b) { int64_t overlap = 0; for (size_t i = 0; i < a.size(); ++i) { if (a[i] && b[i]) ++overlap; } return overlap; } bool MatchesCellLengths(const UnitCell &lhs, const UnitCell &rhs, float rel_tol = 0.08f) { std::array a = { static_cast(lhs.a), static_cast(lhs.b), static_cast(lhs.c) }; std::array b = { static_cast(rhs.a), static_cast(rhs.b), static_cast(rhs.c) }; std::sort(a.begin(), a.end()); std::sort(b.begin(), b.end()); for (int i = 0; i < 3; ++i) { const float denom = std::max(b[i], 1e-6f); if (std::abs(a[i] - b[i]) / denom > rel_tol) return false; } return true; } std::vector BuildPredictedReciprocalSpots(const DiffractionExperiment &experiment, const CrystalLattice &lattice, const BraggPredictionSettings &settings) { BraggPrediction prediction(20000); const int nref = prediction.Calc(experiment, lattice, settings); REQUIRE(nref > 0); const auto &refs = prediction.GetReflections(); const Coord astar = lattice.Astar(); const Coord bstar = lattice.Bstar(); const Coord cstar = lattice.Cstar(); std::vector spots; spots.reserve(nref); for (int i = 0; i < nref; ++i) { const auto &r = refs[i]; spots.emplace_back(static_cast(r.h) * astar + static_cast(r.k) * bstar + static_cast(r.l) * cstar); } return spots; } } // namespace TEST_CASE("FastFeedbackIndexer","[Indexing]") { std::vector hkl; for (int i = 1; i < 7; i++) for (int j = 1; j<6; j++) for (int k = 1; k < 4; k++) hkl.emplace_back(i,j,k); std::vector cells; cells.emplace_back(30,40,50,90,90,90); cells.emplace_back(80,80,90,90,90,120); cells.emplace_back(40,45,80,90,82.5,90); DiffractionExperiment experiment; experiment.SetUnitCell(cells[0]); experiment.IndexingAlgorithm(IndexingAlgorithmEnum::FFBIDX); REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFBIDX); std::unique_ptr indexer = CreateIndexer(experiment); for (auto &c: cells) { CrystalLattice l(c); Eigen::Matrix3f m; m << l.Vec0().x, l.Vec0().y, l.Vec0().z, l.Vec1().x, l.Vec1().y, l.Vec1().z, l.Vec2().x, l.Vec2().y, l.Vec2().z; auto m1 = m.transpose().inverse(); CrystalLattice recip_l(Coord(m1(0,0), m1(0,1), m1(0,2)), Coord(m1(1,0), m1(1,1), m1(1,2)), Coord(m1(2,0), m1(2,1), m1(2,2))); std::vector recip; recip.reserve(hkl.size()); for (const auto &i: hkl) recip.emplace_back(i.x * recip_l.Vec0() + i.y * recip_l.Vec1() + i.z * recip_l.Vec2()); experiment.SetUnitCell(c); indexer->Setup(experiment); auto ret = indexer->Run(recip); REQUIRE(!ret.lattice.empty()); double err[3] = {0.0, 0.0, 0.0}; for (const auto &iter: recip) { err[0] += round_err(ret.lattice[0].Vec0() * iter); err[1] += round_err(ret.lattice[0].Vec1() * iter); err[2] += round_err(ret.lattice[0].Vec2() * iter); } REQUIRE (err[0] < 0.001 * recip.size()); REQUIRE (err[1] < 0.001 * recip.size()); REQUIRE (err[2] < 0.001 * recip.size()); } } TEST_CASE("FFTIndexer","[Indexing]") { Logger logger("FFTIndexer"); UnitCell uc(39,45,78,90,90,90); CrystalLattice cl(uc); DiffractionExperiment experiment; IndexingSettings settings; settings.Algorithm(IndexingAlgorithmEnum::FFT) .FFT_MaxUnitCell_A(250.0).FFT_HighResolution_A(2 * M_PI / 3.0); experiment.ImportIndexingSettings(settings).SetUnitCell(uc); REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT); REQUIRE(experiment.GetIndexingSettings().GetTolerance() == Catch::Approx(0.1f)); std::unique_ptr indexer = CreateIndexer(experiment); REQUIRE(indexer); std::vector vec; for (int h = -2; h < 10; h++) { for (int k = -5; k < 10; k++) { for (int l = -3; l < 10; l++) { vec.push_back(h * cl.Astar() + k * cl.Bstar() + l * cl.Cstar()); } } } logger.Info("Spots {}", vec.size()); auto start = std::chrono::high_resolution_clock::now(); auto result = indexer->Run(vec); auto end = std::chrono::high_resolution_clock::now(); REQUIRE(result.lattice.size() == 1); auto uc_out = result.lattice[0].GetUnitCell(); CHECK(uc_out.a == Catch::Approx(uc.a)); CHECK(uc_out.b == Catch::Approx(uc.b)); CHECK(uc_out.c == Catch::Approx(uc.c)); CHECK(uc_out.alpha == Catch::Approx(uc.alpha)); CHECK(uc_out.beta == Catch::Approx(uc.beta)); CHECK(uc_out.gamma == Catch::Approx(uc.gamma)); logger.Info("Time: {} ms", std::chrono::duration_cast(end - start).count()); } TEST_CASE("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { Logger logger("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction"); UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; CrystalLattice lysozyme_base(lysozyme_uc); CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); DiffractionExperiment experiment(DetJF4M()); experiment.DetectorDistance_mm(75) .BeamY_pxl(1136) .BeamX_pxl(1090) .IncidentEnergy_keV(12.4); BraggPredictionSettings pred_settings{ .high_res_A = 2.0f, .ewald_dist_cutoff = 0.0010f, .max_hkl = 20, .centering = 'P', .wedge_deg = 0.1f, .mosaicity_deg = 0.2f, .min_zeta = 0.05f, .mosaicity_multiplier = 4.0f }; const auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); const auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); logger.Info("Predicted spots lattice 1: {}", spots_1.size()); logger.Info("Predicted spots lattice 2: {}", spots_2.size()); std::vector spots; spots.reserve(spots_1.size() + spots_2.size()); spots.insert(spots.end(), spots_1.begin(), spots_1.end()); spots.insert(spots.end(), spots_2.begin(), spots_2.end()); Eigen::MatrixX3 oCell(9, 3); Eigen::VectorX scores(3); auto put_lattice = [&](int idx, const CrystalLattice &lattice) { oCell(idx * 3 + 0, 0) = lattice.Vec0().x; oCell(idx * 3 + 0, 1) = lattice.Vec0().y; oCell(idx * 3 + 0, 2) = lattice.Vec0().z; oCell(idx * 3 + 1, 0) = lattice.Vec1().x; oCell(idx * 3 + 1, 1) = lattice.Vec1().y; oCell(idx * 3 + 1, 2) = lattice.Vec1().z; oCell(idx * 3 + 2, 0) = lattice.Vec2().x; oCell(idx * 3 + 2, 1) = lattice.Vec2().y; oCell(idx * 3 + 2, 2) = lattice.Vec2().z; }; put_lattice(0, lysozyme_rot_1); put_lattice(1, lysozyme_rot_2); put_lattice(2, lysozyme_rot_1); // duplicate to verify overlap rejection // Keep bootstrap scores tiny to disable candidate drift in iterative re-fitting. scores(0) = 1e-6f; scores(1) = 1.1e-6f; scores(2) = 2e-6f; RefineParameters params{ .viable_cell_min_spots = 12, .dist_tolerance_vs_reference = 0.05f, .reference_unit_cell = std::nullopt, .min_length_A = 20.0f, .max_length_A = 120.0f, .min_angle_deg = 60.0f, .max_angle_deg = 120.0f, .indexing_tolerance = 0.05f }; auto refined = Refine(spots, spots.size(), oCell, scores, params); REQUIRE(refined.size() >= 2); int lysozyme_count = 0; for (const auto &lattice : refined) { if (MatchesCellLengths(lattice.GetUnitCell(), lysozyme_uc)) ++lysozyme_count; } CHECK(lysozyme_count >= 2); int64_t count_0 = 0; int64_t count_1 = 0; auto mask_0 = BuildIndexedMask(spots, refined[0], params.indexing_tolerance, count_0); auto mask_1 = BuildIndexedMask(spots, refined[1], params.indexing_tolerance, count_1); const int64_t overlap = MaskOverlap(mask_0, mask_1); const int64_t max_set = std::max(count_0, count_1); logger.Info("Returned lattice 0 indexes {} spots", count_0); logger.Info("Returned lattice 1 indexes {} spots", count_1); logger.Info("Overlap between returned lattices: {} / {}", overlap, max_set); CHECK(overlap <= static_cast(0.2f * static_cast(max_set))); } /* TEST_CASE("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { Logger logger("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction"); UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; CrystalLattice lysozyme_base(lysozyme_uc); CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); DiffractionExperiment experiment(DetJF4M()); experiment.DetectorDistance_mm(75) .BeamY_pxl(1136) .BeamX_pxl(1090) .IncidentEnergy_keV(12.4); BraggPredictionSettings pred_settings{ .high_res_A = 2.0f, .ewald_dist_cutoff = 0.0010f, .max_hkl = 20, .centering = 'P', .wedge_deg = 0.1f, .mosaicity_deg = 0.2f, .min_zeta = 0.05f, .mosaicity_multiplier = 4.0f }; auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); logger.Info("Predicted spots lattice 1: {}", spots_1.size()); logger.Info("Predicted spots lattice 2: {}", spots_2.size()); std::vector spots; spots.reserve(spots_1.size() + spots_2.size()); spots.insert(spots.end(), spots_1.begin(), spots_1.end()); spots.insert(spots.end(), spots_2.begin(), spots_2.end()); IndexingSettings settings; settings.Algorithm(IndexingAlgorithmEnum::FFT) .FFT_MaxUnitCell_A(120.0) .FFT_HighResolution_A(2.0f) .FFT_NumVectors(1024); experiment.ImportIndexingSettings(settings) .SetUnitCell(lysozyme_uc); REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT); std::unique_ptr indexer = CreateIndexer(experiment); REQUIRE(indexer); indexer->Setup(experiment); auto result = indexer->Run(spots); logger.Info("FFT returned {} lattices", result.lattice.size()); REQUIRE(result.lattice.size() >= 2); const float tolerance = experiment.GetIndexingSettings().GetTolerance(); int lysozyme_count = 0; for (size_t i = 0; i < result.lattice.size(); ++i) { auto uc = result.lattice[i].GetUnitCell(); int64_t indexed_count = 0; BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count); logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots", i, uc.a, uc.b, uc.c, indexed_count); if (MatchesCellLengths(uc, lysozyme_uc)) ++lysozyme_count; } CHECK(lysozyme_count >= 2); // Verify the two best lysozyme lattices are distinct (low overlap) if (result.lattice.size() >= 2) { int64_t count_0 = 0, count_1 = 0; auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0); auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1); const int64_t overlap = MaskOverlap(mask_0, mask_1); const int64_t max_set = std::max(count_0, count_1); logger.Info("Top-2 overlap: {} / {}", overlap, max_set); CHECK(overlap <= static_cast(0.5f * static_cast(max_set))); } } TEST_CASE("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { Logger logger("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction"); UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; CrystalLattice lysozyme_base(lysozyme_uc); CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); DiffractionExperiment experiment(DetJF4M()); experiment.DetectorDistance_mm(75) .BeamY_pxl(1136) .BeamX_pxl(1090) .IncidentEnergy_keV(12.4); BraggPredictionSettings pred_settings{ .high_res_A = 2.0f, .ewald_dist_cutoff = 0.0010f, .max_hkl = 20, .centering = 'P', .wedge_deg = 0.1f, .mosaicity_deg = 0.2f, .min_zeta = 0.05f, .mosaicity_multiplier = 4.0f }; auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); logger.Info("Predicted spots lattice 1: {}", spots_1.size()); logger.Info("Predicted spots lattice 2: {}", spots_2.size()); std::vector spots; spots.reserve(spots_1.size() + spots_2.size()); spots.insert(spots.end(), spots_1.begin(), spots_1.end()); spots.insert(spots.end(), spots_2.begin(), spots_2.end()); experiment.SetUnitCell(lysozyme_uc); experiment.IndexingAlgorithm(IndexingAlgorithmEnum::FFBIDX); REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFBIDX); std::unique_ptr indexer = CreateIndexer(experiment); REQUIRE(indexer); indexer->Setup(experiment); auto result = indexer->Run(spots); logger.Info("FFBIDX returned {} lattices", result.lattice.size()); REQUIRE(result.lattice.size() >= 2); const float tolerance = experiment.GetIndexingSettings().GetTolerance(); int lysozyme_count = 0; for (size_t i = 0; i < result.lattice.size(); ++i) { auto uc = result.lattice[i].GetUnitCell(); int64_t indexed_count = 0; BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count); logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots", i, uc.a, uc.b, uc.c, indexed_count); if (MatchesCellLengths(uc, lysozyme_uc)) ++lysozyme_count; } CHECK(lysozyme_count >= 2); // Verify the two best lysozyme lattices are distinct (low overlap) if (result.lattice.size() >= 2) { int64_t count_0 = 0, count_1 = 0; auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0); auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1); const int64_t overlap = MaskOverlap(mask_0, mask_1); const int64_t max_set = std::max(count_0, count_1); logger.Info("Top-2 overlap: {} / {}", overlap, max_set); CHECK(overlap <= static_cast(0.5f * static_cast(max_set))); } } */ #endif