From e8a9b1840d2a01a8aea55a5f09a5d35436643d22 Mon Sep 17 00:00:00 2001 From: leonarski_f Date: Mon, 8 Jun 2026 22:45:22 +0200 Subject: [PATCH] PixelRefine: Make it faster by doing one cell calculation per shoe-box --- .../pixel_refinement/PixelRefine.cpp | 412 +++++++++++------- 1 file changed, 260 insertions(+), 152 deletions(-) diff --git a/image_analysis/pixel_refinement/PixelRefine.cpp b/image_analysis/pixel_refinement/PixelRefine.cpp index cf64c993..3b6ccdfe 100644 --- a/image_analysis/pixel_refinement/PixelRefine.cpp +++ b/image_analysis/pixel_refinement/PixelRefine.cpp @@ -40,6 +40,139 @@ double SafeInv(double x, double fallback) { return 1.0 / x; } +// Per-pixel: map a detector pixel through the current geometry into the +// reference reciprocal frame. Cheap (a few trig + one rotation); depends on the +// pixel and the detector geometry, not on the lattice. +template +void ObservedRecip(const T *beam, const T *distance_mm, const T *detector_rot, + const T *rotation_axis, double obs_x, double obs_y, + double pixel_size, double inv_lambda, double angle_rad, + Eigen::Matrix &e_obs_recip) { + // PyFAI convention (left-handed for rot1/rot2): rot3 = 0 assumed. + const T c1 = ceres::cos(detector_rot[0]); + const T s1 = ceres::sin(detector_rot[0]); + const T c2 = ceres::cos(detector_rot[1]); + const T s2 = ceres::sin(detector_rot[1]); + + const T det_x = (T(obs_x) - beam[0]) * T(pixel_size); + const T det_y = (T(obs_y) - beam[1]) * T(pixel_size); + const T det_z = T(distance_mm[0]); + + const T t1_x = c1 * det_x + s1 * det_z; + const T t1_y = det_y; + const T t1_z = -s1 * det_x + c1 * det_z; + + const T x = t1_x; + const T y = c2 * t1_y + s2 * t1_z; + const T z = -s2 * t1_y + c2 * t1_z; + + const T inv_norm = T(1) / ceres::sqrt(x * x + y * y + z * z); + + T recip_raw[3] = { + x * inv_norm * T(inv_lambda), + y * inv_norm * T(inv_lambda), + (z * inv_norm - T(1.0)) * T(inv_lambda) + }; + const T aa_back[3] = { + T(angle_rad) * rotation_axis[0], + T(angle_rad) * rotation_axis[1], + T(angle_rad) * rotation_axis[2] + }; + T recip_obs[3]; + ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs); + e_obs_recip = Eigen::Matrix(recip_obs[0], recip_obs[1], recip_obs[2]); +} + +// Per-reflection: predicted node g_hkl, |g_hkl|^2, and the Ewald-sphere normal. +// This is the expensive part (symmetry-aware B matrix, three rotations, cross +// products) - it depends only on the lattice (p0,p1,p2) and hkl, so for a whole +// shoebox it can be computed once. Convention identical to XtalOptimizer. +template +bool PredictedNode(const T *p0, const T *p1, const T *p2, + double exp_h, double exp_k, double exp_l, + gemmi::CrystalSystem symmetry, double inv_lambda, + Eigen::Matrix &e_pred_recip, + Eigen::Matrix &n_radial, T &q_sq) { + Eigen::Matrix e_uc_len = Eigen::Matrix::Zero(); + Eigen::Matrix Bmat = Eigen::Matrix::Identity(); + + if (symmetry == gemmi::CrystalSystem::Hexagonal) { + e_uc_len << p1[0], p1[0], p1[2]; + Bmat(0, 1) = T(-0.5); + Bmat(1, 1) = T(sqrt(3.0) / 2.0); + } else if (symmetry == gemmi::CrystalSystem::Orthorhombic) { + e_uc_len << p1[0], p1[1], p1[2]; + } else if (symmetry == gemmi::CrystalSystem::Tetragonal) { + e_uc_len << p1[0], p1[0], p1[2]; + } else if (symmetry == gemmi::CrystalSystem::Cubic) { + e_uc_len << p1[0], p1[0], p1[0]; + } else if (symmetry == gemmi::CrystalSystem::Monoclinic) { + e_uc_len << p1[0], p1[1], p1[2]; + Bmat(0, 2) = ceres::cos(p2[0]); + Bmat(2, 2) = ceres::sin(p2[0]); + } else { + const T ca = ceres::cos(p2[0]); + const T cb = ceres::cos(p2[1]); + const T cg = ceres::cos(p2[2]); + const T sg = ceres::sin(p2[2]); + + e_uc_len << p1[0], p1[1], p1[2]; + + Bmat(0, 1) = cg; + Bmat(1, 1) = sg; + + const T cx = cb; + const T cy = (ca - cb * cg) / sg; + const T v = T(1) - cx * cx - cy * cy; + const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0); + + Bmat(0, 2) = cx; + Bmat(1, 2) = cy; + Bmat(2, 2) = cz; + } + + const T L0 = e_uc_len[0]; + const T L1 = e_uc_len[1]; + const T L2 = e_uc_len[2]; + + T col0_unrot[3] = {Bmat(0, 0) * L0, Bmat(1, 0) * L0, Bmat(2, 0) * L0}; + T col1_unrot[3] = {Bmat(0, 1) * L1, Bmat(1, 1) * L1, Bmat(2, 1) * L1}; + T col2_unrot[3] = {Bmat(0, 2) * L2, Bmat(1, 2) * L2, Bmat(2, 2) * L2}; + + T col0_rot[3], col1_rot[3], col2_rot[3]; + ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot); + ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot); + ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot); + + const Eigen::Matrix A(col0_rot[0], col0_rot[1], col0_rot[2]); + const Eigen::Matrix Bv(col1_rot[0], col1_rot[1], col1_rot[2]); + const Eigen::Matrix C(col2_rot[0], col2_rot[1], col2_rot[2]); + + const Eigen::Matrix BxC = Bv.cross(C); + const Eigen::Matrix CxA = C.cross(A); + const Eigen::Matrix AxB = A.cross(Bv); + + const T Vol = A.dot(BxC); + if (ceres::abs(Vol) < T(1e-12)) + return false; + const T invV = T(1) / Vol; + + e_pred_recip = (BxC * T(exp_h) + CxA * T(exp_k) + AxB * T(exp_l)) * invV; + q_sq = e_pred_recip.squaredNorm(); + + // Ewald sphere centre at -k_i = (0,0,-inv_lambda); radial normal at g_hkl. + const Eigen::Matrix S_pred( + e_pred_recip[0], + e_pred_recip[1], + e_pred_recip[2] + T(inv_lambda)); + const T S_pred_norm = S_pred.norm(); + if (S_pred_norm < T(1e-10)) + return false; + + n_radial = S_pred / S_pred_norm; + return true; +} + } // namespace // --------------------------------------------------------------------------- @@ -100,136 +233,18 @@ struct PixelResidual { const T *const p1, const T *const p2, T &q_sq, T &eps_radial, T &eps_tang_sq) const { - // PyFAI convention (left-handed for rot1/rot2): - // poni_rot = Rz(-rot3) * Rx(-rot2) * Ry(+rot1), rot3 = 0 assumed. - const T rot1 = detector_rot[0]; - const T rot2 = detector_rot[1]; + Eigen::Matrix e_obs_recip; + ObservedRecip(beam, distance_mm, detector_rot, rotation_axis, + obs_x, obs_y, pixel_size, inv_lambda, angle_rad, e_obs_recip); - const T c1 = ceres::cos(rot1); - const T s1 = ceres::sin(rot1); - const T c2 = ceres::cos(rot2); - const T s2 = ceres::sin(rot2); - - const T det_x = (T(obs_x) - beam[0]) * T(pixel_size); - const T det_y = (T(obs_y) - beam[1]) * T(pixel_size); - const T det_z = T(distance_mm[0]); - - const T t1_x = c1 * det_x + s1 * det_z; - const T t1_y = det_y; - const T t1_z = -s1 * det_x + c1 * det_z; - - const T x = t1_x; - const T y = c2 * t1_y + s2 * t1_z; - const T z = -s2 * t1_y + c2 * t1_z; - - const T lab_norm = ceres::sqrt(x * x + y * y + z * z); - const T inv_norm = T(1) / lab_norm; - - T recip_raw[3]; - recip_raw[0] = x * inv_norm * T(inv_lambda); - recip_raw[1] = y * inv_norm * T(inv_lambda); - recip_raw[2] = (z * inv_norm - T(1.0)) * T(inv_lambda); - - // Goniometer "back-to-start" rotation: image frame -> reference frame. - const T aa_back[3] = { - T(angle_rad) * rotation_axis[0], - T(angle_rad) * rotation_axis[1], - T(angle_rad) * rotation_axis[2] - }; - T recip_obs[3]; - ceres::AngleAxisRotatePoint(aa_back, recip_raw, recip_obs); - const Eigen::Matrix e_obs_recip(recip_obs[0], recip_obs[1], recip_obs[2]); - - // Build cell lengths and the (unit) B matrix from the symmetry-specific - // parametrization (identical convention to XtalOptimizer::XtalResidual). - Eigen::Matrix e_uc_len = Eigen::Matrix::Zero(); - Eigen::Matrix Bmat = Eigen::Matrix::Identity(); - - if (symmetry == gemmi::CrystalSystem::Hexagonal) { - e_uc_len << p1[0], p1[0], p1[2]; - Bmat(0, 1) = T(-0.5); - Bmat(1, 1) = T(sqrt(3.0) / 2.0); - } else if (symmetry == gemmi::CrystalSystem::Orthorhombic) { - e_uc_len << p1[0], p1[1], p1[2]; - } else if (symmetry == gemmi::CrystalSystem::Tetragonal) { - e_uc_len << p1[0], p1[0], p1[2]; - } else if (symmetry == gemmi::CrystalSystem::Cubic) { - e_uc_len << p1[0], p1[0], p1[0]; - } else if (symmetry == gemmi::CrystalSystem::Monoclinic) { - e_uc_len << p1[0], p1[1], p1[2]; - Bmat(0, 2) = ceres::cos(p2[0]); - Bmat(2, 2) = ceres::sin(p2[0]); - } else { - const T ca = ceres::cos(p2[0]); - const T cb = ceres::cos(p2[1]); - const T cg = ceres::cos(p2[2]); - const T sg = ceres::sin(p2[2]); - - e_uc_len << p1[0], p1[1], p1[2]; - - Bmat(0, 1) = cg; - Bmat(1, 1) = sg; - - const T cx = cb; - const T cy = (ca - cb * cg) / sg; - const T v = T(1) - cx * cx - cy * cy; - const T cz = (v >= T(0)) ? ceres::sqrt(v) : T(0); - - Bmat(0, 2) = cx; - Bmat(1, 2) = cy; - Bmat(2, 2) = cz; - } - - const T L0 = e_uc_len[0]; - const T L1 = e_uc_len[1]; - const T L2 = e_uc_len[2]; - - T col0_unrot[3] = {Bmat(0, 0) * L0, Bmat(1, 0) * L0, Bmat(2, 0) * L0}; - T col1_unrot[3] = {Bmat(0, 1) * L1, Bmat(1, 1) * L1, Bmat(2, 1) * L1}; - T col2_unrot[3] = {Bmat(0, 2) * L2, Bmat(1, 2) * L2, Bmat(2, 2) * L2}; - - T col0_rot[3], col1_rot[3], col2_rot[3]; - ceres::AngleAxisRotatePoint(p0, col0_unrot, col0_rot); - ceres::AngleAxisRotatePoint(p0, col1_unrot, col1_rot); - ceres::AngleAxisRotatePoint(p0, col2_unrot, col2_rot); - - const Eigen::Matrix A(col0_rot[0], col0_rot[1], col0_rot[2]); - const Eigen::Matrix Bv(col1_rot[0], col1_rot[1], col1_rot[2]); - const Eigen::Matrix C(col2_rot[0], col2_rot[1], col2_rot[2]); - - const Eigen::Matrix BxC = Bv.cross(C); - const Eigen::Matrix CxA = C.cross(A); - const Eigen::Matrix AxB = A.cross(Bv); - - const T V = A.dot(BxC); - if (ceres::abs(V) < T(1e-12)) - return false; - const T invV = T(1) / V; - - const Eigen::Matrix Astar = BxC * invV; - const Eigen::Matrix Bstar = CxA * invV; - const Eigen::Matrix Cstar = AxB * invV; - - const Eigen::Matrix e_pred_recip = - Astar * T(exp_h) + Bstar * T(exp_k) + Cstar * T(exp_l); - - q_sq = e_pred_recip.squaredNorm(); - - // Ewald sphere centre at -k_i = (0,0,-inv_lambda); radial normal at g_hkl. - const Eigen::Matrix S_pred( - e_pred_recip[0], - e_pred_recip[1], - e_pred_recip[2] + T(inv_lambda)); - const T S_pred_norm = S_pred.norm(); - if (S_pred_norm < T(1e-10)) + Eigen::Matrix e_pred_recip, n_radial; + if (!PredictedNode(p0, p1, p2, exp_h, exp_k, exp_l, symmetry, inv_lambda, + e_pred_recip, n_radial, q_sq)) return false; - const Eigen::Matrix n_radial = S_pred / S_pred_norm; const Eigen::Matrix delta_q = e_obs_recip - e_pred_recip; - eps_radial = delta_q.dot(n_radial); - const Eigen::Matrix dq_tang = delta_q - eps_radial * n_radial; - eps_tang_sq = dq_tang.squaredNorm(); + eps_tang_sq = (delta_q - eps_radial * n_radial).squaredNorm(); return true; } @@ -250,27 +265,25 @@ struct PixelResidual { const T B_term = ceres::exp(-B[0] * q_sq / T(4.0)); - // Full 3D reciprocal-space spot density modelled as a separable Gaussian, - // normalized so that its integral is 1 (intensity-conserving): - // radial: g_r(e) = exp(-e^2/R0^2) / (sqrt(pi) R0) [1/A^-1] - // tangential: g_t(e) = exp(-|e|^2/R1^2) / (pi R1^2) [1/A^-2] - // The detector pixel captures the fraction g_t * A_recip of the tangential - // profile (A_recip = reciprocal area the pixel subtends; sum over shoebox - // ~ 1). The radial factor is the still-image partiality: how far this - // reflection sits from the Ewald sphere. + // Separable Gaussian spot model: + // radial P_r(e) = exp(-e^2/R0_eff^2) (peak-normalized, in (0,1]) + // tangent g_t(e) = exp(-|e|^2/R1^2) / (pi R1^2) [1/A^-2] + // The pixel captures the fraction g_t * A_recip of the tangential profile + // (A_recip = reciprocal area the pixel subtends; sum over shoebox ~ 1). + // The radial factor is the still-image partiality (how far the reflection + // sits from the Ewald sphere); the overall scale is carried by the free G. // - // Caveat: a still samples the radial direction at a single offset, so the - // sqrt(pi) R0 normalization makes g_r a density (1/A^-1) rather than a - // dimensionless fraction. The leftover dimensional factor is absorbed by - // the free scale G. The energy-bandwidth contribution to the radial width - // is folded in here via R_bw_sq (beam divergence is still TODO). + // IMPORTANT: the radial factor MUST use the same convention here as the + // extraction's `partiality` (peak-normalized), otherwise image_scale_corr + // = 1/(partiality*G*B) does not invert the model and a leftover, R0_eff- + // dependent (hence resolution-dependent) factor biases the intensities. + // R0_eff folds in the energy-bandwidth broadening via R_bw_sq. const T R0_eff_sq = R[0] * R[0] + T(R_bw_sq); - const T g_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq) - / (ceres::sqrt(T(M_PI)) * ceres::sqrt(R0_eff_sq)); + const T P_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq); const T P_tang = T(A_recip) * ceres::exp(-eps_tang_sq / (R[1] * R[1])) / (T(M_PI) * R[1] * R[1]); - const T signal = scale_factor[0] * T(Itrue) * B_term * g_radial * P_tang; + const T signal = scale_factor[0] * T(Itrue) * B_term * P_radial * P_tang; Ipred = signal + T(Ibkg); return true; } @@ -306,6 +319,85 @@ struct PixelResidual { gemmi::CrystalSystem symmetry; }; +// --------------------------------------------------------------------------- +// Per-shoebox cost functor +// +// One residual block per reflection emitting N residuals (one per shoebox pixel). +// The expensive per-reflection geometry (PredictedNode: symmetry-aware B matrix, +// three rotations, cross products) is computed ONCE; only the cheap per-pixel +// ObservedRecip + Gaussian profile run in the pixel loop. This is identical in +// value to the old one-block-per-pixel formulation but ~(pixels-per-shoebox)x +// fewer evaluations of the costly node computation. Uses the same shared helpers +// (and hence the same conventions) as PixelResidual. +// --------------------------------------------------------------------------- +struct ShoeboxResidual { + ShoeboxResidual(const ReflGroup &g, double lambda, double pixel_size, + gemmi::CrystalSystem symmetry) + : pixels(g.pixels), Itrue(g.Itrue), R_bw_sq(g.R_bw_sq), + exp_h(g.h), exp_k(g.k), exp_l(g.l), + inv_lambda(1.0 / lambda), pixel_size(pixel_size), + angle_rad(g.pixels.empty() ? 0.0 : g.pixels.front().angle_rad), + symmetry(symmetry) {} + + template + bool operator()(const T *const *params, T *residual) const { + // Parameter blocks (order matches AddParameterBlock in Run): + // 0 beam[2] 1 distance[1] 2 detector_rot[2] 3 rotation_axis[3] + // 4 p0[3] 5 p1[3] 6 p2[3] 7 scale[1] 8 B[1] 9 R[2] + const T *beam = params[0]; + const T *distance_mm = params[1]; + const T *detector_rot = params[2]; + const T *rotation_axis = params[3]; + const T *p0 = params[4]; + const T *p1 = params[5]; + const T *p2 = params[6]; + const T *scale_factor = params[7]; + const T *B = params[8]; + const T *R = params[9]; + + if (R[0] < T(1e-10) || R[1] < T(1e-10)) + return false; + + // --- per-reflection: computed once --------------------------------- + Eigen::Matrix e_pred_recip, n_radial; + T q_sq; + if (!PredictedNode(p0, p1, p2, exp_h, exp_k, exp_l, symmetry, inv_lambda, + e_pred_recip, n_radial, q_sq)) + return false; + + const T B_term = ceres::exp(-B[0] * q_sq / T(4.0)); + const T R0_eff_sq = R[0] * R[0] + T(R_bw_sq); + + // --- per-pixel loop ------------------------------------------------- + for (size_t i = 0; i < pixels.size(); ++i) { + const PixelObs &obs = pixels[i]; + + Eigen::Matrix e_obs_recip; + ObservedRecip(beam, distance_mm, detector_rot, rotation_axis, + obs.x, obs.y, pixel_size, inv_lambda, angle_rad, e_obs_recip); + + const Eigen::Matrix delta_q = e_obs_recip - e_pred_recip; + const T eps_radial = delta_q.dot(n_radial); + const T eps_tang_sq = (delta_q - eps_radial * n_radial).squaredNorm(); + + const T P_radial = ceres::exp(-eps_radial * eps_radial / R0_eff_sq); + const T P_tang = T(obs.A_recip) * ceres::exp(-eps_tang_sq / (R[1] * R[1])) + / (T(M_PI) * R[1] * R[1]); + + const T signal = scale_factor[0] * T(Itrue) * B_term * P_radial * P_tang; + const T Ipred = signal + T(obs.Ibkg); + residual[i] = (Ipred - T(obs.Iobs)) * T(obs.weight); + } + return true; + } + + std::vector pixels; + const double Itrue, R_bw_sq; + const double exp_h, exp_k, exp_l; + const double inv_lambda, pixel_size, angle_rad; + gemmi::CrystalSystem symmetry; +}; + PixelRefine::PixelRefine(const DiffractionExperiment &experiment, const AzimuthalIntegrationMapping &mapping, const std::vector &reference) @@ -526,20 +618,36 @@ void PixelRefine::Run(const T *image, latt_vec0, latt_vec1, latt_vec2); // ---- 4. Build the problem --------------------------------------------- + // One residual block per shoebox (N residuals), so the expensive + // per-reflection node geometry is evaluated once per reflection instead + // of once per pixel. ceres::Problem problem; + size_t residual_pixels = 0; for (const auto &g : groups) { - for (const auto &obs : g.pixels) { - auto *cost = new ceres::AutoDiffCostFunction< - PixelResidual, 1, 2, 1, 2, 3, 3, 3, 3, 1, 1, 2>( - new PixelResidual(obs, g.Itrue, lambda, pixel_size, - g.h, g.k, g.l, g.R_bw_sq, data.crystal_system)); - problem.AddResidualBlock(cost, new ceres::HuberLoss(3.0), - beam, &dist_mm, detector_rot, rot_vec, - latt_vec0, latt_vec1, latt_vec2, - &data.scale_factor, &data.B_factor, data.R); - } + auto *cost = new ceres::DynamicAutoDiffCostFunction( + new ShoeboxResidual(g, lambda, pixel_size, data.crystal_system)); + cost->AddParameterBlock(2); // beam + cost->AddParameterBlock(1); // distance + cost->AddParameterBlock(2); // detector_rot + cost->AddParameterBlock(3); // rotation_axis + cost->AddParameterBlock(3); // p0 (orientation) + cost->AddParameterBlock(3); // p1 (lengths) + cost->AddParameterBlock(3); // p2 (angles) + cost->AddParameterBlock(1); // scale G + cost->AddParameterBlock(1); // B + cost->AddParameterBlock(2); // R + cost->SetNumResiduals(static_cast(g.pixels.size())); + // No robust loss here: a per-block (whole-shoebox) Huber would act on + // the sum of ~N squared residuals and mis-scale, unlike the previous + // per-pixel Huber. Per-pixel sigma weighting is retained; per-pixel + // outlier rejection (zingers) is a TODO if needed. + problem.AddResidualBlock(cost, nullptr, + beam, &dist_mm, detector_rot, rot_vec, + latt_vec0, latt_vec1, latt_vec2, + &data.scale_factor, &data.B_factor, data.R); + residual_pixels += g.pixels.size(); } - data.residual_count = problem.NumResidualBlocks(); + data.residual_count = residual_pixels; // ---- 5. Constrain / bound parameter blocks ---------------------------- if (!data.refine_orientation)