diff --git a/tests/CalcBraggPredictionTest.cpp b/tests/CalcBraggPredictionTest.cpp index fee29973..0a49cb25 100644 --- a/tests/CalcBraggPredictionTest.cpp +++ b/tests/CalcBraggPredictionTest.cpp @@ -365,4 +365,53 @@ TEST_CASE("BraggPredictionGPU_backscattering") { } } +TEST_CASE("BraggPrediction_CPU_GPU_consistency_tilted") { + // Verify CPU and GPU implementations produce identical results with tilted detector + DiffractionExperiment experiment(DetJF4M()); + experiment.DetectorDistance_mm(100.0).BeamX_pxl(1500.0).BeamY_pxl(1000.0) + .PoniRot1_rad(0.04).PoniRot2_rad(-0.025) + .IncidentEnergy_keV(12.0); + + CrystalLattice lattice(Coord{30, 10, 0}, Coord{-15, 45, 0}, Coord{0, 0, 150}); + + BraggPredictionSettings settings{ + .high_res_A = 2.0, + .ewald_dist_cutoff = 0.0015, + .max_hkl = 30 + }; + + BraggPrediction cpu_pred; + BraggPredictionGPU gpu_pred; + + int cpu_count = cpu_pred.Calc(experiment, lattice, settings); + int gpu_count = gpu_pred.Calc(experiment, lattice, settings); + + REQUIRE(cpu_count > 0); + REQUIRE(gpu_count > 0); + + // Build map of GPU reflections by hkl + std::map, const Reflection*> gpu_refl_map; + for (int i = 0; i < gpu_count; ++i) { + const auto& r = gpu_pred.GetReflections().at(i); + gpu_refl_map[{r.h, r.k, r.l}] = &r; + } + + // Check that each CPU reflection has a matching GPU reflection + int matched = 0; + for (int i = 0; i < cpu_count; ++i) { + const auto& cpu_r = cpu_pred.GetReflections().at(i); + auto key = std::make_tuple(cpu_r.h, cpu_r.k, cpu_r.l); + auto it = gpu_refl_map.find(key); + if (it != gpu_refl_map.end()) { + const auto& gpu_r = *it->second; + CHECK(cpu_r.predicted_x == Catch::Approx(gpu_r.predicted_x).margin(0.1)); + CHECK(cpu_r.predicted_y == Catch::Approx(gpu_r.predicted_y).margin(0.1)); + CHECK(cpu_r.d == Catch::Approx(gpu_r.d).margin(0.01)); + matched++; + } + } + + // Most reflections should match (allow for some numerical differences at boundaries) + CHECK(matched > cpu_count * 0.95); +} #endif diff --git a/tests/DiffractionGeometryTest.cpp b/tests/DiffractionGeometryTest.cpp index 231613a3..ef307613 100644 --- a/tests/DiffractionGeometryTest.cpp +++ b/tests/DiffractionGeometryTest.cpp @@ -471,4 +471,56 @@ TEST_CASE("ResPhiToPxl_poni_rot") { out = geom.ResPhiToPxl(2.0, 0.7567); CHECK(geom.PxlToRes(out.first, out.second) == Catch::Approx(2.0)); CHECK(fabs(geom.Phi_rad(out.first, out.second) - 0.7567) < 0.001 ); +} + +TEST_CASE("DiffractionGeometry_DetectorToRecip_RecipToDetector_tilted") { + // Verify roundtrip consistency with non-zero rot1/rot2 + DiffractionGeometry geom; + geom.BeamX_pxl(1000).BeamY_pxl(1000).DetectorDistance_mm(150) + .PixelSize_mm(0.075).Wavelength_A(1.0) + .PoniRot1_rad(0.05).PoniRot2_rad(-0.03); + + // Test multiple points across the detector + std::vector> test_points = { + {500, 500}, {1500, 500}, {500, 1500}, {1500, 1500}, + {800, 1200}, {1200, 800}, {300, 1700}, {1700, 300} + }; + + for (const auto& [x, y] : test_points) { + Coord recip = geom.DetectorToRecip(x, y); + auto [proj_x, proj_y] = geom.RecipToDector(recip); + + CHECK(proj_x == Catch::Approx(x).margin(0.001)); + CHECK(proj_y == Catch::Approx(y).margin(0.001)); + } +} + +TEST_CASE("DiffractionGeometry_PONI_matrix_consistency") { + // Verify that the PONI rotation matrix gives consistent results + // when used for both forward and inverse transformations + DiffractionGeometry geom; + geom.BeamX_pxl(1000).BeamY_pxl(1000).DetectorDistance_mm(100) + .PixelSize_mm(0.075).Wavelength_A(1.0) + .PoniRot1_rad(0.04).PoniRot2_rad(-0.025); + + const auto& poni_rot = geom.GetPoniRotMatrix(); + const auto poni_rot_T = poni_rot.transpose(); + + // Test: poni_rot * poni_rot^T should be identity (orthogonal matrix) + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + Coord ei, ej; + ei[i] = 1.0f; + ej[j] = 1.0f; + float expected = (i == j) ? 1.0f : 0.0f; + CHECK((poni_rot * (poni_rot_T * ej))[i] == Catch::Approx(expected).margin(1e-6)); + } + } + + // Test: S0 vector transformation + Coord S0 = geom.GetScatteringVector(); + // For beam along z, S0 = (0, 0, 1/λ) + CHECK(S0.x == Catch::Approx(0.0f)); + CHECK(S0.y == Catch::Approx(0.0f)); + CHECK(S0.z == Catch::Approx(1.0f)); } \ No newline at end of file