diff --git a/image_analysis/indexing/PostIndexingRefinement.cpp b/image_analysis/indexing/PostIndexingRefinement.cpp index a426be43..36b0b837 100644 --- a/image_analysis/indexing/PostIndexingRefinement.cpp +++ b/image_analysis/indexing/PostIndexingRefinement.cpp @@ -35,18 +35,22 @@ namespace { } static inline std::vector ComputeIndexedMask( - const Eigen::Ref> &spots, - const Eigen::Matrix3f &cell, - float indexing_tolerance, - int64_t &indexed_spot_count) { + const Eigen::Ref> &spots, + const Eigen::Matrix3f &cell, + float indexing_tolerance, + int64_t &indexed_spot_count) { const float indexing_tolerance_sq = indexing_tolerance * indexing_tolerance; - const Eigen::MatrixX3 resid = CalculateResiduals(spots, cell); + + // Compute fractional Miller indices + Eigen::MatrixX3 miller_frac = spots * cell; + Eigen::MatrixX3 miller_int = miller_frac.array().round().matrix(); + Eigen::MatrixX3 frac_resid = miller_frac - miller_int; std::vector mask(spots.rows(), 0); indexed_spot_count = 0; for (int i = 0; i < spots.rows(); ++i) { - if (resid.row(i).squaredNorm() < indexing_tolerance_sq) { + if (frac_resid.row(i).squaredNorm() < indexing_tolerance_sq) { mask[i] = 1; indexed_spot_count++; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1c4bb8f4..dd755f7a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -47,7 +47,6 @@ ADD_EXECUTABLE(jfjoch_test ImageMetadataTest.cpp JFJochReceiverLiteTest.cpp GridScanSettingsTest.cpp - FFTIndexerTest.cpp JFJochReceiverPlotsTest.cpp GoniometerAxisTest.cpp DetGeomCalibTest.cpp diff --git a/tests/IndexingUnitTest.cpp b/tests/IndexingUnitTest.cpp index cd8cc992..1e1e1e49 100644 --- a/tests/IndexingUnitTest.cpp +++ b/tests/IndexingUnitTest.cpp @@ -4,6 +4,9 @@ #include #include "../writer/HDF5Objects.h" #include "../image_analysis/indexing/IndexerFactory.h" +#include "../image_analysis/indexing/PostIndexingRefinement.h" +#include "../image_analysis/bragg_prediction/BraggPrediction.h" +#include "../common/Logger.h" inline double round_err(double x) { return std::abs(x - std::round(x)); @@ -12,6 +15,120 @@ inline double round_err(double x) { #ifdef JFJOCH_USE_CUDA #include +#include + +namespace { + Eigen::Matrix3f MakeRotation(float ax_deg, float ay_deg, float az_deg) { + const float ax = ax_deg * static_cast(M_PI) / 180.0f; + const float ay = ay_deg * static_cast(M_PI) / 180.0f; + const float az = az_deg * static_cast(M_PI) / 180.0f; + + return (Eigen::AngleAxisf(az, Eigen::Vector3f::UnitZ()) + * Eigen::AngleAxisf(ay, Eigen::Vector3f::UnitY()) + * Eigen::AngleAxisf(ax, Eigen::Vector3f::UnitX())).toRotationMatrix(); + } + + CrystalLattice RotateLattice(const CrystalLattice &lattice, const Eigen::Matrix3f &rot) { + auto apply = [&](const Coord &v) { + Eigen::Vector3f x(v.x, v.y, v.z); + x = rot * x; + return Coord(x.x(), x.y(), x.z()); + }; + + return { + apply(lattice.Vec0()), + apply(lattice.Vec1()), + apply(lattice.Vec2()) + }; + } + + std::vector BuildIndexedMask(const std::vector &spots, + const CrystalLattice &lattice, + float tolerance, + int64_t &count) { + const Coord a = lattice.Vec0(); + const Coord b = lattice.Vec1(); + const Coord c = lattice.Vec2(); + const float tol_sq = tolerance * tolerance; + + std::vector mask(spots.size(), 0); + count = 0; + + for (size_t i = 0; i < spots.size(); ++i) { + const float h_fp = spots[i] * a; + const float k_fp = spots[i] * b; + const float l_fp = spots[i] * c; + + const float h_frac = h_fp - std::round(h_fp); + const float k_frac = k_fp - std::round(k_fp); + const float l_frac = l_fp - std::round(l_fp); + + const float norm_sq = h_frac * h_frac + k_frac * k_frac + l_frac * l_frac; + if (norm_sq < tol_sq) { + mask[i] = 1; + ++count; + } + } + + return mask; + } + + int64_t MaskOverlap(const std::vector &a, const std::vector &b) { + int64_t overlap = 0; + for (size_t i = 0; i < a.size(); ++i) { + if (a[i] && b[i]) + ++overlap; + } + return overlap; + } + + bool MatchesCellLengths(const UnitCell &lhs, const UnitCell &rhs, float rel_tol = 0.08f) { + std::array a = { + static_cast(lhs.a), + static_cast(lhs.b), + static_cast(lhs.c) + }; + std::array b = { + static_cast(rhs.a), + static_cast(rhs.b), + static_cast(rhs.c) + }; + std::sort(a.begin(), a.end()); + std::sort(b.begin(), b.end()); + + for (int i = 0; i < 3; ++i) { + const float denom = std::max(b[i], 1e-6f); + if (std::abs(a[i] - b[i]) / denom > rel_tol) + return false; + } + return true; + } + + std::vector BuildPredictedReciprocalSpots(const DiffractionExperiment &experiment, + const CrystalLattice &lattice, + const BraggPredictionSettings &settings) { + BraggPrediction prediction(20000); + const int nref = prediction.Calc(experiment, lattice, settings); + REQUIRE(nref > 0); + + const auto &refs = prediction.GetReflections(); + const Coord astar = lattice.Astar(); + const Coord bstar = lattice.Bstar(); + const Coord cstar = lattice.Cstar(); + + std::vector spots; + spots.reserve(nref); + + for (int i = 0; i < nref; ++i) { + const auto &r = refs[i]; + spots.emplace_back(static_cast(r.h) * astar + + static_cast(r.k) * bstar + + static_cast(r.l) * cstar); + } + + return spots; + } +} // namespace TEST_CASE("FastFeedbackIndexer","[Indexing]") { std::vector hkl; @@ -57,11 +174,6 @@ TEST_CASE("FastFeedbackIndexer","[Indexing]") { auto ret = indexer->Run(recip); REQUIRE(!ret.lattice.empty()); - //auto uc = ret[0].GetUnitCell(); - //REQUIRE(c.a == Catch::Approx(uc.a)); - //REQUIRE(c.b == Catch::Approx(uc.b)); - //REQUIRE(c.c == Catch::Approx(uc.c)); - double err[3] = {0.0, 0.0, 0.0}; for (const auto &iter: recip) { err[0] += round_err(ret.lattice[0].Vec0() * iter); @@ -74,4 +186,310 @@ TEST_CASE("FastFeedbackIndexer","[Indexing]") { } } +TEST_CASE("FFTIndexer","[Indexing]") { + Logger logger("FFTIndexer"); + + UnitCell uc(39,45,78,90,90,90); + CrystalLattice cl(uc); + + DiffractionExperiment experiment; + IndexingSettings settings; + settings.Algorithm(IndexingAlgorithmEnum::FFT) + .FFT_MaxUnitCell_A(250.0).FFT_HighResolution_A(2 * M_PI / 3.0); + experiment.ImportIndexingSettings(settings).SetUnitCell(uc); + + REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT); + REQUIRE(experiment.GetIndexingSettings().GetTolerance() == Catch::Approx(0.1f)); + std::unique_ptr indexer = CreateIndexer(experiment); + REQUIRE(indexer); + + std::vector vec; + for (int h = -2; h < 10; h++) { + for (int k = -5; k < 10; k++) { + for (int l = -3; l < 10; l++) { + vec.push_back(h * cl.Astar() + k * cl.Bstar() + l * cl.Cstar()); + } + } + } + logger.Info("Spots {}", vec.size()); + + auto start = std::chrono::high_resolution_clock::now(); + auto result = indexer->Run(vec); + auto end = std::chrono::high_resolution_clock::now(); + + REQUIRE(result.lattice.size() == 1); + auto uc_out = result.lattice[0].GetUnitCell(); + CHECK(uc_out.a == Catch::Approx(uc.a)); + CHECK(uc_out.b == Catch::Approx(uc.b)); + CHECK(uc_out.c == Catch::Approx(uc.c)); + + CHECK(uc_out.alpha == Catch::Approx(uc.alpha)); + CHECK(uc_out.beta == Catch::Approx(uc.beta)); + CHECK(uc_out.gamma == Catch::Approx(uc.gamma)); + + logger.Info("Time: {} ms", std::chrono::duration_cast(end - start).count()); +} + +TEST_CASE("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { + Logger logger("PostIndexingRefinement_MultiLattice_TwoLysozymes_BraggPrediction"); + + UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; + CrystalLattice lysozyme_base(lysozyme_uc); + + CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); + CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); + + DiffractionExperiment experiment(DetJF4M()); + experiment.DetectorDistance_mm(75) + .BeamY_pxl(1136) + .BeamX_pxl(1090) + .IncidentEnergy_keV(12.4); + + BraggPredictionSettings pred_settings{ + .high_res_A = 2.0f, + .ewald_dist_cutoff = 0.0010f, + .max_hkl = 20, + .centering = 'P', + .wedge_deg = 0.1f, + .mosaicity_deg = 0.2f, + .min_zeta = 0.05f, + .mosaicity_multiplier = 4.0f + }; + + const auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); + const auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); + + logger.Info("Predicted spots lattice 1: {}", spots_1.size()); + logger.Info("Predicted spots lattice 2: {}", spots_2.size()); + + std::vector spots; + spots.reserve(spots_1.size() + spots_2.size()); + spots.insert(spots.end(), spots_1.begin(), spots_1.end()); + spots.insert(spots.end(), spots_2.begin(), spots_2.end()); + + Eigen::MatrixX3 oCell(9, 3); + Eigen::VectorX scores(3); + + auto put_lattice = [&](int idx, const CrystalLattice &lattice) { + oCell(idx * 3 + 0, 0) = lattice.Vec0().x; + oCell(idx * 3 + 0, 1) = lattice.Vec0().y; + oCell(idx * 3 + 0, 2) = lattice.Vec0().z; + + oCell(idx * 3 + 1, 0) = lattice.Vec1().x; + oCell(idx * 3 + 1, 1) = lattice.Vec1().y; + oCell(idx * 3 + 1, 2) = lattice.Vec1().z; + + oCell(idx * 3 + 2, 0) = lattice.Vec2().x; + oCell(idx * 3 + 2, 1) = lattice.Vec2().y; + oCell(idx * 3 + 2, 2) = lattice.Vec2().z; + }; + + put_lattice(0, lysozyme_rot_1); + put_lattice(1, lysozyme_rot_2); + put_lattice(2, lysozyme_rot_1); // duplicate to verify overlap rejection + + // Keep bootstrap scores tiny to disable candidate drift in iterative re-fitting. + scores(0) = 1e-6f; + scores(1) = 1.1e-6f; + scores(2) = 2e-6f; + + RefineParameters params{ + .viable_cell_min_spots = 12, + .dist_tolerance_vs_reference = 0.05f, + .reference_unit_cell = std::nullopt, + .min_length_A = 20.0f, + .max_length_A = 120.0f, + .min_angle_deg = 60.0f, + .max_angle_deg = 120.0f, + .indexing_tolerance = 0.05f + }; + + auto refined = Refine(spots, spots.size(), oCell, scores, params); + + REQUIRE(refined.size() >= 2); + + int lysozyme_count = 0; + for (const auto &lattice : refined) { + if (MatchesCellLengths(lattice.GetUnitCell(), lysozyme_uc)) + ++lysozyme_count; + } + CHECK(lysozyme_count >= 2); + + int64_t count_0 = 0; + int64_t count_1 = 0; + auto mask_0 = BuildIndexedMask(spots, refined[0], params.indexing_tolerance, count_0); + auto mask_1 = BuildIndexedMask(spots, refined[1], params.indexing_tolerance, count_1); + const int64_t overlap = MaskOverlap(mask_0, mask_1); + const int64_t max_set = std::max(count_0, count_1); + + logger.Info("Returned lattice 0 indexes {} spots", count_0); + logger.Info("Returned lattice 1 indexes {} spots", count_1); + logger.Info("Overlap between returned lattices: {} / {}", overlap, max_set); + + CHECK(overlap <= static_cast(0.2f * static_cast(max_set))); +} + +/* +TEST_CASE("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { + Logger logger("FFTIndexer_MultiLattice_TwoLysozymes_BraggPrediction"); + + UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; + CrystalLattice lysozyme_base(lysozyme_uc); + + CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); + CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); + + DiffractionExperiment experiment(DetJF4M()); + experiment.DetectorDistance_mm(75) + .BeamY_pxl(1136) + .BeamX_pxl(1090) + .IncidentEnergy_keV(12.4); + + BraggPredictionSettings pred_settings{ + .high_res_A = 2.0f, + .ewald_dist_cutoff = 0.0010f, + .max_hkl = 20, + .centering = 'P', + .wedge_deg = 0.1f, + .mosaicity_deg = 0.2f, + .min_zeta = 0.05f, + .mosaicity_multiplier = 4.0f + }; + + auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); + auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); + + logger.Info("Predicted spots lattice 1: {}", spots_1.size()); + logger.Info("Predicted spots lattice 2: {}", spots_2.size()); + + std::vector spots; + spots.reserve(spots_1.size() + spots_2.size()); + spots.insert(spots.end(), spots_1.begin(), spots_1.end()); + spots.insert(spots.end(), spots_2.begin(), spots_2.end()); + + IndexingSettings settings; + settings.Algorithm(IndexingAlgorithmEnum::FFT) + .FFT_MaxUnitCell_A(120.0) + .FFT_HighResolution_A(2.0f) + .FFT_NumVectors(1024); + experiment.ImportIndexingSettings(settings) + .SetUnitCell(lysozyme_uc); + + REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFT); + + std::unique_ptr indexer = CreateIndexer(experiment); + REQUIRE(indexer); + indexer->Setup(experiment); + + auto result = indexer->Run(spots); + + logger.Info("FFT returned {} lattices", result.lattice.size()); + REQUIRE(result.lattice.size() >= 2); + + const float tolerance = experiment.GetIndexingSettings().GetTolerance(); + + int lysozyme_count = 0; + for (size_t i = 0; i < result.lattice.size(); ++i) { + auto uc = result.lattice[i].GetUnitCell(); + int64_t indexed_count = 0; + BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count); + logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots", + i, uc.a, uc.b, uc.c, indexed_count); + if (MatchesCellLengths(uc, lysozyme_uc)) + ++lysozyme_count; + } + + CHECK(lysozyme_count >= 2); + + // Verify the two best lysozyme lattices are distinct (low overlap) + if (result.lattice.size() >= 2) { + int64_t count_0 = 0, count_1 = 0; + auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0); + auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1); + const int64_t overlap = MaskOverlap(mask_0, mask_1); + const int64_t max_set = std::max(count_0, count_1); + logger.Info("Top-2 overlap: {} / {}", overlap, max_set); + CHECK(overlap <= static_cast(0.5f * static_cast(max_set))); + } +} + +TEST_CASE("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction","[Indexing]") { + Logger logger("FFBIDXIndexer_MultiLattice_TwoLysozymes_BraggPrediction"); + + UnitCell lysozyme_uc{36.9, 78.95, 78.95, 90.0, 90.0, 90.0}; + CrystalLattice lysozyme_base(lysozyme_uc); + + CrystalLattice lysozyme_rot_1 = RotateLattice(lysozyme_base, MakeRotation(10.0f, 18.0f, 27.0f)); + CrystalLattice lysozyme_rot_2 = RotateLattice(lysozyme_base, MakeRotation(66.0f, -14.0f, 101.0f)); + + DiffractionExperiment experiment(DetJF4M()); + experiment.DetectorDistance_mm(75) + .BeamY_pxl(1136) + .BeamX_pxl(1090) + .IncidentEnergy_keV(12.4); + + BraggPredictionSettings pred_settings{ + .high_res_A = 2.0f, + .ewald_dist_cutoff = 0.0010f, + .max_hkl = 20, + .centering = 'P', + .wedge_deg = 0.1f, + .mosaicity_deg = 0.2f, + .min_zeta = 0.05f, + .mosaicity_multiplier = 4.0f + }; + + auto spots_1 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_1, pred_settings); + auto spots_2 = BuildPredictedReciprocalSpots(experiment, lysozyme_rot_2, pred_settings); + + logger.Info("Predicted spots lattice 1: {}", spots_1.size()); + logger.Info("Predicted spots lattice 2: {}", spots_2.size()); + + std::vector spots; + spots.reserve(spots_1.size() + spots_2.size()); + spots.insert(spots.end(), spots_1.begin(), spots_1.end()); + spots.insert(spots.end(), spots_2.begin(), spots_2.end()); + + experiment.SetUnitCell(lysozyme_uc); + experiment.IndexingAlgorithm(IndexingAlgorithmEnum::FFBIDX); + + REQUIRE(experiment.GetIndexingAlgorithm() == IndexingAlgorithmEnum::FFBIDX); + + std::unique_ptr indexer = CreateIndexer(experiment); + REQUIRE(indexer); + indexer->Setup(experiment); + + auto result = indexer->Run(spots); + + logger.Info("FFBIDX returned {} lattices", result.lattice.size()); + REQUIRE(result.lattice.size() >= 2); + + const float tolerance = experiment.GetIndexingSettings().GetTolerance(); + + int lysozyme_count = 0; + for (size_t i = 0; i < result.lattice.size(); ++i) { + auto uc = result.lattice[i].GetUnitCell(); + int64_t indexed_count = 0; + BuildIndexedMask(spots, result.lattice[i], tolerance, indexed_count); + logger.Info("Lattice {} cell ({:.1f} {:.1f} {:.1f}) indexes {} spots", + i, uc.a, uc.b, uc.c, indexed_count); + if (MatchesCellLengths(uc, lysozyme_uc)) + ++lysozyme_count; + } + + CHECK(lysozyme_count >= 2); + + // Verify the two best lysozyme lattices are distinct (low overlap) + if (result.lattice.size() >= 2) { + int64_t count_0 = 0, count_1 = 0; + auto mask_0 = BuildIndexedMask(spots, result.lattice[0], tolerance, count_0); + auto mask_1 = BuildIndexedMask(spots, result.lattice[1], tolerance, count_1); + const int64_t overlap = MaskOverlap(mask_0, mask_1); + const int64_t max_set = std::max(count_0, count_1); + logger.Info("Top-2 overlap: {} / {}", overlap, max_set); + CHECK(overlap <= static_cast(0.5f * static_cast(max_set))); + } +} */ + + #endif \ No newline at end of file