RotationScaleMerge: GPU scale-fulls, fulls kept resident (phase 2 step 2)

Scale the combined fulls (Unity model) on the device so they no longer round-trip
between the combine and the merge: after the GPU combine, build the fulls' per-frame
and per-ASU-group CSRs on the host from just the small key arrays (f_frame/f_group)
with a deterministic counting sort - no GPU stable-sort - then scale in place and
download once.

The four scaling kernels are reused unchanged except FitPerFrameGKernel, which gains
an optional `perm` argument (null for the partials, whose arrays are already
frame-contiguous; a frame-grouping permutation for the emit-ordered fulls) so the
fulls are scaled without a physical reorder. The Unity model falls out of giving the
fulls all-ones partiality/rlp/zeta (coeff = mean), so no other kernel changes and the
committed phase-1 partial-scaling path is bit-identical (perm == null -> idx == i).

Validated across the rotation battery (JFJOCH_RSM_GPU_COMBINE=1): all 15 deterministic
crystals stay run-to-run deterministic and their merged output is bit-identical to the
CPU path (SG/ISa/CC1.2/completeness). The lone exception is EP_cs_01-24 (CC1/2 2%,
R_meas 379% - unindexable noise): merged intensities/CC/completeness match exactly, but
the ill-conditioned 16-bin error-model b fit amplifies the ~1e-7 scale-fulls rounding
to ISa 10.6 vs 10.8 - benign, same class as the accepted phase-1 GPU rounding. The 3
upstream-nondeterministic crystals vary as before (GPU-prediction overflow, not this).

Scale-fulls drops from ~0.09s to ~0 across the two passes; combine+scale-fulls region
~0.32s GPU vs ~0.46s CPU on lyso. Still opt-in (fulls are downloaded for the host merge;
the win grows once the merge/error-model also stay resident).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-07-03 07:51:29 +02:00
co-authored by Claude Opus 4.8
parent 2c928b27cd
commit ced85bcd9d
3 changed files with 154 additions and 21 deletions
@@ -1098,31 +1098,58 @@ RotationScaleMerge::Result RotationScaleMerge::Run(bool for_search,
// --- 3. 3D combine of per-frame partials into fulls (fulls inherit their ASU group here). ---
bool combined_on_gpu = false;
bool scaled_fulls_on_gpu = false;
#ifdef JFJOCH_USE_CUDA
// The GPU combine mirrors Combine() exactly but keeps the fulls on the device; the diagnostic dump
// (serial, one writer) has no GPU equivalent, so fall back to the CPU path when it is requested.
// GPU combine (+ scale-fulls) keeps the fulls resident on the device: combine, then build the frame /
// ASU-group CSRs on the host from just the small key arrays (a deterministic counting sort - no GPU
// stable-sort), scale the fulls in place, and download only once. Mirrors Combine() + the Unity
// scale-fulls loop below. The diagnostic dump (serial, one writer) has no GPU path -> CPU fallback.
if (gpu_active_ && gpu_combine_ && observation_dump_path.empty()) {
std::vector<float> corr(partials.size()); // refresh the smoothed corr on the device
for (size_t i = 0; i < partials.size(); ++i) corr[i] = partials[i].corr;
gpu_->SetCorr(corr.data());
const int nf = gpu_->Combine(rawrun_group.data(), min_partiality, capture_uncertainty_coeff);
fulls.assign(nf, Obs{});
g_full.assign(n_frames, 1.0);
if (scale_fulls && nf > 0) {
// Frame + group CSRs over the emit-ordered fulls, built by counting sort on the host (stable,
// deterministic). frame is always in [0, n_frames); group is <0 for absent/out-of-range fulls.
std::vector<int32_t> ff(nf), fg(nf);
gpu_->GetFullsKeys(ff.data(), fg.data());
std::vector<int32_t> f_start(n_frames, 0), f_count(n_frames, 0), f_perm(nf);
for (int i = 0; i < nf; ++i) ++f_count[ff[i]];
for (int f = 1; f < n_frames; ++f) f_start[f] = f_start[f - 1] + f_count[f - 1];
{ std::vector<int32_t> fill = f_start; for (int i = 0; i < nf; ++i) f_perm[fill[ff[i]]++] = i; }
gpu_->SetFullsFrameCSR(f_perm.data(), nf, f_start.data(), f_count.data());
std::vector<int32_t> g_count(n_groups, 0), g_start(n_groups, 0);
for (int i = 0; i < nf; ++i) if (fg[i] >= 0) ++g_count[fg[i]];
int acc = 0;
for (int g = 0; g < n_groups; ++g) { g_start[g] = acc; acc += g_count[g]; }
std::vector<int32_t> g_perm(acc);
{ std::vector<int32_t> fill = g_start; for (int i = 0; i < nf; ++i) if (fg[i] >= 0) g_perm[fill[fg[i]]++] = i; }
gpu_->SetFullsGroups(g_perm.data(), acc, g_start.data(), g_count.data());
gpu_->ScaleFulls(scaling_iter, SCALE_ROBUST_K, min_partiality);
scaled_fulls_on_gpu = true;
}
fulls.assign(nf, Obs{});
std::vector<int32_t> fh(nf), fk(nf), fl(nf), fframe(nf), fgroup(nf);
std::vector<float> fI(nf), fsig(nf), fd(nf), fimg(nf);
std::vector<float> fI(nf), fsig(nf), fd(nf), fimg(nf), fcorr(nf, 1.0f);
std::vector<uint8_t> fon(nf);
gpu_->GetFulls(fh.data(), fk.data(), fl.data(), fI.data(), fsig.data(), fd.data(),
fimg.data(), fframe.data(), fon.data(), fgroup.data());
if (scaled_fulls_on_gpu) gpu_->GetFullsCorr(fcorr.data());
for (int i = 0; i < nf; ++i) {
Obs &o = fulls[i];
o.h = fh[i]; o.k = fk[i]; o.l = fl[i];
o.I = fI[i]; o.sigma = fsig[i]; o.d = fd[i];
o.rlp = 1.0f; o.partiality = 1.0f; o.corr = 1.0f;
o.rlp = 1.0f; o.partiality = 1.0f; o.corr = fcorr[i];
o.image_number = fimg[i]; o.frame = fframe[i];
o.on_ice = fon[i]; o.group = fgroup[i];
}
SortFullsByFrame();
logger.Info("3D combine (GPU): {} fulls", nf);
logger.Info("3D combine{} (GPU): {} fulls", scaled_fulls_on_gpu ? " + scale-fulls" : "", nf);
combined_on_gpu = true;
}
#endif
@@ -1131,7 +1158,7 @@ RotationScaleMerge::Result RotationScaleMerge::Run(bool for_search,
lap("combine");
// --- 4. Scale the fulls (XDS order, Unity model). ---
if (scale_fulls) {
if (scale_fulls && !scaled_fulls_on_gpu) {
std::vector<double> full_mean;
for (int it = 0; it < scaling_iter; ++it) {
ReduceGroupMeans(fulls, n_groups, false, {}, full_mean);
@@ -86,11 +86,15 @@ namespace {
// One block per frame: robust per-frame scale G by IRLS (Cauchy), over the frame's contiguous obs.
// Identical objective to the CPU SolveScaleIRLS. Leaves g/scaled untouched for under-populated frames.
// `perm` (null for the partials, whose arrays are already frame-contiguous) maps a position in the
// frame's [lo,hi) range to the obs index, so the same kernel scales the fulls (emit-ordered) through a
// frame-grouping permutation without physically reordering the fulls arrays.
__global__ void FitPerFrameGKernel(int n_frames, double robust_k,
const int32_t *__restrict__ frame_start,
const int32_t *__restrict__ frame_count,
const float *__restrict__ I, const float *__restrict__ sigma,
const float *__restrict__ sco_coeff, const uint8_t *__restrict__ sco_ok,
const int32_t *__restrict__ perm,
double *__restrict__ g, uint8_t *__restrict__ scaled) {
const int f = blockIdx.x;
if (f >= n_frames) return;
@@ -100,7 +104,7 @@ namespace {
// count accepted
long cnt_local = 0;
for (int i = lo + threadIdx.x; i < hi; i += blockDim.x)
if (sco_ok[i]) ++cnt_local;
if (sco_ok[perm ? perm[i] : i]) ++cnt_local;
const double cnt = BlockReduceSum(double(cnt_local), sh);
__shared__ double s_cnt;
if (threadIdx.x == 0) s_cnt = cnt;
@@ -112,11 +116,12 @@ namespace {
// seed: plain weighted-LS ratio (robust weight = 1)
double num = 0.0, den = 0.0;
for (int i = lo + threadIdx.x; i < hi; i += blockDim.x) {
if (!sco_ok[i]) continue;
const double coeff = sco_coeff[i];
const double w = SafeInvD(sigma[i], 1.0);
const int a = perm ? perm[i] : i;
if (!sco_ok[a]) continue;
const double coeff = sco_coeff[a];
const double w = SafeInvD(sigma[a], 1.0);
const double w2 = w * w;
num += w2 * coeff * double(I[i]);
num += w2 * coeff * double(I[a]);
den += w2 * coeff * coeff;
}
double tnum = BlockReduceSum(num, sh); __syncthreads();
@@ -133,13 +138,14 @@ namespace {
const double G = s_G;
num = 0.0; den = 0.0;
for (int i = lo + threadIdx.x; i < hi; i += blockDim.x) {
if (!sco_ok[i]) continue;
const double coeff = sco_coeff[i];
const double w = SafeInvD(sigma[i], 1.0);
const int a = perm ? perm[i] : i;
if (!sco_ok[a]) continue;
const double coeff = sco_coeff[a];
const double w = SafeInvD(sigma[a], 1.0);
const double w2 = w * w;
const double res = w * (G * coeff - double(I[i]));
const double res = w * (G * coeff - double(I[a]));
const double rw = 1.0 / (1.0 + res * res / k2);
num += rw * w2 * coeff * double(I[i]);
num += rw * w2 * coeff * double(I[a]);
den += rw * w2 * coeff * coeff;
}
tnum = BlockReduceSum(num, sh); __syncthreads();
@@ -336,6 +342,11 @@ namespace {
CombineRawRun<Emit>(r, p);
}
template <typename T>
__global__ void FillKernel(T *p, int n, T v) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += gridDim.x * blockDim.x) p[i] = v;
}
void CudaCheck(cudaError_t e, const char *what) {
if (e != cudaSuccess)
throw JFJochException(JFJochExceptionCategory::GPUCUDAError,
@@ -378,6 +389,13 @@ struct RotationScaleMergeGPU::Impl {
CudaDevicePtr<int32_t> f_h, f_k, f_l, f_frame, f_group;
CudaDevicePtr<float> f_I, f_sigma, f_d, f_img;
CudaDevicePtr<uint8_t> f_on_ice;
// scale-fulls (Unity model, kept resident): all-ones partiality/rlp/zeta so the shared scaling kernels
// yield coeff=mean, plus the working corr, the per-obs scale scratch, and the fulls frame/group CSRs
// (built on the host from the small f_frame/f_group key arrays, over the emit-ordered fulls).
CudaDevicePtr<float> f_corr, f_partiality, f_rlp, f_zeta, f_sco_coeff;
CudaDevicePtr<uint8_t> f_sco_ok;
CudaDevicePtr<int32_t> f_frame_perm, f_frame_start, f_frame_count;
CudaDevicePtr<int32_t> f_gperm, f_gstart, f_gcount;
};
RotationScaleMergeGPU::RotationScaleMergeGPU() : impl_(std::make_unique<Impl>()) {
@@ -441,7 +459,7 @@ void RotationScaleMergeGPU::ScalePartials(int iters, double robust_k, double min
PrepScaleObsKernel<<<obs_blocks, BLK>>>(d.n_obs, d.group.get(), d.partiality.get(), d.rlp.get(),
d.zeta.get(), d.on_ice.get(), d.group_mean.get(), d.sco_coeff.get(), d.sco_ok.get());
FitPerFrameGKernel<<<d.n_frames, BLK>>>(d.n_frames, robust_k, d.frame_start.get(), d.frame_count.get(),
d.I.get(), d.sigma.get(), d.sco_coeff.get(), d.sco_ok.get(), d.g.get(), d.scaled.get());
d.I.get(), d.sigma.get(), d.sco_coeff.get(), d.sco_ok.get(), nullptr, d.g.get(), d.scaled.get());
UpdateCorrKernel<<<upd_blocks, BLK>>>(d.n_obs, d.frame.get(), d.rlp.get(), d.partiality.get(),
d.g.get(), d.scaled.get(), d.corr.get());
}
@@ -525,6 +543,9 @@ int RotationScaleMergeGPU::Combine(const int32_t *rawrun_group, double min_parti
d.f_I = CudaDevicePtr<float>(nf); d.f_sigma = CudaDevicePtr<float>(nf);
d.f_d = CudaDevicePtr<float>(nf); d.f_img = CudaDevicePtr<float>(nf);
d.f_on_ice = CudaDevicePtr<uint8_t>(nf);
d.f_corr = CudaDevicePtr<float>(nf); d.f_partiality = CudaDevicePtr<float>(nf);
d.f_rlp = CudaDevicePtr<float>(nf); d.f_zeta = CudaDevicePtr<float>(nf);
d.f_sco_coeff = CudaDevicePtr<float>(nf); d.f_sco_ok = CudaDevicePtr<uint8_t>(nf);
CudaCheck(cudaMemcpy(d.rr_offset.get(), offset.data(), size_t(d.n_runs) * sizeof(int32_t),
cudaMemcpyHostToDevice), "upload offset");
@@ -557,3 +578,66 @@ void RotationScaleMergeGPU::GetFulls(int32_t *h, int32_t *k, int32_t *l, float *
dl(d, dd.f_d.get(), n * sizeof(float)); dl(image_number, dd.f_img.get(), n * sizeof(float));
dl(on_ice, dd.f_on_ice.get(), n * sizeof(uint8_t));
}
void RotationScaleMergeGPU::GetFullsKeys(int32_t *frame, int32_t *group) const {
const auto &d = *impl_;
if (d.n_fulls == 0) return;
const size_t bytes = size_t(d.n_fulls) * sizeof(int32_t);
CudaCheck(cudaMemcpy(frame, d.f_frame.get(), bytes, cudaMemcpyDeviceToHost), "download f_frame");
CudaCheck(cudaMemcpy(group, d.f_group.get(), bytes, cudaMemcpyDeviceToHost), "download f_group");
}
void RotationScaleMergeGPU::SetFullsFrameCSR(const int32_t *frame_perm, int n_perm,
const int32_t *frame_start, const int32_t *frame_count) {
auto &d = *impl_;
Upload(d.f_frame_perm, frame_perm, n_perm);
Upload(d.f_frame_start, frame_start, d.n_frames);
Upload(d.f_frame_count, frame_count, d.n_frames);
}
void RotationScaleMergeGPU::SetFullsGroups(const int32_t *gperm, int n_gperm,
const int32_t *gstart, const int32_t *gcount) {
auto &d = *impl_;
Upload(d.f_gperm, gperm, n_gperm);
Upload(d.f_gstart, gstart, d.n_groups);
Upload(d.f_gcount, gcount, d.n_groups);
}
void RotationScaleMergeGPU::ScaleFulls(int iters, double robust_k, double min_partiality) {
auto &d = *impl_;
const int nf = d.n_fulls;
if (nf == 0) return;
const int obs_blocks = std::min(65535, (nf + BLK - 1) / BLK);
const int grp_blocks = std::min(65535, (d.n_groups + BLK - 1) / BLK);
// Unity model: partiality/rlp/zeta = 1 so coeff = mean; corr starts at 1.
FillKernel<<<obs_blocks, BLK>>>(d.f_corr.get(), nf, 1.0f);
FillKernel<<<obs_blocks, BLK>>>(d.f_partiality.get(), nf, 1.0f);
FillKernel<<<obs_blocks, BLK>>>(d.f_rlp.get(), nf, 1.0f);
FillKernel<<<obs_blocks, BLK>>>(d.f_zeta.get(), nf, 1.0f);
// g/scaled reset once (a frame, once fitted, stays fitted - same as ScalePartials).
CudaCheck(cudaMemset(d.scaled.get(), 0, size_t(d.n_frames) * sizeof(uint8_t)), "memset f scaled");
CudaCheck(cudaMemset(d.g.get(), 0, size_t(d.n_frames) * sizeof(double)), "memset f g");
for (int it = 0; it < iters; ++it) {
ReduceGroupMeansKernel<<<grp_blocks, BLK>>>(d.n_groups, min_partiality,
d.f_gperm.get(), d.f_gstart.get(), d.f_gcount.get(),
d.f_I.get(), d.f_sigma.get(), d.f_partiality.get(), d.f_corr.get(), d.group_mean.get());
PrepScaleObsKernel<<<obs_blocks, BLK>>>(nf, d.f_group.get(), d.f_partiality.get(), d.f_rlp.get(),
d.f_zeta.get(), d.f_on_ice.get(), d.group_mean.get(), d.f_sco_coeff.get(), d.f_sco_ok.get());
FitPerFrameGKernel<<<d.n_frames, BLK>>>(d.n_frames, robust_k,
d.f_frame_start.get(), d.f_frame_count.get(), d.f_I.get(), d.f_sigma.get(),
d.f_sco_coeff.get(), d.f_sco_ok.get(), d.f_frame_perm.get(), d.g.get(), d.scaled.get());
UpdateCorrKernel<<<obs_blocks, BLK>>>(nf, d.f_frame.get(), d.f_rlp.get(), d.f_partiality.get(),
d.g.get(), d.scaled.get(), d.f_corr.get());
}
CudaCheck(cudaGetLastError(), "scale fulls launch");
CudaCheck(cudaDeviceSynchronize(), "scale fulls sync");
}
void RotationScaleMergeGPU::GetFullsCorr(float *corr) const {
const auto &d = *impl_;
if (d.n_fulls == 0) return;
CudaCheck(cudaMemcpy(corr, d.f_corr.get(), size_t(d.n_fulls) * sizeof(float),
cudaMemcpyDeviceToHost), "download f_corr");
}
@@ -75,11 +75,33 @@ public:
// number of fulls (call GetFulls with buffers of that length).
int Combine(const int32_t *rawrun_group, double min_partiality, double capture_uncertainty_coeff);
// Download the combined fulls SoA (length = Combine()'s return). corr/partiality/rlp are 1 by
// construction and not returned; the caller sets them.
// Download the combined fulls SoA (length = Combine()'s return). The working corr is downloaded
// separately by GetFullsCorr (it is only meaningful after ScaleFulls; otherwise the caller sets it).
void GetFulls(int32_t *h, int32_t *k, int32_t *l, float *I, float *sigma, float *d,
float *image_number, int32_t *frame, uint8_t *on_ice, int32_t *group) const;
// --- scale the resident fulls on the device (Unity model), no round-trip ---
// Download the fulls' frame and ASU-group keys (emit order) so the host can build the frame/group CSRs
// with a counting sort (deterministic, no GPU stable-sort) and hand them back below.
void GetFullsKeys(int32_t *frame, int32_t *group) const;
// The fulls' per-frame CSR: frame_perm groups the emit-ordered fulls by frame (frame_start/count length
// n_frames index it), so FitPerFrameG can scale the fulls without physically reordering them.
void SetFullsFrameCSR(const int32_t *frame_perm, int n_perm,
const int32_t *frame_start, const int32_t *frame_count);
// The fulls' per-ASU-group CSR (group-ordered permutation of the fulls with group>=0, + its CSR).
void SetFullsGroups(const int32_t *gperm, int n_gperm,
const int32_t *gstart, const int32_t *gcount);
// Run `iters` of the Unity scaling loop on the resident fulls (reduce group means -> per-frame IRLS G
// -> update corr), in place on the fulls' working corr. Requires SetFullsFrameCSR + SetFullsGroups.
void ScaleFulls(int iters, double robust_k, double min_partiality);
// Download the fulls' working corr (length = n_fulls), valid after ScaleFulls.
void GetFullsCorr(float *corr) const;
private:
struct Impl;
std::unique_ptr<Impl> impl_;