392 lines
14 KiB
C++
392 lines
14 KiB
C++
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
|
||
// SPDX-License-Identifier: GPL-3.0-only
|
||
|
||
#include "XtalOptimizer.h"
|
||
#include "ceres/ceres.h"
|
||
#include "ceres/rotation.h"
|
||
|
||
struct XtalResidual {
|
||
XtalResidual(double x, double y,
|
||
double beam_x, double beam_y,
|
||
double lambda,
|
||
double pixel_size, double distance_mm,
|
||
double rot1, double rot2,
|
||
double exp_h, double exp_k,
|
||
double exp_l, gemmi::CrystalSystem symmetry)
|
||
: obs_x(x), obs_y(y),
|
||
lambda(lambda),
|
||
pixel_size(pixel_size),
|
||
exp_h(exp_h),
|
||
exp_k(exp_k),
|
||
exp_l(exp_l),
|
||
distance(distance_mm),
|
||
rot1(rot1),
|
||
rot2(rot2),
|
||
beam_x(beam_x),
|
||
beam_y(beam_y),
|
||
symmetry(symmetry) {}
|
||
|
||
template<typename T>
|
||
bool operator()(const T *const corr_x,
|
||
const T *const corr_y,
|
||
const T *const p0,
|
||
const T *const p1,
|
||
const T *const p2,
|
||
T *residual) const {
|
||
T c1 = ceres::cos(T(rot1));
|
||
T c2 = ceres::cos(T(rot2));
|
||
T s1 = ceres::sin(T(rot1));
|
||
T s2 = ceres::sin(T(rot2));
|
||
|
||
// x_lab in mm
|
||
T x_lab = (T(obs_x) - beam_x - corr_x[0]) * T(pixel_size);
|
||
T y_lab = (T(obs_y) - beam_y - corr_y[0]) * T(pixel_size);
|
||
T z_lab = T(distance);
|
||
|
||
// apply rotations
|
||
T x = x_lab * c1 + z_lab * s1;
|
||
T y = y_lab * c2 + (-x_lab * s1 + z_lab * c1) * s2;
|
||
T z = -y_lab * s2 + (-x_lab * s1 + z_lab * c1) * c2;
|
||
|
||
// convert to recip space
|
||
T lab_norm = ceres::sqrt(x * x + y * y + z * z);
|
||
|
||
T recip[3];
|
||
recip[0] = x / (lab_norm * T(lambda));
|
||
recip[1] = y / (lab_norm * T(lambda));
|
||
recip[2] = (z / lab_norm - T(1.0)) / T(lambda);
|
||
|
||
Eigen::Map<const Eigen::Matrix<T, 3, 1>> e_obs_recip(recip);
|
||
|
||
Eigen::Matrix<T, 3, 1> e_pred;
|
||
Eigen::Matrix<T, 3, 3> e_latt;
|
||
|
||
if (symmetry == gemmi::CrystalSystem::Cubic
|
||
|| symmetry == gemmi::CrystalSystem::Tetragonal
|
||
|| symmetry == gemmi::CrystalSystem::Orthorhombic
|
||
|| symmetry == gemmi::CrystalSystem::Hexagonal) {
|
||
|
||
T uc_rot_matrix[9];
|
||
ceres::AngleAxisToRotationMatrix(p0, uc_rot_matrix);
|
||
|
||
Eigen::Map<const Eigen::Matrix<T, 3, 3>> e_uc_rot_matrix(uc_rot_matrix);
|
||
Eigen::Matrix<T, 3, 1> e_uc_len = Eigen::Matrix<T, 3, 1>::Zero();
|
||
Eigen::Matrix<T, 3, 3> Bhex = Eigen::Matrix<T, 3, 3>::Identity();
|
||
|
||
if (symmetry == gemmi::CrystalSystem::Hexagonal) {
|
||
e_uc_len << p1[0], p1[0], p1[2];
|
||
Bhex(0, 1) = T(-0.5); // cos(120)
|
||
Bhex(1, 1) = T(sqrt(3.0) / 2.0); // sin(120)
|
||
} else if (symmetry == gemmi::CrystalSystem::Orthorhombic)
|
||
e_uc_len << p1[0], p1[1], p1[2];
|
||
else if (symmetry == gemmi::CrystalSystem::Tetragonal)
|
||
e_uc_len << p1[0], p1[0], p1[2];
|
||
else if (symmetry == gemmi::CrystalSystem::Cubic)
|
||
e_uc_len << p1[0], p1[0], p1[0];
|
||
|
||
e_latt = e_uc_rot_matrix * e_uc_len.asDiagonal() * Bhex;
|
||
} else
|
||
e_latt << p0[0], p1[0], p2[0], p0[1], p1[1], p2[1], p0[2], p1[2], p2[2];
|
||
|
||
Eigen::Matrix<T, 3, 1> e_hkl;
|
||
e_hkl << T(exp_h), T(exp_k), T(exp_l);
|
||
|
||
auto e_pred_hkl = e_latt.transpose() * e_obs_recip;
|
||
|
||
residual[0] = exp_h - e_pred_hkl[0];
|
||
residual[1] = exp_k - e_pred_hkl[1];
|
||
residual[2] = exp_l - e_pred_hkl[2];
|
||
|
||
return true;
|
||
}
|
||
|
||
const double obs_x, obs_y;
|
||
const double lambda;
|
||
const double pixel_size;
|
||
const double exp_h;
|
||
const double exp_k;
|
||
const double exp_l;
|
||
const double distance;
|
||
const double rot1, rot2;
|
||
const double beam_x, beam_y;
|
||
gemmi::CrystalSystem symmetry;
|
||
};
|
||
|
||
inline void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt,
|
||
double rod[3],
|
||
double lengths[3]) {
|
||
// Load lattice columns
|
||
const Coord a = latt.Vec0();
|
||
const Coord b = latt.Vec1();
|
||
const Coord c = latt.Vec2();
|
||
|
||
Eigen::Vector3d A(a[0], a[1], a[2]);
|
||
Eigen::Vector3d B(b[0], b[1], b[2]);
|
||
Eigen::Vector3d C(c[0], c[1], c[2]);
|
||
|
||
// Lengths = column norms (orthorhombic assumption)
|
||
lengths[0] = A.norm();
|
||
lengths[1] = B.norm();
|
||
lengths[2] = C.norm();
|
||
|
||
auto safe_unit = [](const Eigen::Vector3d &v, double eps = 1e-15) -> Eigen::Vector3d {
|
||
double n = v.norm();
|
||
return (n > eps) ? (v / n) : Eigen::Vector3d(1.0, 0.0, 0.0);
|
||
};
|
||
|
||
// Gram–Schmidt with original order: x from A, y from B orthogonalized vs x
|
||
Eigen::Vector3d e1 = safe_unit(A);
|
||
Eigen::Vector3d y = B - (e1.dot(B)) * e1;
|
||
Eigen::Vector3d e2 = safe_unit(y);
|
||
|
||
// z from cross to ensure right-handed basis
|
||
Eigen::Vector3d e3 = e1.cross(e2);
|
||
double n3 = e3.norm();
|
||
if (n3 < 1e-15) {
|
||
// Degenerate case: B nearly collinear with A → use C instead
|
||
y = C - (e1.dot(C)) * e1;
|
||
e2 = safe_unit(y);
|
||
e3 = e1.cross(e2);
|
||
n3 = e3.norm();
|
||
if (n3 < 1e-15) {
|
||
// Still degenerate: pick any perpendicular to e1
|
||
e2 = safe_unit((std::abs(e1.x()) < 0.9)
|
||
? Eigen::Vector3d::UnitX().cross(e1)
|
||
: Eigen::Vector3d::UnitY().cross(e1));
|
||
e3 = e1.cross(e2);
|
||
}
|
||
} else {
|
||
e3 /= n3;
|
||
}
|
||
|
||
Eigen::Matrix3d R;
|
||
R.col(0) = e1;
|
||
R.col(1) = e2;
|
||
R.col(2) = e3;
|
||
|
||
// Convert rotation to Rodrigues (axis * angle)
|
||
Eigen::AngleAxisd aa(R);
|
||
Eigen::Vector3d r = aa.angle() * aa.axis();
|
||
rod[0] = r.x();
|
||
rod[1] = r.y();
|
||
rod[2] = r.z();
|
||
}
|
||
|
||
void LatticeToRodriguesAndLengths_Hex(const CrystalLattice &latt, double rod[3], double ac[3]) {
|
||
const Coord a = latt.Vec0();
|
||
const Coord b = latt.Vec1();
|
||
const Coord c = latt.Vec2();
|
||
|
||
Eigen::Vector3d A(a[0], a[1], a[2]);
|
||
Eigen::Vector3d B(b[0], b[1], b[2]);
|
||
Eigen::Vector3d C(c[0], c[1], c[2]);
|
||
|
||
const double a_len = A.norm();
|
||
const double b_len = B.norm();
|
||
const double c_len = C.norm();
|
||
|
||
ac[0] = (a_len + b_len) / 2.0;
|
||
ac[1] = (a_len + b_len) / 2.0;
|
||
ac[2] = c_len;
|
||
|
||
Eigen::Vector3d e1;
|
||
Eigen::Vector3d e3;
|
||
|
||
if (a_len > 0.0)
|
||
e1 = A / a_len;
|
||
else
|
||
e1 = Eigen::Vector3d::UnitX();
|
||
|
||
if (c_len > 0.0)
|
||
e3 = C / c_len;
|
||
else
|
||
e3 = Eigen::Vector3d::UnitZ();
|
||
|
||
Eigen::Vector3d e2 = e3.cross(e1);
|
||
if (e2.norm() < 1e-15) {
|
||
e2 = (std::abs(e1.x()) < 0.9) ? Eigen::Vector3d::UnitX().cross(e1)
|
||
: Eigen::Vector3d::UnitY().cross(e1);
|
||
}
|
||
e2.normalize();
|
||
e3 = e1.cross(e2).normalized();
|
||
|
||
Eigen::Matrix3d R;
|
||
R.col(0) = e1;
|
||
R.col(1) = e2;
|
||
R.col(2) = e3;
|
||
|
||
Eigen::AngleAxisd aa(R);
|
||
Eigen::Vector3d r = aa.angle() * aa.axis();
|
||
rod[0] = r.x();
|
||
rod[1] = r.y();
|
||
rod[2] = r.z();
|
||
}
|
||
|
||
CrystalLattice AngleAxisAndLengthsToLattice(const double rod[3], const double lengths[3], bool hex) {
|
||
const Eigen::Vector3d r(rod[0], rod[1], rod[2]);
|
||
const double angle = r.norm();
|
||
Eigen::Matrix3d R = Eigen::Matrix3d::Identity();
|
||
if (angle > 0.0)
|
||
R = Eigen::AngleAxisd(angle, r / angle).toRotationMatrix();
|
||
const Eigen::DiagonalMatrix<double, 3> D(lengths[0], lengths[1], lengths[2]);
|
||
|
||
Eigen::Matrix3d Bhex = Eigen::Matrix3d::Identity();
|
||
if (hex) {
|
||
Bhex(0, 1) = -1/2.0;
|
||
Bhex(1, 1) = sqrt(3)/2;
|
||
}
|
||
|
||
auto latt = R * D * Bhex;
|
||
return CrystalLattice(Coord(latt(0, 0), latt(1, 0), latt(2, 0)),
|
||
Coord(latt(0, 1), latt(1, 1), latt(2, 1)),
|
||
Coord(latt(0, 2), latt(1, 2), latt(2, 2)));
|
||
}
|
||
|
||
bool XtalOptimizerInternal(XtalOptimizerData &data,
|
||
const std::vector<SpotToSave> &spots,
|
||
const float tolerance) {
|
||
try {
|
||
if (data.centering != 'P')
|
||
data.latt = data.latt.ToPrimitive(data.centering);
|
||
|
||
if (data.crystal_system == gemmi::CrystalSystem::Tetragonal
|
||
|| data.crystal_system == gemmi::CrystalSystem::Hexagonal)
|
||
data.latt.ReorderABEqual();
|
||
|
||
Coord vec0 = data.latt.Vec0();
|
||
Coord vec1 = data.latt.Vec1();
|
||
Coord vec2 = data.latt.Vec2();
|
||
|
||
// Initial guess for the parameters
|
||
double corr_x = 0;
|
||
double corr_y = 0;
|
||
|
||
ceres::Problem problem;
|
||
|
||
double latt_vec0[3], latt_vec1[3], latt_vec2[3];
|
||
|
||
switch (data.crystal_system) {
|
||
case gemmi::CrystalSystem::Orthorhombic:
|
||
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
|
||
break;
|
||
case gemmi::CrystalSystem::Tetragonal:
|
||
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
|
||
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1]) / 2.0;
|
||
break;
|
||
case gemmi::CrystalSystem::Cubic:
|
||
LatticeToRodriguesAndLengths_GS(data.latt, latt_vec0, latt_vec1);
|
||
latt_vec1[0] = (latt_vec1[0] + latt_vec1[1] + latt_vec1[2]) / 3.0;
|
||
break;
|
||
case gemmi::CrystalSystem::Hexagonal:
|
||
LatticeToRodriguesAndLengths_Hex(data.latt, latt_vec0, latt_vec1);
|
||
break;
|
||
default:
|
||
latt_vec0[0] = vec0.x; latt_vec0[1] = vec0.y; latt_vec0[2] = vec0.z;
|
||
latt_vec1[0] = vec1.x; latt_vec1[1] = vec1.y; latt_vec1[2] = vec1.z;
|
||
latt_vec2[0] = vec2.x; latt_vec2[1] = vec2.y; latt_vec2[2] = vec2.z;
|
||
break;
|
||
}
|
||
|
||
// Add residuals for each point
|
||
for (const auto &pt: spots) {
|
||
Coord recip = data.geom.DetectorToRecip(pt.x, pt.y);
|
||
double h_fp = recip * vec0;
|
||
double k_fp = recip * vec1;
|
||
double l_fp = recip * vec2;
|
||
|
||
double h = std::round(h_fp);
|
||
double k = std::round(k_fp);
|
||
double l = std::round(l_fp);
|
||
|
||
double norm_sq = (h - h_fp) * (h - h_fp) + (k -k_fp) * (k - k_fp) + (l - l_fp) * (l - l_fp);
|
||
|
||
if (norm_sq > tolerance * tolerance)
|
||
continue;
|
||
|
||
problem.AddResidualBlock(
|
||
new ceres::AutoDiffCostFunction<XtalResidual, 3, 1, 1, 3, 3, 3>(
|
||
new XtalResidual(pt.x, pt.y,
|
||
data.geom.GetBeamX_pxl(),
|
||
data.geom.GetBeamY_pxl(),
|
||
data.geom.GetWavelength_A(),
|
||
data.geom.GetPixelSize_mm(),
|
||
data.geom.GetDetectorDistance_mm(),
|
||
data.geom.GetPoniRot1_rad(),
|
||
data.geom.GetPoniRot2_rad(),
|
||
h, k, l,
|
||
data.crystal_system)),
|
||
nullptr,
|
||
&corr_x,
|
||
&corr_y,
|
||
latt_vec0,
|
||
latt_vec1,
|
||
latt_vec2
|
||
);
|
||
}
|
||
|
||
if (!data.refine_beam_center) {
|
||
problem.SetParameterBlockConstant(&corr_x);
|
||
problem.SetParameterBlockConstant(&corr_y);
|
||
}
|
||
|
||
if (problem.NumResidualBlocks() < data.min_spots)
|
||
return false;
|
||
|
||
// Configure solver
|
||
ceres::Solver::Options options;
|
||
options.linear_solver_type = ceres::DENSE_QR;
|
||
options.minimizer_progress_to_stdout = false;
|
||
options.logging_type = ceres::LoggingType::SILENT;
|
||
ceres::Solver::Summary summary;
|
||
|
||
// Run optimization
|
||
ceres::Solve(options, &problem, &summary);
|
||
|
||
if (data.refine_beam_center) {
|
||
data.beam_corr_x = corr_x;
|
||
data.beam_corr_y = corr_y;
|
||
data.geom.BeamX_pxl(data.geom.GetBeamX_pxl() + corr_x)
|
||
.BeamY_pxl(data.geom.GetBeamY_pxl() + corr_y);
|
||
}
|
||
|
||
if (data.crystal_system == gemmi::CrystalSystem::Orthorhombic)
|
||
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
|
||
else if (data.crystal_system == gemmi::CrystalSystem::Tetragonal) {
|
||
latt_vec1[1] = latt_vec1[0];
|
||
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
|
||
} else if (data.crystal_system == gemmi::CrystalSystem::Cubic) {
|
||
latt_vec1[1] = latt_vec1[0];
|
||
latt_vec1[2] = latt_vec1[0];
|
||
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, false);
|
||
} else if (data.crystal_system == gemmi::CrystalSystem::Hexagonal) {
|
||
latt_vec1[1] = latt_vec1[0];
|
||
data.latt = AngleAxisAndLengthsToLattice(latt_vec0, latt_vec1, true);
|
||
} else {
|
||
data.latt = CrystalLattice(Coord(latt_vec0[0], latt_vec0[1], latt_vec0[2]),
|
||
Coord(latt_vec1[0], latt_vec1[1], latt_vec1[2]),
|
||
Coord(latt_vec2[0], latt_vec2[1], latt_vec2[2]));
|
||
}
|
||
|
||
if (data.centering != 'P')
|
||
data.latt = data.latt.FromPrimitive(data.centering);
|
||
|
||
return true;
|
||
} catch (...) {
|
||
// Convergence problems, likely not updated
|
||
return false;
|
||
}
|
||
}
|
||
|
||
bool XtalOptimizer(XtalOptimizerData &data, const std::vector<SpotToSave> &spots) {
|
||
// 5-pass algorithm
|
||
float tolerance = 0.3f;
|
||
|
||
if (!XtalOptimizerInternal(data, spots, tolerance))
|
||
return false;
|
||
for (int i = 0; i < 5; i++) {
|
||
tolerance *= 0.8;
|
||
XtalOptimizerInternal(data, spots, tolerance);
|
||
}
|
||
return true;
|
||
}
|