Files
Jungfraujoch/image_analysis/geom_refinement/XtalOptimizer.cpp
2025-09-08 20:28:59 +02:00

392 lines
14 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "XtalOptimizer.h"
#include "ceres/ceres.h"
#include "ceres/rotation.h"
struct XtalResidual {
XtalResidual(double x, double y,
double beam_x, double beam_y,
double lambda,
double pixel_size, double distance_mm,
double rot1, double rot2,
double exp_h, double exp_k,
double exp_l, gemmi::CrystalSystem symmetry)
: obs_x(x), obs_y(y),
lambda(lambda),
pixel_size(pixel_size),
exp_h(exp_h),
exp_k(exp_k),
exp_l(exp_l),
distance(distance_mm),
rot1(rot1),
rot2(rot2),
beam_x(beam_x),
beam_y(beam_y),
symmetry(symmetry) {}
template<typename T>
bool operator()(const T *const corr_x,
const T *const corr_y,
const T *const p0,
const T *const p1,
const T *const p2,
T *residual) const {
T c1 = ceres::cos(T(rot1));
T c2 = ceres::cos(T(rot2));
T s1 = ceres::sin(T(rot1));
T s2 = ceres::sin(T(rot2));
// x_lab in mm
T x_lab = (T(obs_x) - beam_x - corr_x[0]) * T(pixel_size);
T y_lab = (T(obs_y) - beam_y - corr_y[0]) * T(pixel_size);
T z_lab = T(distance);
// apply rotations
T x = x_lab * c1 + z_lab * s1;
T y = y_lab * c2 + (-x_lab * s1 + z_lab * c1) * s2;
T z = -y_lab * s2 + (-x_lab * s1 + z_lab * c1) * c2;
// convert to recip space
T lab_norm = ceres::sqrt(x * x + y * y + z * z);
T recip[3];
recip[0] = x / (lab_norm * T(lambda));
recip[1] = y / (lab_norm * T(lambda));
recip[2] = (z / lab_norm - T(1.0)) / T(lambda);
Eigen::Map<const Eigen::Matrix<T, 3, 1>> e_obs_recip(recip);
Eigen::Matrix<T, 3, 1> e_pred;
Eigen::Matrix<T, 3, 3> e_latt;
if (symmetry == gemmi::CrystalSystem::Cubic
|| symmetry == gemmi::CrystalSystem::Tetragonal
|| symmetry == gemmi::CrystalSystem::Orthorhombic
|| symmetry == gemmi::CrystalSystem::Hexagonal) {
T uc_rot_matrix[9];
ceres::AngleAxisToRotationMatrix(p0, uc_rot_matrix);
Eigen::Map<const Eigen::Matrix<T, 3, 3>> e_uc_rot_matrix(uc_rot_matrix);
Eigen::Matrix<T, 3, 1> e_uc_len = Eigen::Matrix<T, 3, 1>::Zero();
Eigen::Matrix<T, 3, 3> Bhex = Eigen::Matrix<T, 3, 3>::Identity();
if (symmetry == gemmi::CrystalSystem::Hexagonal) {
e_uc_len << p1[0], p1[0], p1[2];
Bhex(0, 1) = T(-0.5); // cos(120)
Bhex(1, 1) = T(sqrt(3.0) / 2.0); // sin(120)
} else if (symmetry == gemmi::CrystalSystem::Orthorhombic)
e_uc_len << p1[0], p1[1], p1[2];
else if (symmetry == gemmi::CrystalSystem::Tetragonal)
e_uc_len << p1[0], p1[0], p1[2];
else if (symmetry == gemmi::CrystalSystem::Cubic)
e_uc_len << p1[0], p1[0], p1[0];
e_latt = e_uc_rot_matrix * e_uc_len.asDiagonal() * Bhex;
} else
e_latt << p0[0], p1[0], p2[0], p0[1], p1[1], p2[1], p0[2], p1[2], p2[2];
Eigen::Matrix<T, 3, 1> e_hkl;
e_hkl << T(exp_h), T(exp_k), T(exp_l);
auto e_pred_hkl = e_latt.transpose() * e_obs_recip;
residual[0] = exp_h - e_pred_hkl[0];
residual[1] = exp_k - e_pred_hkl[1];
residual[2] = exp_l - e_pred_hkl[2];
return true;
}
const double obs_x, obs_y;
const double lambda;
const double pixel_size;
const double exp_h;
const double exp_k;
const double exp_l;
const double distance;
const double rot1, rot2;
const double beam_x, beam_y;
gemmi::CrystalSystem symmetry;
};
inline void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt,
double rod[3],
double lengths[3]) {
// Load lattice columns
const Coord a = latt.Vec0();
const Coord b = latt.Vec1();
const Coord c = latt.Vec2();
Eigen::Vector3d A(a[0], a[1], a[2]);
Eigen::Vector3d B(b[0], b[1], b[2]);
Eigen::Vector3d C(c[0], c[1], c[2]);
// Lengths = column norms (orthorhombic assumption)
lengths[0] = A.norm();
lengths[1] = B.norm();
lengths[2] = C.norm();
auto safe_unit = [](const Eigen::Vector3d &v, double eps = 1e-15) -> Eigen::Vector3d {
double n = v.norm();
return (n > eps) ? (v / n) : Eigen::Vector3d(1.0, 0.0, 0.0);
};
// GramSchmidt with original order: x from A, y from B orthogonalized vs x
Eigen::Vector3d e1 = safe_unit(A);
Eigen::Vector3d y = B - (e1.dot(B)) * e1;
Eigen::Vector3d e2 = safe_unit(y);
// z from cross to ensure right-handed basis
Eigen::Vector3d e3 = e1.cross(e2);
double n3 = e3.norm();
if (n3 < 1e-15) {
// Degenerate case: B nearly collinear with A → use C instead
y = C - (e1.dot(C)) * e1;
e2 = safe_unit(y);
e3 = e1.cross(e2);
n3 = e3.norm();
if (n3 < 1e-15) {
// Still degenerate: pick any perpendicular to e1
e2 = safe_unit((std::abs(e1.x()) < 0.9)
? Eigen::Vector3d::UnitX().cross(e1)
: Eigen::Vector3d::UnitY().cross(e1));
e3 = e1.cross(e2);
}
} else {
e3 /= n3;
}
Eigen::Matrix3d R;
R.col(0) = e1;
R.col(1) = e2;
R.col(2) = e3;
// Convert rotation to Rodrigues (axis * angle)
Eigen::AngleAxisd aa(R);
Eigen::Vector3d r = aa.angle() * aa.axis();
rod[0] = r.x();
rod[1] = r.y();
rod[2] = r.z();
}
void LatticeToRodriguesAndLengths_Hex(const CrystalLattice &latt, double rod[3], double ac[3]) {
const Coord a = latt.Vec0();
const Coord b = latt.Vec1();
const Coord c = latt.Vec2();
Eigen::Vector3d A(a[0], a[1], a[2]);
Eigen::Vector3d B(b[0], b[1], b[2]);
Eigen::Vector3d C(c[0], c[1], c[2]);
const double a_len = A.norm();
const double b_len = B.norm();
const double c_len = C.norm();
ac[0] = (a_len + b_len) / 2.0;
ac[1] = (a_len + b_len) / 2.0;
ac[2] = c_len;
Eigen::Vector3d e1;
Eigen::Vector3d e3;
if (a_len > 0.0)
e1 = A / a_len;
else
e1 = Eigen::Vector3d::UnitX();
if (c_len > 0.0)
e3 = C / c_len;
else
e3 = Eigen::Vector3d::UnitZ();
Eigen::Vector3d e2 = e3.cross(e1);
if (e2.norm() < 1e-15) {
e2 = (std::abs(e1.x()) < 0.9) ? Eigen::Vector3d::UnitX().cross(e1)
: Eigen::Vector3d::UnitY().cross(e1);
}
e2.normalize();
e3 = e1.cross(e2).normalized();
Eigen::Matrix3d R;
R.col(0) = e1;
R.col(1) = e2;
R.col(2) = e3;
Eigen::AngleAxisd aa(R);
Eigen::Vector3d r = aa.angle() * aa.axis();
rod[0] = r.x();
rod[1] = r.y();
rod[2] = r.z();
}
CrystalLattice AngleAxisAndLengthsToLattice(const double rod[3], const double lengths[3], bool hex) {
const Eigen::Vector3d r(rod[0], rod[1], rod[2]);
const double angle = r.norm();
Eigen::Matrix3d R = Eigen::Matrix3d::Identity();
if (angle > 0.0)
R = Eigen::AngleAxisd(angle, r / angle).toRotationMatrix();
const Eigen::DiagonalMatrix<double, 3> D(lengths[0], lengths[1], lengths[2]);
Eigen::Matrix3d Bhex = Eigen::Matrix3d::Identity();
if (hex) {
Bhex(0, 1) = -1/2.0;
Bhex(1, 1) = sqrt(3)/2;
}
auto latt = R * D * Bhex;
return CrystalLattice(Coord(latt(0, 0), latt(1, 0), latt(2, 0)),
Coord(latt(0, 1), latt(1, 1), latt(2, 1)),
Coord(latt(0, 2), latt(1, 2), latt(2, 2)));
}
bool XtalOptimizerInternal(XtalOptimizerData &data,
const std::vector<SpotToSave> &spots,
const float tolerance) {
try {
if (data.centering != 'P')
data.latt = data.latt.ToPrimitive(data.centering);
if (data.crystal_system == gemmi::CrystalSystem::Tetragonal
|| data.crystal_system == gemmi::CrystalSystem::Hexagonal)
data.latt.ReorderABEqual();
Coord vec0 = data.latt.Vec0();
Coord vec1 = data.latt.Vec1();
Coord vec2 = data.latt.Vec2();
// Initial guess for the parameters
double corr_x = 0;
double corr_y = 0;
ceres::Problem problem;
double latt_vec0[3], latt_vec1[3], latt_vec2[3];
switch (data.crystal_system) {
case gemmi::CrystalSystem::Orthorhombic:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
break;
case gemmi::CrystalSystem::Tetragonal:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1]) / 2.0;
break;
case gemmi::CrystalSystem::Cubic:
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1] + latt_vec1[2]) / 3.0;
break;
case gemmi::CrystalSystem::Hexagonal:
LatticeToRodriguesAndLengths_Hex(data.latt, latt_vec0, latt_vec1);
break;
default:
latt_vec0[0] = vec0.x; latt_vec0[1] = vec0.y; latt_vec0[2] = vec0.z;
latt_vec1[0] = vec1.x; latt_vec1[1] = vec1.y; latt_vec1[2] = vec1.z;
latt_vec2[0] = vec2.x; latt_vec2[1] = vec2.y; latt_vec2[2] = vec2.z;
break;
}
// Add residuals for each point
for (const auto &pt: spots) {
Coord recip = data.geom.DetectorToRecip(pt.x, pt.y);
double h_fp = recip * vec0;
double k_fp = recip * vec1;
double l_fp = recip * vec2;
double h = std::round(h_fp);
double k = std::round(k_fp);
double l = std::round(l_fp);
double norm_sq = (h - h_fp) * (h - h_fp) + (k -k_fp) * (k - k_fp) + (l - l_fp) * (l - l_fp);
if (norm_sq > tolerance * tolerance)
continue;
problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<XtalResidual, 3, 1, 1, 3, 3, 3>(
new XtalResidual(pt.x, pt.y,
data.geom.GetBeamX_pxl(),
data.geom.GetBeamY_pxl(),
data.geom.GetWavelength_A(),
data.geom.GetPixelSize_mm(),
data.geom.GetDetectorDistance_mm(),
data.geom.GetPoniRot1_rad(),
data.geom.GetPoniRot2_rad(),
h, k, l,
data.crystal_system)),
nullptr,
&corr_x,
&corr_y,
latt_vec0,
latt_vec1,
latt_vec2
);
}
if (!data.refine_beam_center) {
problem.SetParameterBlockConstant(&corr_x);
problem.SetParameterBlockConstant(&corr_y);
}
if (problem.NumResidualBlocks() < data.min_spots)
return false;
// Configure solver
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = false;
options.logging_type = ceres::LoggingType::SILENT;
ceres::Solver::Summary summary;
// Run optimization
ceres::Solve(options, &problem, &summary);
if (data.refine_beam_center) {
data.beam_corr_x = corr_x;
data.beam_corr_y = corr_y;
data.geom.BeamX_pxl(data.geom.GetBeamX_pxl() + corr_x)
.BeamY_pxl(data.geom.GetBeamY_pxl() + corr_y);
}
if (data.crystal_system == gemmi::CrystalSystem::Orthorhombic)
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
else if (data.crystal_system == gemmi::CrystalSystem::Tetragonal) {
latt_vec1[1] = latt_vec1[0];
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
} else if (data.crystal_system == gemmi::CrystalSystem::Cubic) {
latt_vec1[1] = latt_vec1[0];
latt_vec1[2] = latt_vec1[0];
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
} else if (data.crystal_system == gemmi::CrystalSystem::Hexagonal) {
latt_vec1[1] = latt_vec1[0];
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, true);
} else {
data.latt = CrystalLattice(Coord(latt_vec0[0], latt_vec0[1], latt_vec0[2]),
Coord(latt_vec1[0], latt_vec1[1], latt_vec1[2]),
Coord(latt_vec2[0], latt_vec2[1], latt_vec2[2]));
}
if (data.centering != 'P')
data.latt = data.latt.FromPrimitive(data.centering);
return true;
} catch (...) {
// Convergence problems, likely not updated
return false;
}
}
bool XtalOptimizer(XtalOptimizerData &data, const std::vector<SpotToSave> &spots) {
// 5-pass algorithm
float tolerance = 0.3f;
if (!XtalOptimizerInternal(data, spots, tolerance))
return false;
for (int i = 0; i < 5; i++) {
tolerance *= 0.8;
XtalOptimizerInternal(data, spots, tolerance);
}
return true;
}