Files
Jungfraujoch/image_analysis/geom_refinement/XtalOptimizer.cpp
T

861 lines
32 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include <Eigen/Dense>
#include "XtalOptimizer.h"
#include "ceres/ceres.h"
#include "ceres/rotation.h"
struct XtalResidual {
XtalResidual(double x, double y,
double lambda,
double pixel_size,
double angle_rad,
double exp_h, double exp_k,
double exp_l,
gemmi::CrystalSystem symmetry)
: obs_x(x), obs_y(y),
inv_lambda(1.0/lambda),
pixel_size(pixel_size),
exp_h(exp_h),
exp_k(exp_k),
exp_l(exp_l),
angle_rad(angle_rad),
symmetry(symmetry) {
if (std::fabs(lambda) < 1e-6)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Lambda cannot be close to zero");
}
template<typename T>
bool operator()(const T *const beam,
const T *const distance_mm,
const T *const detector_rot,
const T *const rotation_axis,
const T *const p0,
const T *const p1,
const T *const p2,
T *residual) const {
// PyFAI convention (left-handed for rot1/rot2):
// poni_rot = Rz(-rot3) * Rx(-rot2) * Ry(+rot1)
// detector_rot[0] = rot1, detector_rot[1] = rot2 (rot3 = 0 assumed)
const T rot1 = detector_rot[0];
const T rot2 = detector_rot[1];
// Ry(+rot1): rotation around Y-axis
const T c1 = ceres::cos(rot1);
const T s1 = ceres::sin(rot1);
// Rx(-rot2): rotation around X-axis with inverted sign (PyFAI left-handed)
const T c2 = ceres::cos(rot2);
const T s2 = ceres::sin(rot2);
// Detector coordinates in mm
const T det_x = (T(obs_x) - beam[0]) * T(pixel_size);
const T det_y = (T(obs_y) - beam[1]) * T(pixel_size);
const T det_z = T(distance_mm[0]);
// Apply Ry(rot1) first: rotate around Y
const T t1_x = c1 * det_x + s1 * det_z;
const T t1_y = det_y;
const T t1_z = -s1 * det_x + c1 * det_z;
// Then apply Rx(-rot2): rotate around X
const T x = t1_x;
const T y = c2 * t1_y + s2 * t1_z;
const T z = -s2 * t1_y + c2 * t1_z;
// convert to recip space
const T lab_norm = ceres::sqrt(x * x + y * y + z * z);
const T inv_norm = T(1) / lab_norm;
T recip_raw[3];
recip_raw[0] = x * inv_norm * T(inv_lambda);
recip_raw[1] = y * inv_norm * T(inv_lambda);
recip_raw[2] = (z * inv_norm - T(1.0)) * T(inv_lambda);
// Apply goniometer "back-to-start" rotation:
// brings observed reciprocal from image orientation into reference crystal frame
const T aa_back[3] = {
T(angle_rad) * rotation_axis[0],
T(angle_rad) * rotation_axis[1],
T(angle_rad) * rotation_axis[2]
};
T recip_obs[3];
ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs);
const Eigen::Matrix<T, 3, 1> e_obs_recip(recip_obs[0], recip_obs[1], recip_obs[2]);
// Build unit cell lengths and B (convention: columns are a, b, c prior to global rotation)
Eigen::Matrix<T, 3, 1> e_uc_len = Eigen::Matrix<T, 3, 1>::Zero();
Eigen::Matrix<T, 3, 3> B = Eigen::Matrix<T, 3, 3>::Identity();
if (symmetry == gemmi::CrystalSystem::Hexagonal) {
e_uc_len << p1[0], p1[0], p1[2];
B(0, 1) = T(-0.5); // cos(120)
B(1, 1) = T(sqrt(3.0) / 2.0); // sin(120)
} else if (symmetry == gemmi::CrystalSystem::Orthorhombic) {
e_uc_len << p1[0], p1[1], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Tetragonal) {
e_uc_len << p1[0], p1[0], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Cubic) {
e_uc_len << p1[0], p1[0], p1[0];
} else if (symmetry == gemmi::CrystalSystem::Monoclinic) {
// Unique axis b: alpha = gamma = 90°, beta free (angle between a and c)
e_uc_len << p1[0], p1[1], p1[2];
B(0, 2) = ceres::cos(p2[0]);
B(2, 2) = ceres::sin(p2[0]);
} else {
// Triclinic: p1 = (a,b,c), p2 = (alpha, beta, gamma) in radians
const T ca = ceres::cos(p2[0]);
const T cb = ceres::cos(p2[1]);
const T cg = ceres::cos(p2[2]);
const T sg = ceres::sin(p2[2]);
e_uc_len << p1[0], p1[1], p1[2];
B(0, 0) = T(1);
B(1, 0) = T(0);
B(2, 0) = T(0);
B(0, 1) = cg;
B(1, 1) = sg;
B(2, 1) = T(0);
// c vector components:
const T cx = cb;
const T cy = (ca - cb * cg) / sg;
const T v = T(1) - cx * cx - cy * cy;
const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0);
B(0, 2) = cx;
B(1, 2) = cy;
B(2, 2) = cz;
}
// Build unrotated direct lattice columns: (B * D), then rotate them by p0.
// This avoids AngleAxisToRotationMatrix + matrix multiplications.
const T L0 = e_uc_len[0];
const T L1 = e_uc_len[1];
const T L2 = e_uc_len[2];
T col0_unrot[3] = {B(0, 0) * L0, B(1, 0) * L0, B(2, 0) * L0};
T col1_unrot[3] = {B(0, 1) * L1, B(1, 1) * L1, B(2, 1) * L1};
T col2_unrot[3] = {B(0, 2) * L2, B(1, 2) * L2, B(2, 2) * L2};
T col0_rot[3], col1_rot[3], col2_rot[3];
ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot);
ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot);
ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot);
const Eigen::Matrix<T, 3, 1> A(col0_rot[0], col0_rot[1], col0_rot[2]);
const Eigen::Matrix<T, 3, 1> Bv(col1_rot[0], col1_rot[1], col1_rot[2]);
const Eigen::Matrix<T, 3, 1> C(col2_rot[0], col2_rot[1], col2_rot[2]);
const Eigen::Matrix<T, 3, 1> BxC = Bv.cross(C);
const Eigen::Matrix<T, 3, 1> CxA = C.cross(A);
const Eigen::Matrix<T, 3, 1> AxB = A.cross(Bv);
const T V = A.dot(BxC);
const T invV = T(1) / V;
const Eigen::Matrix<T, 3, 1> Astar = BxC * invV;
const Eigen::Matrix<T, 3, 1> Bstar = CxA * invV;
const Eigen::Matrix<T, 3, 1> Cstar = AxB * invV;
const T h = T(exp_h);
const T k = T(exp_k);
const T l = T(exp_l);
const Eigen::Matrix<T, 3, 1> e_pred_recip = Astar * h + Bstar * k + Cstar * l;
residual[0] = e_obs_recip[0] - e_pred_recip[0];
residual[1] = e_obs_recip[1] - e_pred_recip[1];
residual[2] = e_obs_recip[2] - e_pred_recip[2];
return true;
}
const double obs_x, obs_y;
const double inv_lambda;
const double pixel_size;
const double exp_h;
const double exp_k;
const double exp_l;
const double angle_rad;
gemmi::CrystalSystem symmetry;
};
struct XtalResidualRotationOnlyPrecomp {
XtalResidualRotationOnlyPrecomp(const Coord &recip_obs,
const CrystalLattice &latt,
double h, double k, double l)
: s_obs(recip_obs),
astar(latt.Astar()), bstar(latt.Bstar()), cstar(latt.Cstar()),
h(h), k(k), l(l) {
}
template<typename T>
bool operator()(const T *const rot_aa, T *residual) const {
const T astar_unrot[3] = {T(astar.x), T(astar.y), T(astar.z)};
const T bstar_unrot[3] = {T(bstar.x), T(bstar.y), T(bstar.z)};
const T cstar_unrot[3] = {T(cstar.x), T(cstar.y), T(cstar.z)};
T astar_rot[3], bstar_rot[3], cstar_rot[3];
ceres::AngleAxisRotatePoint(rot_aa, astar_unrot, astar_rot);
ceres::AngleAxisRotatePoint(rot_aa, bstar_unrot, bstar_rot);
ceres::AngleAxisRotatePoint(rot_aa, cstar_unrot, cstar_rot);
const Eigen::Matrix<T, 3, 1> s_pred(T(h) * astar_rot[0] + T(k) * bstar_rot[0] + T(l) * cstar_rot[0],
T(h) * astar_rot[1] + T(k) * bstar_rot[1] + T(l) * cstar_rot[1],
T(h) * astar_rot[2] + T(k) * bstar_rot[2] + T(l) * cstar_rot[2]
);
// Residual in reciprocal space
residual[0] = T(s_obs.x) - s_pred[0];
residual[1] = T(s_obs.y) - s_pred[1];
residual[2] = T(s_obs.z) - s_pred[2];
return true;
}
const Coord s_obs;
const Coord astar, bstar, cstar;
const double h, k, l;
};
// Regularizer: penalises ||rot_aa|| to prefer the smallest rotation that
// explains the data. Weight should be chosen in the same units as the
// reciprocal-space residuals (Å⁻¹ per radian). A value of ~0.010.1 is
// typically enough to break degeneracy without biasing the solution.
struct RotationNormRegularizer {
explicit RotationNormRegularizer(double weight) : weight(weight) {}
template<typename T>
bool operator()(const T *const rot_aa, T *residual) const {
residual[0] = T(weight) * rot_aa[0];
residual[1] = T(weight) * rot_aa[1];
residual[2] = T(weight) * rot_aa[2];
return true;
}
const double weight;
};
inline void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt,
double rod[3],
double lengths[3]) {
// Load lattice columns
const Coord a = latt.Vec0();
const Coord b = latt.Vec1();
const Coord c = latt.Vec2();
Eigen::Vector3d A(a[0], a[1], a[2]);
Eigen::Vector3d B(b[0], b[1], b[2]);
Eigen::Vector3d C(c[0], c[1], c[2]);
// Lengths = column norms (orthorhombic assumption)
lengths[0] = A.norm();
lengths[1] = B.norm();
lengths[2] = C.norm();
auto safe_unit = [](const Eigen::Vector3d &v, double eps = 1e-15) -> Eigen::Vector3d {
double n = v.norm();
return (n > eps) ? (v / n) : Eigen::Vector3d(1.0, 0.0, 0.0);
};
// GramSchmidt with original order: x from A, y from B orthogonalized vs x
Eigen::Vector3d e1 = safe_unit(A);
Eigen::Vector3d y = B - (e1.dot(B)) * e1;
Eigen::Vector3d e2 = safe_unit(y);
// z from cross to ensure right-handed basis
Eigen::Vector3d e3 = e1.cross(e2);
double n3 = e3.norm();
if (n3 < 1e-15) {
// Degenerate case: B nearly collinear with A → use C instead
y = C - (e1.dot(C)) * e1;
e2 = safe_unit(y);
e3 = e1.cross(e2);
n3 = e3.norm();
if (n3 < 1e-15) {
// Still degenerate: pick any perpendicular to e1
e2 = safe_unit((std::abs(e1.x()) < 0.9)
? Eigen::Vector3d::UnitX().cross(e1)
: Eigen::Vector3d::UnitY().cross(e1));
e3 = e1.cross(e2);
}
} else {
e3 /= n3;
}
Eigen::Matrix3d R;
R.col(0) = e1;
R.col(1) = e2;
R.col(2) = e3;
// Convert rotation to Rodrigues (axis * angle)
Eigen::AngleAxisd aa(R);
Eigen::Vector3d r = aa.angle() * aa.axis();
rod[0] = r.x();
rod[1] = r.y();
rod[2] = r.z();
}
void LatticeToRodriguesAndLengths_Hex(const CrystalLattice &latt, double rod[3], double ac[3]) {
const Coord a = latt.Vec0();
const Coord b = latt.Vec1();
const Coord c = latt.Vec2();
Eigen::Vector3d A(a[0], a[1], a[2]);
Eigen::Vector3d B(b[0], b[1], b[2]);
Eigen::Vector3d C(c[0], c[1], c[2]);
const double a_len = A.norm();
const double b_len = B.norm();
const double c_len = C.norm();
ac[0] = (a_len + b_len) / 2.0;
ac[1] = (a_len + b_len) / 2.0;
ac[2] = c_len;
Eigen::Vector3d e1;
Eigen::Vector3d e3;
if (a_len > 0.0)
e1 = A / a_len;
else
e1 = Eigen::Vector3d::UnitX();
if (c_len > 0.0)
e3 = C / c_len;
else
e3 = Eigen::Vector3d::UnitZ();
Eigen::Vector3d e2 = e3.cross(e1);
if (e2.norm() < 1e-15) {
e2 = (std::abs(e1.x()) < 0.9)
? Eigen::Vector3d::UnitX().cross(e1)
: Eigen::Vector3d::UnitY().cross(e1);
}
e2.normalize();
e3 = e1.cross(e2).normalized();
Eigen::Matrix3d R;
R.col(0) = e1;
R.col(1) = e2;
R.col(2) = e3;
Eigen::AngleAxisd aa(R);
Eigen::Vector3d r = aa.angle() * aa.axis();
rod[0] = r.x();
rod[1] = r.y();
rod[2] = r.z();
}
// Extract rotation (Rodrigues), lengths (a,b,c) and beta (rad) for monoclinic (unique axis b).
// Frame choice: e2 aligned with b; e1 from a projected orthogonal to e2; e3 = e1 x e2.
void LatticeToRodriguesLengthsBeta_Mono(const CrystalLattice &latt,
double rod[3],
double lengths[3],
double &beta_rad) {
const Coord a = latt.Vec0();
const Coord b = latt.Vec1();
const Coord c = latt.Vec2();
const Eigen::Vector3d A(a[0], a[1], a[2]);
const Eigen::Vector3d Bv(b[0], b[1], b[2]);
const Eigen::Vector3d C(c[0], c[1], c[2]);
// Unit cell lengths
const double a_len = A.norm();
const double b_len = Bv.norm();
const double c_len = C.norm();
lengths[0] = a_len;
lengths[1] = b_len;
lengths[2] = c_len;
// Monoclinic beta = angle(a, c)
double cos_beta = 0.0;
if (a_len > 1e-15 && c_len > 1e-15) {
cos_beta = A.dot(C) / (a_len * c_len);
cos_beta = std::clamp(cos_beta, -1.0, 1.0);
}
beta_rad = std::acos(cos_beta);
// Protect against singular construction
const double sin_beta = std::max(std::abs(std::sin(beta_rad)), 1e-12);
// Canonical monoclinic basis:
//
// B =
// [ 1 0 cos(beta) ]
// [ 0 1 0 ]
// [ 0 0 sin(beta) ]
//
Eigen::Matrix3d Bmono = Eigen::Matrix3d::Zero();
Bmono(0,0) = 1.0;
Bmono(1,1) = 1.0;
Bmono(0,2) = std::cos(beta_rad);
Bmono(2,2) = sin_beta;
// Scale by lengths
Eigen::DiagonalMatrix<double,3> D(a_len, b_len, c_len);
// Ideal body-frame lattice
const Eigen::Matrix3d M = Bmono * D;
// Observed lattice
Eigen::Matrix3d L;
L.col(0) = A;
L.col(1) = Bv;
L.col(2) = C;
// Estimate rotation:
// R ≈ L * M^{-1}
Eigen::Matrix3d R_est = L * M.inverse();
// Project to nearest proper rotation matrix
Eigen::JacobiSVD<Eigen::Matrix3d> svd(R_est, Eigen::ComputeFullU | Eigen::ComputeFullV);
Eigen::Matrix3d R = svd.matrixU() * svd.matrixV().transpose();
// Enforce det(R)=+1
if (R.determinant() < 0.0) {
Eigen::Matrix3d U = svd.matrixU();
U.col(2) *= -1.0;
R = U * svd.matrixV().transpose();
}
// Rodrigues vector
Eigen::AngleAxisd aa(R);
const Eigen::Vector3d r = aa.angle() * aa.axis();
rod[0] = r.x();
rod[1] = r.y();
rod[2] = r.z();
}
static inline Eigen::Matrix3d B_from_angles(double alpha_rad, double beta_rad, double gamma_rad) {
const double ca = std::cos(alpha_rad);
const double cb = std::cos(beta_rad);
const double cg = std::cos(gamma_rad);
const double sg = std::sin(gamma_rad);
Eigen::Matrix3d B = Eigen::Matrix3d::Identity();
// a along x, b in x-y, c general
B(0, 0) = 1.0; B(1, 0) = 0.0; B(2, 0) = 0.0;
B(0, 1) = cg; B(1, 1) = sg; B(2, 1) = 0.0;
// c vector components (standard crystallography construction)
const double cx = cb;
const double cy = (ca - cb * cg) / sg;
const double cz = std::sqrt(std::max(0.0, 1.0 - cx * cx - cy * cy));
B(0, 2) = cx;
B(1, 2) = cy;
B(2, 2) = cz;
return B;
}
CrystalLattice AngleAxisAndCellToLattice(const double rod[3],
const double lengths[3],
double alpha_rad,
double beta_rad,
double gamma_rad) {
const Eigen::Vector3d r(rod[0], rod[1], rod[2]);
const double angle = r.norm();
Eigen::Matrix3d R = Eigen::Matrix3d::Identity();
if (angle > 0.0)
R = Eigen::AngleAxisd(angle, r / angle).toRotationMatrix();
const Eigen::DiagonalMatrix<double, 3> D(lengths[0], lengths[1], lengths[2]);
const Eigen::Matrix3d B = B_from_angles(alpha_rad, beta_rad, gamma_rad);
// IMPORTANT convention: L = R * B * D (scale columns by lengths)
const Eigen::Matrix3d latt = R * B * D;
return CrystalLattice(Coord(latt(0, 0), latt(1, 0), latt(2, 0)),
Coord(latt(0, 1), latt(1, 1), latt(2, 1)),
Coord(latt(0, 2), latt(1, 2), latt(2, 2)));
}
CrystalLattice AngleAxisAndLengthsToLattice(const double rod[3], const double lengths[3], bool hex) {
if (!hex) {
return AngleAxisAndCellToLattice(rod, lengths,
/*alpha=*/M_PI / 2.0,
/*beta =*/M_PI / 2.0,
/*gamma=*/M_PI / 2.0);
}
// Hexagonal: caller must already enforce a=b in `lengths`.
return AngleAxisAndCellToLattice(rod, lengths,
/*alpha=*/M_PI / 2.0,
/*beta =*/M_PI / 2.0,
/*gamma=*/2.0 * M_PI / 3.0);
}
bool XtalOptimizerInternal(XtalOptimizerData &data,
const std::vector<SpotToSave> &spots,
const float tolerance) {
try {
Coord vec0 = data.latt.Vec0();
Coord vec1 = data.latt.Vec1();
Coord vec2 = data.latt.Vec2();
double beta = data.latt.GetUnitCell().beta;
// Initial guess for the parameters
double beam[2] = {data.geom.GetBeamX_pxl(), data.geom.GetBeamY_pxl()};
double distance_mm = data.geom.GetDetectorDistance_mm();
double detector_rot[2] = {data.geom.GetPoniRot1_rad(), data.geom.GetPoniRot2_rad()};
ceres::Problem problem;
double latt_vec0[3] = {0.0, 0.0, 0.0};
double latt_vec1[3] = {0.0, 0.0, 0.0};
double latt_vec2[3] = {0.0, 0.0, 0.0};
double rot_vec[3] = {1, 0, 0};
switch (data.crystal_system) {
case gemmi::CrystalSystem::Orthorhombic:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
break;
case gemmi::CrystalSystem::Tetragonal:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1]) / 2.0;
break;
case gemmi::CrystalSystem::Cubic:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1] + latt_vec1[2]) / 3.0;
break;
case gemmi::CrystalSystem::Hexagonal:
LatticeToRodriguesAndLengths_Hex(data.latt, latt_vec0, latt_vec1);
break;
case gemmi::CrystalSystem::Monoclinic:
LatticeToRodriguesLengthsBeta_Mono(data.latt, latt_vec0, latt_vec1, beta);
latt_vec2[0] = beta;
latt_vec2[1] = 0.0;
latt_vec2[2] = 0.0;
break;
default:
// Triclinic: initialize a,b,c and α,β,γ from current unit cell
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
auto uc = data.latt.GetUnitCell();
latt_vec2[0] = uc.alpha * M_PI / 180.0;
latt_vec2[1] = uc.beta * M_PI / 180.0;
latt_vec2[2] = uc.gamma * M_PI / 180.0;
break;
}
if (data.axis) {
rot_vec[0] = data.axis->GetAxis().x;
rot_vec[1] = data.axis->GetAxis().y;
rot_vec[2] = data.axis->GetAxis().z;
}
const float tolerance_sq = tolerance * tolerance;
// Add residuals for each point
for (const auto &pt: spots) {
if (!data.index_ice_rings && pt.ice_ring)
continue;
float angle_rad = 0.0;
Coord recip = pt.ReciprocalCoord(data.geom);
if (data.axis) {
recip = data.axis->GetTransformationAngle(pt.phi) * recip;
angle_rad = pt.phi * M_PI / 180.0;
}
double h_fp = recip * vec0;
double k_fp = recip * vec1;
double l_fp = recip * vec2;
double h = std::round(h_fp);
double k = std::round(k_fp);
double l = std::round(l_fp);
double norm_sq = (h - h_fp) * (h - h_fp) + (k - k_fp) * (k - k_fp) + (l - l_fp) * (l - l_fp);
if (norm_sq > tolerance_sq)
continue;
problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<XtalResidual, 3, 2, 1, 2, 3, 3, 3, 3>(
new XtalResidual(pt.x, pt.y,
data.geom.GetWavelength_A(),
data.geom.GetPixelSize_mm(),
angle_rad,
h, k, l,
data.crystal_system)),
nullptr,
beam,
&distance_mm,
detector_rot,
rot_vec,
latt_vec0,
latt_vec1,
latt_vec2
);
}
if (problem.NumResidualBlocks() < data.min_spots)
return false;
if (!data.refine_distance_mm)
problem.SetParameterBlockConstant(&distance_mm);
else {
const double dist_range = 0.1;
problem.SetParameterLowerBound(&distance_mm, 0, distance_mm * (1.0 - dist_range));
problem.SetParameterUpperBound(&distance_mm, 0, distance_mm * (1.0 + dist_range));
}
if (!data.refine_beam_center)
problem.SetParameterBlockConstant(beam);
if (!data.refine_detector_angles) {
problem.SetParameterBlockConstant(detector_rot);
} else {
const double rot_range = 3.0 / 180.0 * M_PI;
for (int i = 0; i < 2; ++i) {
problem.SetParameterLowerBound(detector_rot, i, detector_rot[i] - rot_range);
problem.SetParameterUpperBound(detector_rot, i, detector_rot[i] + rot_range);
}
}
if (!data.refine_rotation_axis) {
problem.SetParameterBlockConstant(rot_vec);
}
if (!data.refine_unit_cell) {
problem.SetParameterBlockConstant(latt_vec1);
problem.SetParameterBlockConstant(latt_vec2);
} else {
// Parameter bounds
// Lengths
for (int i = 0; i < 3; ++i) {
problem.SetParameterLowerBound(latt_vec1, i, data.min_length_A);
problem.SetParameterUpperBound(latt_vec1, i, data.max_length_A);
}
if (data.crystal_system == gemmi::CrystalSystem::Monoclinic) {
const double beta_lo = std::max(1e-6, M_PI * (data.min_angle_deg / 180.0));
const double beta_hi = std::min(M_PI - 1e-6, M_PI * (data.max_angle_deg / 180.0));
problem.SetParameterLowerBound(latt_vec2, 0, beta_lo);
problem.SetParameterUpperBound(latt_vec2, 0, beta_hi);
} else if (data.crystal_system == gemmi::CrystalSystem::Triclinic) {
// α, β, γ bounds (radians)
const double alo = M_PI * (data.min_angle_deg / 180.0);
const double ahi = M_PI * (data.max_angle_deg / 180.0);
for (int i = 0; i < 3; ++i) {
problem.SetParameterLowerBound(latt_vec2, i, alo);
problem.SetParameterUpperBound(latt_vec2, i, ahi);
}
} else {
// Orthorhombic / Tetragonal / Cubic / Hexagonal:
// latt_vec2 has no meaning for these systems — always freeze it.
problem.SetParameterBlockConstant(latt_vec2);
}
}
// Configure solver
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = false;
options.max_solver_time_in_seconds = data.max_time;
options.logging_type = ceres::LoggingType::SILENT;
options.num_threads = 1; // Fix threads to 1, as this runs in multi-threaded context
ceres::Solver::Summary summary;
// Run optimization
ceres::Solve(options, &problem, &summary);
if (data.refine_beam_center) {
data.beam_corr_x = data.geom.GetBeamX_pxl() - beam[0];
data.beam_corr_y = data.geom.GetBeamY_pxl() - beam[1];
data.geom.BeamX_pxl(beam[0]).BeamY_pxl(beam[1]);
}
if (data.refine_distance_mm)
data.geom.DetectorDistance_mm(distance_mm);
if (data.refine_detector_angles)
data.geom.PoniRot1_rad(detector_rot[0]).PoniRot2_rad(detector_rot[1]);
if (data.axis && data.refine_rotation_axis)
data.axis.value().Axis(Coord(rot_vec[0], rot_vec[1], rot_vec[2]));
if (data.crystal_system == gemmi::CrystalSystem::Orthorhombic)
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0);
else if (data.crystal_system == gemmi::CrystalSystem::Tetragonal) {
latt_vec1[1] = latt_vec1[0];
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0);
} else if (data.crystal_system == gemmi::CrystalSystem::Cubic) {
latt_vec1[1] = latt_vec1[0];
latt_vec1[2] = latt_vec1[0];
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, M_PI / 2.0, M_PI / 2.0);
} else if (data.crystal_system == gemmi::CrystalSystem::Hexagonal) {
latt_vec1[1] = latt_vec1[0];
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1,M_PI / 2.0, M_PI / 2.0, 2.0 * M_PI / 3.0);
} else if (data.crystal_system == gemmi::CrystalSystem::Monoclinic) {
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, M_PI / 2.0, latt_vec2[0], M_PI / 2.0);
} else {
// Triclinic via the same generic builder
data.latt = AngleAxisAndCellToLattice(latt_vec0, latt_vec1, latt_vec2[0], latt_vec2[1], latt_vec2[2]);
}
return true;
} catch (...) {
// Convergence problems, likely not updated
return false;
}
}
bool XtalOptimizer(XtalOptimizerData &data, const std::vector<SpotToSave> &spots) {
if (!XtalOptimizerInternal(data, spots, 0.3))
return false;
XtalOptimizerInternal(data, spots, 0.2);
return XtalOptimizerInternal(data, spots, 0.1);
}
bool XtalOptimizerRotationOnly(XtalOptimizerData &data,
const std::vector<SpotToSave> &spots,
const float tolerance) {
try {
// Parameter: angle-axis for the extra rotation. Identity == {0,0,0}.
double rot_aa[3] = {0.0, 0.0, 0.0};
// Spot selection by current indexing (same approach as XtalOptimizerInternal)
const Coord a0 = data.latt.Vec0();
const Coord b0 = data.latt.Vec1();
const Coord c0 = data.latt.Vec2();
const float tol_sq = tolerance * tolerance;
ceres::Problem problem;
for (const auto &pt : spots) {
if (!data.index_ice_rings && pt.ice_ring)
continue;
// Compute fractional HKL using the CURRENT lattice
Coord recip_index = pt.ReciprocalCoord(data.geom);
if (data.axis.has_value())
recip_index = data.axis->GetTransformationAngle(pt.phi) * recip_index;
const double h_fp = static_cast<double>(recip_index * a0);
const double k_fp = static_cast<double>(recip_index * b0);
const double l_fp = static_cast<double>(recip_index * c0);
const double h = std::round(h_fp);
const double k = std::round(k_fp);
const double l = std::round(l_fp);
const double norm_sq =
(h - h_fp) * (h - h_fp) +
(k - k_fp) * (k - k_fp) +
(l - l_fp) * (l - l_fp);
if (norm_sq > static_cast<double>(tol_sq))
continue;
// s_obs must be in the same reference frame as the
// predicted reciprocal vector (h·a* + k·b* + l·c*), which is the
// phi=0 crystal frame. Apply the same goniometer back-rotation
// that was used above for the HKL assignment.
Coord s_obs = data.geom.DetectorToRecip(pt.x, pt.y);
if (data.axis.has_value())
s_obs = data.axis->GetTransformationAngle(pt.phi) * s_obs;
auto *cost =
new ceres::AutoDiffCostFunction<XtalResidualRotationOnlyPrecomp, 3, 3>(
new XtalResidualRotationOnlyPrecomp(s_obs, data.latt, h, k, l)
);
problem.AddResidualBlock(cost, nullptr, rot_aa);
}
if (problem.NumResidualBlocks() < data.min_spots)
return false;
// Regularization: prefer the smallest rotation correction that fits the
// data. This is essential when spots are nearly coplanar in reciprocal
// space (e.g. still images), where the rotation component perpendicular
// to the scattering plane is otherwise underdetermined.
// The weight is in Å⁻¹ rad⁻¹; tune relative to your typical residual.
{
const double reg_weight = 0.05; // e.g. 0.05
problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<RotationNormRegularizer, 3, 3>(
new RotationNormRegularizer(reg_weight)),
nullptr, rot_aa);
}
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = false;
options.max_solver_time_in_seconds = data.max_time;
options.logging_type = ceres::LoggingType::SILENT;
options.num_threads = 1;
ceres::Solver::Summary summary;
ceres::Solve(options, &problem, &summary);
// Apply rotation to direct-lattice vectors.
// ceres::AngleAxisToRotationMatrix writes a **row-major** 3×3 matrix,
// and Eigen's << operator also fills row-by-row, so the assignment
// below is correct without any transposing.
//
// Note: for a pure orthogonal rotation R, R⁻ᵀ = R, so rotating the
// direct-lattice vectors (A, B, C) by R is exactly equivalent to
// rotating the reciprocal vectors (a*, b*, c*) by the same R. No
// transpose or inversion of R is needed here.
double R_raw[9];
ceres::AngleAxisToRotationMatrix(rot_aa, R_raw); // row-major 3x3
Eigen::Matrix3d R;
R << R_raw[0], R_raw[3], R_raw[6],
R_raw[1], R_raw[4], R_raw[7],
R_raw[2], R_raw[5], R_raw[8];
const Eigen::Vector3d A(a0.x, a0.y, a0.z);
const Eigen::Vector3d B(b0.x, b0.y, b0.z);
const Eigen::Vector3d C(c0.x, c0.y, c0.z);
const Eigen::Vector3d A2 = R * A;
const Eigen::Vector3d B2 = R * B;
const Eigen::Vector3d C2 = R * C;
data.latt = CrystalLattice(
Coord(static_cast<float>(A2.x()), static_cast<float>(A2.y()), static_cast<float>(A2.z())),
Coord(static_cast<float>(B2.x()), static_cast<float>(B2.y()), static_cast<float>(B2.z())),
Coord(static_cast<float>(C2.x()), static_cast<float>(C2.y()), static_cast<float>(C2.z()))
);
double theta = std::sqrt(rot_aa[0] * rot_aa[0] + rot_aa[1] * rot_aa[1] + rot_aa[2] * rot_aa[2]);
data.angle_corr = theta;
if (theta > 1e-6) {
Coord rot;
rot.x = rot_aa[0] / theta;
rot.y = rot_aa[1] / theta;
rot.z = rot_aa[2] / theta;
data.angle_axis = rot;
} else
data.angle_axis.reset();
return true;
} catch (...) {
return false;
}
}