PixelRefine: Make it faster by doing one cell calculation per shoe-box
Build Packages / build:rpm (rocky8_nocuda) (push) Successful in 12m30s
Build Packages / build:rpm (rocky9_nocuda) (push) Successful in 13m56s
Build Packages / build:rpm (ubuntu2204_nocuda) (push) Successful in 12m10s
Build Packages / build:rpm (ubuntu2404_nocuda) (push) Successful in 10m24s
Build Packages / build:rpm (rocky8_sls9) (push) Successful in 13m6s
Build Packages / build:rpm (rocky9_sls9) (push) Successful in 14m4s
Build Packages / build:rpm (rocky8) (push) Successful in 13m6s
Build Packages / build:rpm (rocky9) (push) Successful in 11m44s
Build Packages / build:rpm (ubuntu2204) (push) Successful in 10m59s
Build Packages / build:rpm (ubuntu2404) (push) Successful in 10m9s
Build Packages / DIALS test (push) Successful in 12m4s
Build Packages / XDS test (durin plugin) (push) Successful in 8m50s
Build Packages / XDS test (JFJoch plugin) (push) Successful in 8m34s
Build Packages / XDS test (neggia plugin) (push) Successful in 8m28s
Build Packages / Generate python client (push) Successful in 17s
Build Packages / Build documentation (push) Successful in 36s
Build Packages / Create release (push) Skipped
Build Packages / Unit tests (push) Failing after 57m12s

This commit is contained in:
2026-06-08 22:45:22 +02:00
parent 05711a1077
commit e8a9b1840d
+260 -152
View File
@@ -40,6 +40,139 @@ double SafeInv(double x, double fallback) {
return 1.0 / x;
}
// Per-pixel: map a detector pixel through the current geometry into the
// reference reciprocal frame. Cheap (a few trig + one rotation); depends on the
// pixel and the detector geometry, not on the lattice.
template<typename T>
void ObservedRecip(const T *beam, const T *distance_mm, const T *detector_rot,
const T *rotation_axis, double obs_x, double obs_y,
double pixel_size, double inv_lambda, double angle_rad,
Eigen::Matrix<T, 3, 1> &e_obs_recip) {
// PyFAI convention (left-handed for rot1/rot2): rot3 = 0 assumed.
const T c1 = ceres::cos(detector_rot[0]);
const T s1 = ceres::sin(detector_rot[0]);
const T c2 = ceres::cos(detector_rot[1]);
const T s2 = ceres::sin(detector_rot[1]);
const T det_x = (T(obs_x) - beam[0]) * T(pixel_size);
const T det_y = (T(obs_y) - beam[1]) * T(pixel_size);
const T det_z = T(distance_mm[0]);
const T t1_x = c1 * det_x + s1 * det_z;
const T t1_y = det_y;
const T t1_z = -s1 * det_x + c1 * det_z;
const T x = t1_x;
const T y = c2 * t1_y + s2 * t1_z;
const T z = -s2 * t1_y + c2 * t1_z;
const T inv_norm = T(1) / ceres::sqrt(x * x + y * y + z * z);
T recip_raw[3] = {
x * inv_norm * T(inv_lambda),
y * inv_norm * T(inv_lambda),
(z * inv_norm - T(1.0)) * T(inv_lambda)
};
const T aa_back[3] = {
T(angle_rad) * rotation_axis[0],
T(angle_rad) * rotation_axis[1],
T(angle_rad) * rotation_axis[2]
};
T recip_obs[3];
ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs);
e_obs_recip = Eigen::Matrix<T, 3, 1>(recip_obs[0], recip_obs[1], recip_obs[2]);
}
// Per-reflection: predicted node g_hkl, |g_hkl|^2, and the Ewald-sphere normal.
// This is the expensive part (symmetry-aware B matrix, three rotations, cross
// products) - it depends only on the lattice (p0,p1,p2) and hkl, so for a whole
// shoebox it can be computed once. Convention identical to XtalOptimizer.
template<typename T>
bool PredictedNode(const T *p0, const T *p1, const T *p2,
double exp_h, double exp_k, double exp_l,
gemmi::CrystalSystem symmetry, double inv_lambda,
Eigen::Matrix<T, 3, 1> &e_pred_recip,
Eigen::Matrix<T, 3, 1> &n_radial, T &q_sq) {
Eigen::Matrix<T, 3, 1> e_uc_len = Eigen::Matrix<T, 3, 1>::Zero();
Eigen::Matrix<T, 3, 3> Bmat = Eigen::Matrix<T, 3, 3>::Identity();
if (symmetry == gemmi::CrystalSystem::Hexagonal) {
e_uc_len << p1[0], p1[0], p1[2];
Bmat(0, 1) = T(-0.5);
Bmat(1, 1) = T(sqrt(3.0) / 2.0);
} else if (symmetry == gemmi::CrystalSystem::Orthorhombic) {
e_uc_len << p1[0], p1[1], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Tetragonal) {
e_uc_len << p1[0], p1[0], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Cubic) {
e_uc_len << p1[0], p1[0], p1[0];
} else if (symmetry == gemmi::CrystalSystem::Monoclinic) {
e_uc_len << p1[0], p1[1], p1[2];
Bmat(0, 2) = ceres::cos(p2[0]);
Bmat(2, 2) = ceres::sin(p2[0]);
} else {
const T ca = ceres::cos(p2[0]);
const T cb = ceres::cos(p2[1]);
const T cg = ceres::cos(p2[2]);
const T sg = ceres::sin(p2[2]);
e_uc_len << p1[0], p1[1], p1[2];
Bmat(0, 1) = cg;
Bmat(1, 1) = sg;
const T cx = cb;
const T cy = (ca - cb * cg) / sg;
const T v = T(1) - cx * cx - cy * cy;
const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0);
Bmat(0, 2) = cx;
Bmat(1, 2) = cy;
Bmat(2, 2) = cz;
}
const T L0 = e_uc_len[0];
const T L1 = e_uc_len[1];
const T L2 = e_uc_len[2];
T col0_unrot[3] = {Bmat(0, 0) * L0, Bmat(1, 0) * L0, Bmat(2, 0) * L0};
T col1_unrot[3] = {Bmat(0, 1) * L1, Bmat(1, 1) * L1, Bmat(2, 1) * L1};
T col2_unrot[3] = {Bmat(0, 2) * L2, Bmat(1, 2) * L2, Bmat(2, 2) * L2};
T col0_rot[3], col1_rot[3], col2_rot[3];
ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot);
ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot);
ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot);
const Eigen::Matrix<T, 3, 1> A(col0_rot[0], col0_rot[1], col0_rot[2]);
const Eigen::Matrix<T, 3, 1> Bv(col1_rot[0], col1_rot[1], col1_rot[2]);
const Eigen::Matrix<T, 3, 1> C(col2_rot[0], col2_rot[1], col2_rot[2]);
const Eigen::Matrix<T, 3, 1> BxC = Bv.cross(C);
const Eigen::Matrix<T, 3, 1> CxA = C.cross(A);
const Eigen::Matrix<T, 3, 1> AxB = A.cross(Bv);
const T Vol = A.dot(BxC);
if (ceres::abs(Vol) < T(1e-12))
return false;
const T invV = T(1) / Vol;
e_pred_recip = (BxC * T(exp_h) + CxA * T(exp_k) + AxB * T(exp_l)) * invV;
q_sq = e_pred_recip.squaredNorm();
// Ewald sphere centre at -k_i = (0,0,-inv_lambda); radial normal at g_hkl.
const Eigen::Matrix<T, 3, 1> S_pred(
e_pred_recip[0],
e_pred_recip[1],
e_pred_recip[2] + T(inv_lambda));
const T S_pred_norm = S_pred.norm();
if (S_pred_norm < T(1e-10))
return false;
n_radial = S_pred / S_pred_norm;
return true;
}
} // namespace
// ---------------------------------------------------------------------------
@@ -100,136 +233,18 @@ struct PixelResidual {
const T *const p1,
const T *const p2,
T &q_sq, T &eps_radial, T &eps_tang_sq) const {
// PyFAI convention (left-handed for rot1/rot2):
// poni_rot = Rz(-rot3) * Rx(-rot2) * Ry(+rot1), rot3 = 0 assumed.
const T rot1 = detector_rot[0];
const T rot2 = detector_rot[1];
Eigen::Matrix<T, 3, 1> e_obs_recip;
ObservedRecip(beam, distance_mm, detector_rot, rotation_axis,
obs_x, obs_y, pixel_size, inv_lambda, angle_rad, e_obs_recip);
const T c1 = ceres::cos(rot1);
const T s1 = ceres::sin(rot1);
const T c2 = ceres::cos(rot2);
const T s2 = ceres::sin(rot2);
const T det_x = (T(obs_x) - beam[0]) * T(pixel_size);
const T det_y = (T(obs_y) - beam[1]) * T(pixel_size);
const T det_z = T(distance_mm[0]);
const T t1_x = c1 * det_x + s1 * det_z;
const T t1_y = det_y;
const T t1_z = -s1 * det_x + c1 * det_z;
const T x = t1_x;
const T y = c2 * t1_y + s2 * t1_z;
const T z = -s2 * t1_y + c2 * t1_z;
const T lab_norm = ceres::sqrt(x * x + y * y + z * z);
const T inv_norm = T(1) / lab_norm;
T recip_raw[3];
recip_raw[0] = x * inv_norm * T(inv_lambda);
recip_raw[1] = y * inv_norm * T(inv_lambda);
recip_raw[2] = (z * inv_norm - T(1.0)) * T(inv_lambda);
// Goniometer "back-to-start" rotation: image frame -> reference frame.
const T aa_back[3] = {
T(angle_rad) * rotation_axis[0],
T(angle_rad) * rotation_axis[1],
T(angle_rad) * rotation_axis[2]
};
T recip_obs[3];
ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs);
const Eigen::Matrix<T, 3, 1> e_obs_recip(recip_obs[0], recip_obs[1], recip_obs[2]);
// Build cell lengths and the (unit) B matrix from the symmetry-specific
// parametrization (identical convention to XtalOptimizer::XtalResidual).
Eigen::Matrix<T, 3, 1> e_uc_len = Eigen::Matrix<T, 3, 1>::Zero();
Eigen::Matrix<T, 3, 3> Bmat = Eigen::Matrix<T, 3, 3>::Identity();
if (symmetry == gemmi::CrystalSystem::Hexagonal) {
e_uc_len << p1[0], p1[0], p1[2];
Bmat(0, 1) = T(-0.5);
Bmat(1, 1) = T(sqrt(3.0) / 2.0);
} else if (symmetry == gemmi::CrystalSystem::Orthorhombic) {
e_uc_len << p1[0], p1[1], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Tetragonal) {
e_uc_len << p1[0], p1[0], p1[2];
} else if (symmetry == gemmi::CrystalSystem::Cubic) {
e_uc_len << p1[0], p1[0], p1[0];
} else if (symmetry == gemmi::CrystalSystem::Monoclinic) {
e_uc_len << p1[0], p1[1], p1[2];
Bmat(0, 2) = ceres::cos(p2[0]);
Bmat(2, 2) = ceres::sin(p2[0]);
} else {
const T ca = ceres::cos(p2[0]);
const T cb = ceres::cos(p2[1]);
const T cg = ceres::cos(p2[2]);
const T sg = ceres::sin(p2[2]);
e_uc_len << p1[0], p1[1], p1[2];
Bmat(0, 1) = cg;
Bmat(1, 1) = sg;
const T cx = cb;
const T cy = (ca - cb * cg) / sg;
const T v = T(1) - cx * cx - cy * cy;
const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0);
Bmat(0, 2) = cx;
Bmat(1, 2) = cy;
Bmat(2, 2) = cz;
}
const T L0 = e_uc_len[0];
const T L1 = e_uc_len[1];
const T L2 = e_uc_len[2];
T col0_unrot[3] = {Bmat(0, 0) * L0, Bmat(1, 0) * L0, Bmat(2, 0) * L0};
T col1_unrot[3] = {Bmat(0, 1) * L1, Bmat(1, 1) * L1, Bmat(2, 1) * L1};
T col2_unrot[3] = {Bmat(0, 2) * L2, Bmat(1, 2) * L2, Bmat(2, 2) * L2};
T col0_rot[3], col1_rot[3], col2_rot[3];
ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot);
ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot);
ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot);
const Eigen::Matrix<T, 3, 1> A(col0_rot[0], col0_rot[1], col0_rot[2]);
const Eigen::Matrix<T, 3, 1> Bv(col1_rot[0], col1_rot[1], col1_rot[2]);
const Eigen::Matrix<T, 3, 1> C(col2_rot[0], col2_rot[1], col2_rot[2]);
const Eigen::Matrix<T, 3, 1> BxC = Bv.cross(C);
const Eigen::Matrix<T, 3, 1> CxA = C.cross(A);
const Eigen::Matrix<T, 3, 1> AxB = A.cross(Bv);
const T V = A.dot(BxC);
if (ceres::abs(V) < T(1e-12))
return false;
const T invV = T(1) / V;
const Eigen::Matrix<T, 3, 1> Astar = BxC * invV;
const Eigen::Matrix<T, 3, 1> Bstar = CxA * invV;
const Eigen::Matrix<T, 3, 1> Cstar = AxB * invV;
const Eigen::Matrix<T, 3, 1> e_pred_recip =
Astar * T(exp_h) + Bstar * T(exp_k) + Cstar * T(exp_l);
q_sq = e_pred_recip.squaredNorm();
// Ewald sphere centre at -k_i = (0,0,-inv_lambda); radial normal at g_hkl.
const Eigen::Matrix<T, 3, 1> S_pred(
e_pred_recip[0],
e_pred_recip[1],
e_pred_recip[2] + T(inv_lambda));
const T S_pred_norm = S_pred.norm();
if (S_pred_norm < T(1e-10))
Eigen::Matrix<T, 3, 1> e_pred_recip, n_radial;
if (!PredictedNode(p0, p1, p2, exp_h, exp_k, exp_l, symmetry, inv_lambda,
e_pred_recip, n_radial, q_sq))
return false;
const Eigen::Matrix<T, 3, 1> n_radial = S_pred / S_pred_norm;
const Eigen::Matrix<T, 3, 1> delta_q = e_obs_recip - e_pred_recip;
eps_radial = delta_q.dot(n_radial);
const Eigen::Matrix<T, 3, 1> dq_tang = delta_q - eps_radial * n_radial;
eps_tang_sq = dq_tang.squaredNorm();
eps_tang_sq = (delta_q - eps_radial * n_radial).squaredNorm();
return true;
}
@@ -250,27 +265,25 @@ struct PixelResidual {
const T B_term = ceres::exp(-B[0] * q_sq / T(4.0));
// Full 3D reciprocal-space spot density modelled as a separable Gaussian,
// normalized so that its integral is 1 (intensity-conserving):
// radial: g_r(e) = exp(-e^2/R0^2) / (sqrt(pi) R0) [1/A^-1]
// tangential: g_t(e) = exp(-|e|^2/R1^2) / (pi R1^2) [1/A^-2]
// The detector pixel captures the fraction g_t * A_recip of the tangential
// profile (A_recip = reciprocal area the pixel subtends; sum over shoebox
// ~ 1). The radial factor is the still-image partiality: how far this
// reflection sits from the Ewald sphere.
// Separable Gaussian spot model:
// radial P_r(e) = exp(-e^2/R0_eff^2) (peak-normalized, in (0,1])
// tangent g_t(e) = exp(-|e|^2/R1^2) / (pi R1^2) [1/A^-2]
// The pixel captures the fraction g_t * A_recip of the tangential profile
// (A_recip = reciprocal area the pixel subtends; sum over shoebox ~ 1).
// The radial factor is the still-image partiality (how far the reflection
// sits from the Ewald sphere); the overall scale is carried by the free G.
//
// Caveat: a still samples the radial direction at a single offset, so the
// sqrt(pi) R0 normalization makes g_r a density (1/A^-1) rather than a
// dimensionless fraction. The leftover dimensional factor is absorbed by
// the free scale G. The energy-bandwidth contribution to the radial width
// is folded in here via R_bw_sq (beam divergence is still TODO).
// IMPORTANT: the radial factor MUST use the same convention here as the
// extraction's `partiality` (peak-normalized), otherwise image_scale_corr
// = 1/(partiality*G*B) does not invert the model and a leftover, R0_eff-
// dependent (hence resolution-dependent) factor biases the intensities.
// R0_eff folds in the energy-bandwidth broadening via R_bw_sq.
const T R0_eff_sq = R[0] * R[0] + T(R_bw_sq);
const T g_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq)
/ (ceres::sqrt(T(M_PI)) * ceres::sqrt(R0_eff_sq));
const T P_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq);
const T P_tang = T(A_recip) * ceres::exp(-eps_tang_sq / (R[1] * R[1]))
/ (T(M_PI) * R[1] * R[1]);
const T signal = scale_factor[0] * T(Itrue) * B_term * g_radial * P_tang;
const T signal = scale_factor[0] * T(Itrue) * B_term * P_radial * P_tang;
Ipred = signal + T(Ibkg);
return true;
}
@@ -306,6 +319,85 @@ struct PixelResidual {
gemmi::CrystalSystem symmetry;
};
// ---------------------------------------------------------------------------
// Per-shoebox cost functor
//
// One residual block per reflection emitting N residuals (one per shoebox pixel).
// The expensive per-reflection geometry (PredictedNode: symmetry-aware B matrix,
// three rotations, cross products) is computed ONCE; only the cheap per-pixel
// ObservedRecip + Gaussian profile run in the pixel loop. This is identical in
// value to the old one-block-per-pixel formulation but ~(pixels-per-shoebox)x
// fewer evaluations of the costly node computation. Uses the same shared helpers
// (and hence the same conventions) as PixelResidual.
// ---------------------------------------------------------------------------
struct ShoeboxResidual {
ShoeboxResidual(const ReflGroup &g, double lambda, double pixel_size,
gemmi::CrystalSystem symmetry)
: pixels(g.pixels), Itrue(g.Itrue), R_bw_sq(g.R_bw_sq),
exp_h(g.h), exp_k(g.k), exp_l(g.l),
inv_lambda(1.0 / lambda), pixel_size(pixel_size),
angle_rad(g.pixels.empty() ? 0.0 : g.pixels.front().angle_rad),
symmetry(symmetry) {}
template<typename T>
bool operator()(const T *const *params, T *residual) const {
// Parameter blocks (order matches AddParameterBlock in Run):
// 0 beam[2] 1 distance[1] 2 detector_rot[2] 3 rotation_axis[3]
// 4 p0[3] 5 p1[3] 6 p2[3] 7 scale[1] 8 B[1] 9 R[2]
const T *beam = params[0];
const T *distance_mm = params[1];
const T *detector_rot = params[2];
const T *rotation_axis = params[3];
const T *p0 = params[4];
const T *p1 = params[5];
const T *p2 = params[6];
const T *scale_factor = params[7];
const T *B = params[8];
const T *R = params[9];
if (R[0] < T(1e-10) || R[1] < T(1e-10))
return false;
// --- per-reflection: computed once ---------------------------------
Eigen::Matrix<T, 3, 1> e_pred_recip, n_radial;
T q_sq;
if (!PredictedNode(p0, p1, p2, exp_h, exp_k, exp_l, symmetry, inv_lambda,
e_pred_recip, n_radial, q_sq))
return false;
const T B_term = ceres::exp(-B[0] * q_sq / T(4.0));
const T R0_eff_sq = R[0] * R[0] + T(R_bw_sq);
// --- per-pixel loop -------------------------------------------------
for (size_t i = 0; i < pixels.size(); ++i) {
const PixelObs &obs = pixels[i];
Eigen::Matrix<T, 3, 1> e_obs_recip;
ObservedRecip(beam, distance_mm, detector_rot, rotation_axis,
obs.x, obs.y, pixel_size, inv_lambda, angle_rad, e_obs_recip);
const Eigen::Matrix<T, 3, 1> delta_q = e_obs_recip - e_pred_recip;
const T eps_radial = delta_q.dot(n_radial);
const T eps_tang_sq = (delta_q - eps_radial * n_radial).squaredNorm();
const T P_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq);
const T P_tang = T(obs.A_recip) * ceres::exp(-eps_tang_sq / (R[1] * R[1]))
/ (T(M_PI) * R[1] * R[1]);
const T signal = scale_factor[0] * T(Itrue) * B_term * P_radial * P_tang;
const T Ipred = signal + T(obs.Ibkg);
residual[i] = (Ipred - T(obs.Iobs)) * T(obs.weight);
}
return true;
}
std::vector<PixelObs> pixels;
const double Itrue, R_bw_sq;
const double exp_h, exp_k, exp_l;
const double inv_lambda, pixel_size, angle_rad;
gemmi::CrystalSystem symmetry;
};
PixelRefine::PixelRefine(const DiffractionExperiment &experiment,
const AzimuthalIntegrationMapping &mapping,
const std::vector<MergedReflection> &reference)
@@ -526,20 +618,36 @@ void PixelRefine::Run(const T *image,
latt_vec0, latt_vec1, latt_vec2);
// ---- 4. Build the problem ---------------------------------------------
// One residual block per shoebox (N residuals), so the expensive
// per-reflection node geometry is evaluated once per reflection instead
// of once per pixel.
ceres::Problem problem;
size_t residual_pixels = 0;
for (const auto &g : groups) {
for (const auto &obs : g.pixels) {
auto *cost = new ceres::AutoDiffCostFunction<
PixelResidual, 1, 2, 1, 2, 3, 3, 3, 3, 1, 1, 2>(
new PixelResidual(obs, g.Itrue, lambda, pixel_size,
g.h, g.k, g.l, g.R_bw_sq, data.crystal_system));
problem.AddResidualBlock(cost, new ceres::HuberLoss(3.0),
beam, &dist_mm, detector_rot, rot_vec,
latt_vec0, latt_vec1, latt_vec2,
&data.scale_factor, &data.B_factor, data.R);
}
auto *cost = new ceres::DynamicAutoDiffCostFunction<ShoeboxResidual>(
new ShoeboxResidual(g, lambda, pixel_size, data.crystal_system));
cost->AddParameterBlock(2); // beam
cost->AddParameterBlock(1); // distance
cost->AddParameterBlock(2); // detector_rot
cost->AddParameterBlock(3); // rotation_axis
cost->AddParameterBlock(3); // p0 (orientation)
cost->AddParameterBlock(3); // p1 (lengths)
cost->AddParameterBlock(3); // p2 (angles)
cost->AddParameterBlock(1); // scale G
cost->AddParameterBlock(1); // B
cost->AddParameterBlock(2); // R
cost->SetNumResiduals(static_cast<int>(g.pixels.size()));
// No robust loss here: a per-block (whole-shoebox) Huber would act on
// the sum of ~N squared residuals and mis-scale, unlike the previous
// per-pixel Huber. Per-pixel sigma weighting is retained; per-pixel
// outlier rejection (zingers) is a TODO if needed.
problem.AddResidualBlock(cost, nullptr,
beam, &dist_mm, detector_rot, rot_vec,
latt_vec0, latt_vec1, latt_vec2,
&data.scale_factor, &data.B_factor, data.R);
residual_pixels += g.pixels.size();
}
data.residual_count = problem.NumResidualBlocks();
data.residual_count = residual_pixels;
// ---- 5. Constrain / bound parameter blocks ----------------------------
if (!data.refine_orientation)