// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include #include "XtalOptimizer.h" #include "ceres/ceres.h" #include "ceres/rotation.h" struct XtalResidual { XtalResidual(double x, double y, double lambda, double pixel_size, double angle_rad, double exp_h, double exp_k, double exp_l, gemmi::CrystalSystem symmetry) : obs_x(x), obs_y(y), lambda(lambda), pixel_size(pixel_size), exp_h(exp_h), exp_k(exp_k), exp_l(exp_l), angle_rad(angle_rad), symmetry(symmetry) { } template bool operator()(const T *const beam_x, const T *const beam_y, const T *const distance_mm, const T *const detector_rot, const T *const rotation_axis, const T *const p0, const T *const p1, const T *const p2, T *residual) const { // PyFAI convention (left-handed for rot1/rot2): // poni_rot = Rz(-rot3) * Rx(-rot2) * Ry(+rot1) // detector_rot[0] = rot1, detector_rot[1] = rot2 (rot3 = 0 assumed) const T rot1 = detector_rot[0]; const T rot2 = detector_rot[1]; // Ry(+rot1): rotation around Y-axis const T c1 = ceres::cos(rot1); const T s1 = ceres::sin(rot1); // Rx(-rot2): rotation around X-axis with inverted sign (PyFAI left-handed) const T c2 = ceres::cos(-rot2); const T s2 = ceres::sin(-rot2); // Detector coordinates in mm const T det_x = (T(obs_x) - beam_x[0]) * T(pixel_size); const T det_y = (T(obs_y) - beam_y[0]) * T(pixel_size); const T det_z = T(distance_mm[0]); // Apply Ry(rot1) first: rotate around Y const T t1_x = c1 * det_x + s1 * det_z; const T t1_y = det_y; const T t1_z = -s1 * det_x + c1 * det_z; // Then apply Rx(-rot2): rotate around X const T x = t1_x; const T y = c2 * t1_y - s2 * t1_z; const T z = s2 * t1_y + c2 * t1_z; // convert to recip space const T lab_norm = ceres::sqrt(x * x + y * y + z * z); const T inv_norm = T(1) / lab_norm; const T inv_lambda = T(1) / T(lambda); T recip_raw[3]; recip_raw[0] = x * inv_norm * inv_lambda; recip_raw[1] = y * inv_norm * inv_lambda; recip_raw[2] = (z * inv_norm - T(1.0)) * inv_lambda; // Apply goniometer "back-to-start" rotation: // brings observed reciprocal from image orientation into reference crystal frame const T aa_back[3] = { T(angle_rad) * rotation_axis[0], T(angle_rad) * rotation_axis[1], T(angle_rad) * rotation_axis[2] }; T recip_obs[3]; ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs); const Eigen::Matrix e_obs_recip(recip_obs[0], recip_obs[1], recip_obs[2]); // Build unit cell lengths and B (convention: columns are a, b, c prior to global rotation) Eigen::Matrix e_uc_len = Eigen::Matrix::Zero(); Eigen::Matrix B = Eigen::Matrix::Identity(); if (symmetry == gemmi::CrystalSystem::Hexagonal) { e_uc_len << p1[0], p1[0], p1[2]; B(0, 1) = T(-0.5); // cos(120) B(1, 1) = T(sqrt(3.0) / 2.0); // sin(120) } else if (symmetry == gemmi::CrystalSystem::Orthorhombic) { e_uc_len << p1[0], p1[1], p1[2]; } else if (symmetry == gemmi::CrystalSystem::Tetragonal) { e_uc_len << p1[0], p1[0], p1[2]; } else if (symmetry == gemmi::CrystalSystem::Cubic) { e_uc_len << p1[0], p1[0], p1[0]; } else if (symmetry == gemmi::CrystalSystem::Monoclinic) { // Unique axis b: alpha = gamma = 90°, beta free (angle between a and c) e_uc_len << p1[0], p1[1], p1[2]; B(0, 2) = ceres::cos(p2[0]); B(2, 2) = ceres::sin(p2[0]); } else { // Triclinic: p1 = (a,b,c), p2 = (alpha, beta, gamma) in radians const T ca = ceres::cos(p2[0]); const T cb = ceres::cos(p2[1]); const T cg = ceres::cos(p2[2]); const T sg = ceres::sin(p2[2]); e_uc_len << p1[0], p1[1], p1[2]; B(0, 0) = T(1); B(1, 0) = T(0); B(2, 0) = T(0); B(0, 1) = cg; B(1, 1) = sg; B(2, 1) = T(0); // c vector components: const T cx = cb; const T cy = (ca - cb * cg) / sg; const T v = T(1) - cx * cx - cy * cy; const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0); B(0, 2) = cx; B(1, 2) = cy; B(2, 2) = cz; } // Build unrotated direct lattice columns: (B * D), then rotate them by p0. // This avoids AngleAxisToRotationMatrix + matrix multiplications. const T L0 = e_uc_len[0]; const T L1 = e_uc_len[1]; const T L2 = e_uc_len[2]; T col0_unrot[3] = {B(0, 0) * L0, B(1, 0) * L0, B(2, 0) * L0}; T col1_unrot[3] = {B(0, 1) * L1, B(1, 1) * L1, B(2, 1) * L1}; T col2_unrot[3] = {B(0, 2) * L2, B(1, 2) * L2, B(2, 2) * L2}; T col0_rot[3], col1_rot[3], col2_rot[3]; ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot); ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot); ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot); const Eigen::Matrix A(col0_rot[0], col0_rot[1], col0_rot[2]); const Eigen::Matrix Bv(col1_rot[0], col1_rot[1], col1_rot[2]); const Eigen::Matrix C(col2_rot[0], col2_rot[1], col2_rot[2]); const Eigen::Matrix BxC = Bv.cross(C); const Eigen::Matrix CxA = C.cross(A); const Eigen::Matrix AxB = A.cross(Bv); const T V = A.dot(BxC); const T invV = T(1) / V; const Eigen::Matrix Astar = BxC * invV; const Eigen::Matrix Bstar = CxA * invV; const Eigen::Matrix Cstar = AxB * invV; const T h = T(exp_h); const T k = T(exp_k); const T l = T(exp_l); const Eigen::Matrix e_pred_recip = Astar * h + Bstar * k + Cstar * l; residual[0] = e_obs_recip[0] - e_pred_recip[0]; residual[1] = e_obs_recip[1] - e_pred_recip[1]; residual[2] = e_obs_recip[2] - e_pred_recip[2]; return true; } const double obs_x, obs_y; const double lambda; const double pixel_size; const double exp_h; const double exp_k; const double exp_l; const double angle_rad; gemmi::CrystalSystem symmetry; }; inline void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt, double rod[3], double lengths[3]) { // Load lattice columns const Coord a = latt.Vec0(); const Coord b = latt.Vec1(); const Coord c = latt.Vec2(); Eigen::Vector3d A(a[0], a[1], a[2]); Eigen::Vector3d B(b[0], b[1], b[2]); Eigen::Vector3d C(c[0], c[1], c[2]); // Lengths = column norms (orthorhombic assumption) lengths[0] = A.norm(); lengths[1] = B.norm(); lengths[2] = C.norm(); auto safe_unit = [](const Eigen::Vector3d &v, double eps = 1e-15) -> Eigen::Vector3d { double n = v.norm(); return (n > eps) ? (v / n) : Eigen::Vector3d(1.0, 0.0, 0.0); }; // Gram–Schmidt with original order: x from A, y from B orthogonalized vs x Eigen::Vector3d e1 = safe_unit(A); Eigen::Vector3d y = B - (e1.dot(B)) * e1; Eigen::Vector3d e2 = safe_unit(y); // z from cross to ensure right-handed basis Eigen::Vector3d e3 = e1.cross(e2); double n3 = e3.norm(); if (n3 < 1e-15) { // Degenerate case: B nearly collinear with A → use C instead y = C - (e1.dot(C)) * e1; e2 = safe_unit(y); e3 = e1.cross(e2); n3 = e3.norm(); if (n3 < 1e-15) { // Still degenerate: pick any perpendicular to e1 e2 = safe_unit((std::abs(e1.x()) < 0.9) ? Eigen::Vector3d::UnitX().cross(e1) : Eigen::Vector3d::UnitY().cross(e1)); e3 = e1.cross(e2); } } else { e3 /= n3; } Eigen::Matrix3d R; R.col(0) = e1; R.col(1) = e2; R.col(2) = e3; // Convert rotation to Rodrigues (axis * angle) Eigen::AngleAxisd aa(R); Eigen::Vector3d r = aa.angle() * aa.axis(); rod[0] = r.x(); rod[1] = r.y(); rod[2] = r.z(); } void LatticeToRodriguesAndLengths_Hex(const CrystalLattice &latt, double rod[3], double ac[3]) { const Coord a = latt.Vec0(); const Coord b = latt.Vec1(); const Coord c = latt.Vec2(); Eigen::Vector3d A(a[0], a[1], a[2]); Eigen::Vector3d B(b[0], b[1], b[2]); Eigen::Vector3d C(c[0], c[1], c[2]); const double a_len = A.norm(); const double b_len = B.norm(); const double c_len = C.norm(); ac[0] = (a_len + b_len) / 2.0; ac[1] = (a_len + b_len) / 2.0; ac[2] = c_len; Eigen::Vector3d e1; Eigen::Vector3d e3; if (a_len > 0.0) e1 = A / a_len; else e1 = Eigen::Vector3d::UnitX(); if (c_len > 0.0) e3 = C / c_len; else e3 = Eigen::Vector3d::UnitZ(); Eigen::Vector3d e2 = e3.cross(e1); if (e2.norm() < 1e-15) { e2 = (std::abs(e1.x()) < 0.9) ? Eigen::Vector3d::UnitX().cross(e1) : Eigen::Vector3d::UnitY().cross(e1); } e2.normalize(); e3 = e1.cross(e2).normalized(); Eigen::Matrix3d R; R.col(0) = e1; R.col(1) = e2; R.col(2) = e3; Eigen::AngleAxisd aa(R); Eigen::Vector3d r = aa.angle() * aa.axis(); rod[0] = r.x(); rod[1] = r.y(); rod[2] = r.z(); } // Extract rotation (Rodrigues), lengths (a,b,c) and beta (rad) for monoclinic (unique axis b). // Frame choice: e2 aligned with b; e1 from a projected orthogonal to e2; e3 = e1 x e2. void LatticeToRodriguesLengthsBeta_Mono(const CrystalLattice &latt, double rod[3], double lengths[3], double &beta_rad) { const Coord a = latt.Vec0(); const Coord b = latt.Vec1(); const Coord c = latt.Vec2(); Eigen::Vector3d A(a[0], a[1], a[2]); Eigen::Vector3d Bv(b[0], b[1], b[2]); Eigen::Vector3d C(c[0], c[1], c[2]); const double a_len = A.norm(); const double b_len = Bv.norm(); const double c_len = C.norm(); lengths[0] = a_len; lengths[1] = b_len; lengths[2] = c_len; // beta = angle between a and c double cos_beta = 0.0; if (a_len > 0.0 && c_len > 0.0) cos_beta = std::max(-1.0, std::min(1.0, A.dot(C) / (a_len * c_len))); beta_rad = std::acos(cos_beta); // Recover R from the same forward model used in refinement: // L ≈ R * B(beta) * D(a,b,c) => R ≈ L * (B*D)^-1 Eigen::Matrix3d L; L.col(0) = A; L.col(1) = Bv; L.col(2) = C; Eigen::Matrix3d Bmono = Eigen::Matrix3d::Identity(); Bmono(0, 2) = std::cos(beta_rad); Bmono(2, 2) = std::sin(beta_rad); Eigen::DiagonalMatrix D(lengths[0], lengths[1], lengths[2]); Eigen::Matrix3d M = Bmono * D; Eigen::Matrix3d R_est = Eigen::Matrix3d::Identity(); if (std::abs(M.determinant()) > 1e-15) { R_est = L * M.inverse(); } Eigen::JacobiSVD svd(R_est, Eigen::ComputeFullU | Eigen::ComputeFullV); Eigen::Matrix3d R = svd.matrixU() * svd.matrixV().transpose(); if (R.determinant() < 0.0) { Eigen::Matrix3d U = svd.matrixU(); U.col(2) *= -1.0; R = U * svd.matrixV().transpose(); } Eigen::AngleAxisd aa(R); Eigen::Vector3d r = aa.angle() * aa.axis(); rod[0] = r.x(); rod[1] = r.y(); rod[2] = r.z(); } static inline Eigen::Matrix3d B_from_angles(double alpha_rad, double beta_rad, double gamma_rad) { const double ca = std::cos(alpha_rad); const double cb = std::cos(beta_rad); const double cg = std::cos(gamma_rad); const double sg = std::sin(gamma_rad); Eigen::Matrix3d B = Eigen::Matrix3d::Identity(); // a along x, b in x-y, c general B(0, 0) = 1.0; B(1, 0) = 0.0; B(2, 0) = 0.0; B(0, 1) = cg; B(1, 1) = sg; B(2, 1) = 0.0; // c vector components (standard crystallography construction) const double cx = cb; const double cy = (ca - cb * cg) / sg; const double cz = std::sqrt(std::max(0.0, 1.0 - cx * cx - cy * cy)); B(0, 2) = cx; B(1, 2) = cy; B(2, 2) = cz; return B; } CrystalLattice AngleAxisAndCellToLattice(const double rod[3], const double lengths[3], double alpha_rad, double beta_rad, double gamma_rad) { const Eigen::Vector3d r(rod[0], rod[1], rod[2]); const double angle = r.norm(); Eigen::Matrix3d R = Eigen::Matrix3d::Identity(); if (angle > 0.0) R = Eigen::AngleAxisd(angle, r / angle).toRotationMatrix(); const Eigen::DiagonalMatrix D(lengths[0], lengths[1], lengths[2]); const Eigen::Matrix3d B = B_from_angles(alpha_rad, beta_rad, gamma_rad); // IMPORTANT convention: L = R * B * D (scale columns by lengths) const Eigen::Matrix3d latt = R * B * D; return CrystalLattice(Coord(latt(0, 0), latt(1, 0), latt(2, 0)), Coord(latt(0, 1), latt(1, 1), latt(2, 1)), Coord(latt(0, 2), latt(1, 2), latt(2, 2))); } CrystalLattice AngleAxisAndLengthsToLattice(const double rod[3], const double lengths[3], bool hex) { if (!hex) { return AngleAxisAndCellToLattice(rod, lengths, /*alpha=*/M_PI / 2.0, /*beta =*/M_PI / 2.0, /*gamma=*/M_PI / 2.0); } // Hexagonal: caller must already enforce a=b in `lengths`. return AngleAxisAndCellToLattice(rod, lengths, /*alpha=*/M_PI / 2.0, /*beta =*/M_PI / 2.0, /*gamma=*/2.0 * M_PI / 3.0); } bool XtalOptimizerInternal(XtalOptimizerData &data, const std::vector &spots, const float tolerance) { try { Coord vec0 = data.latt.Vec0(); Coord vec1 = data.latt.Vec1(); Coord vec2 = data.latt.Vec2(); double beta = data.latt.GetUnitCell().beta; // Initial guess for the parameters double beam_x = data.geom.GetBeamX_pxl(); double beam_y = data.geom.GetBeamY_pxl(); double distance_mm = data.geom.GetDetectorDistance_mm(); double detector_rot[2] = {data.geom.GetPoniRot1_rad(), data.geom.GetPoniRot2_rad()}; ceres::Problem problem; double latt_vec0[3], latt_vec1[3], latt_vec2[3]; double rot_vec[3] = {1, 0, 0}; switch (data.crystal_system) { case gemmi::CrystalSystem::Orthorhombic: LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1); break; case gemmi::CrystalSystem::Tetragonal: LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1); latt_vec1[0] = (latt_vec1[0] + latt_vec1[1]) / 2.0; break; case gemmi::CrystalSystem::Cubic: LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1); latt_vec1[0] = (latt_vec1[0] + latt_vec1[1] + latt_vec1[2]) / 3.0; break; case gemmi::CrystalSystem::Hexagonal: LatticeToRodriguesAndLengths_Hex(data.latt, latt_vec0, latt_vec1); break; case gemmi::CrystalSystem::Monoclinic: LatticeToRodriguesLengthsBeta_Mono(data.latt, latt_vec0, latt_vec1, beta); latt_vec2[0] = beta; latt_vec2[1] = 0.0; latt_vec2[2] = 0.0; break; default: // Triclinic: initialize a,b,c and α,β,γ from current unit cell LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1); auto uc = data.latt.GetUnitCell(); latt_vec2[0] = uc.alpha * M_PI / 180.0; latt_vec2[1] = uc.beta * M_PI / 180.0; latt_vec2[2] = uc.gamma * M_PI / 180.0; break; } if (data.axis) { rot_vec[0] = data.axis->GetAxis().x; rot_vec[1] = data.axis->GetAxis().y; rot_vec[2] = data.axis->GetAxis().z; } const float tolerance_sq = tolerance * tolerance; // Add residuals for each point for (const auto &pt: spots) { if (!data.index_ice_rings && pt.ice_ring) continue; float angle_rad = 0.0; Coord recip = pt.ReciprocalCoord(data.geom); if (data.axis) { recip = data.axis->GetTransformationAngle(pt.phi) * recip; angle_rad = pt.phi * M_PI / 180.0; } double h_fp = recip * vec0; double k_fp = recip * vec1; double l_fp = recip * vec2; double h = std::round(h_fp); double k = std::round(k_fp); double l = std::round(l_fp); double norm_sq = (h - h_fp) * (h - h_fp) + (k - k_fp) * (k - k_fp) + (l - l_fp) * (l - l_fp); if (norm_sq > tolerance_sq) continue; problem.AddResidualBlock( new ceres::AutoDiffCostFunction( new XtalResidual(pt.x, pt.y, data.geom.GetWavelength_A(), data.geom.GetPixelSize_mm(), angle_rad, h, k, l, data.crystal_system)), nullptr, &beam_x, &beam_y, &distance_mm, detector_rot, rot_vec, latt_vec0, latt_vec1, latt_vec2 ); } if (problem.NumResidualBlocks() < data.min_spots) return false; if (!data.refine_distance_mm) problem.SetParameterBlockConstant(&distance_mm); else { const double dist_range = 0.1; problem.SetParameterLowerBound(&distance_mm, 0, distance_mm * (1.0 - dist_range)); problem.SetParameterUpperBound(&distance_mm, 0, distance_mm * (1.0 + dist_range)); } if (!data.refine_beam_center) { problem.SetParameterBlockConstant(&beam_x); problem.SetParameterBlockConstant(&beam_y); } if (!data.refine_detector_angles) { problem.SetParameterBlockConstant(detector_rot); } else { const double rot_range = 3.0 / 180.0 * M_PI; for (int i = 0; i < 2; ++i) { problem.SetParameterLowerBound(detector_rot, i, detector_rot[i] - rot_range); problem.SetParameterUpperBound(detector_rot, i, detector_rot[i] + rot_range); } } if (!data.refine_rotation_axis) { problem.SetParameterBlockConstant(rot_vec); } if (!data.refine_unit_cell) { problem.SetParameterBlockConstant(latt_vec1); problem.SetParameterBlockConstant(latt_vec2); } else { // Parameter bounds // Lengths for (int i = 0; i < 3; ++i) { problem.SetParameterLowerBound(latt_vec1, i, data.min_length_A); problem.SetParameterUpperBound(latt_vec1, i, data.max_length_A); } if (data.crystal_system == gemmi::CrystalSystem::Monoclinic) { const double beta_lo = std::max(1e-6, M_PI * (data.min_angle_deg / 180.0)); const double beta_hi = std::min(M_PI - 1e-6, M_PI * (data.max_angle_deg / 180.0)); problem.SetParameterLowerBound(latt_vec2, 0, beta_lo); problem.SetParameterUpperBound(latt_vec2, 0, beta_hi); } else if (data.crystal_system == gemmi::CrystalSystem::Triclinic) { // α, β, γ bounds (radians) const double alo = M_PI * (data.min_angle_deg / 180.0); const double ahi = M_PI * (data.max_angle_deg / 180.0); for (int i = 0; i < 3; ++i) { problem.SetParameterLowerBound(latt_vec2, i, alo); problem.SetParameterUpperBound(latt_vec2, i, ahi); } } } // Configure solver ceres::Solver::Options options; options.linear_solver_type = ceres::DENSE_QR; options.minimizer_progress_to_stdout = false; options.max_solver_time_in_seconds = data.max_time; options.logging_type = ceres::LoggingType::SILENT; options.num_threads = 1; // Fix threads to 1, as this runs in multi-threaded context ceres::Solver::Summary summary; // Run optimization ceres::Solve(options, &problem, &summary); if (data.refine_beam_center) { data.beam_corr_x = data.geom.GetBeamX_pxl() - beam_x; data.beam_corr_y = data.geom.GetBeamY_pxl() - beam_y; data.geom.BeamX_pxl(beam_x).BeamY_pxl(beam_y); } if (data.refine_distance_mm) data.geom.DetectorDistance_mm(distance_mm); if (data.refine_detector_angles) data.geom.PoniRot1_rad(detector_rot[0]).PoniRot2_rad(detector_rot[1]); if (data.axis && data.refine_rotation_axis) data.axis.value().Axis(Coord(rot_vec[0], rot_vec[1], rot_vec[2])); if (data.crystal_system == gemmi::CrystalSystem::Orthorhombic) data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0); else if (data.crystal_system == gemmi::CrystalSystem::Tetragonal) { latt_vec1[1] = latt_vec1[0]; data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0); } else if (data.crystal_system == gemmi::CrystalSystem::Cubic) { latt_vec1[1] = latt_vec1[0]; latt_vec1[2] = latt_vec1[0]; data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0); } else if (data.crystal_system == gemmi::CrystalSystem::Hexagonal) { latt_vec1[1] = latt_vec1[0]; data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1,M_PI / 2.0, M_PI / 2.0, 2.0 * M_PI / 3.0); } else if (data.crystal_system == gemmi::CrystalSystem::Monoclinic) { data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, latt_vec2[0], M_PI / 2.0); } else { // Triclinic via the same generic builder data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, latt_vec2[0], latt_vec2[1], latt_vec2[2]); } return true; } catch (...) { // Convergence problems, likely not updated return false; } } bool XtalOptimizer(XtalOptimizerData &data, const std::vector &spots) { if (!XtalOptimizerInternal(data, spots, 0.3)) return false; XtalOptimizerInternal(data, spots, 0.2); return XtalOptimizerInternal(data, spots, 0.1); }