Files
Jungfraujoch/process/JFJochProcess.cpp
T
leonarski_f 50ebf1694e Add global (joint) scaling path: jfjoch_process --global-scale
Refine all per-image scales, per-image mosaicities and the shared per-HKL
true intensities (plus the global wedge) together in one Ceres problem
(GlobalScale.{h,cpp}, GlobalScaleCeres), as an alternative to the existing
alternating merge-then-scale-each-image loop. The result feeds back into each
reflection's image_scale_corr exactly as ScaleOnTheFly does, so the rot3d
combine and MergeOnTheFly run unchanged and every metric stays comparable.

On HEWL crystal 2 the joint fit converges to the same place as the alternating
loop (anomalous S-peak vs XDS 0.53x -> 0.54x, ISa 9.4 both), confirming the
alternating path already reaches the joint optimum. Kept as a validated
alternative and the home for future global corrections.

A shared quadratic absorption surface in detector position was prototyped here
and dropped: it fit large non-physical coefficients (radially degenerate with
the per-HKL resolution structure), lowered the scaling-fit residual but raised
the error-model b (ISa 9.4 -> 7.3) and did not improve anomalous accuracy.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-25 20:43:04 +02:00

518 lines
24 KiB
C++

// SPDX-FileCopyrightText: 2026 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "JFJochProcess.h"
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cmath>
#include <functional>
#include <future>
#include <numeric>
#include <set>
#include <sstream>
#include "../reader/JFJochHDF5Reader.h"
#include "../common/Logger.h"
#include "../common/AzimuthalIntegrationMapping.h"
#include "../common/AzimuthalIntegrationProfile.h"
#include "../common/CUDAWrapper.h"
#include "../common/time_utc.h"
#include "../writer/FileWriter.h"
#include "../image_analysis/MXAnalysisWithoutFPGA.h"
#include "../image_analysis/IndexAndRefine.h"
#include "../image_analysis/indexing/IndexerThreadPool.h"
#include "../image_analysis/azint/AzIntEngineCPU.h"
#include "../image_analysis/image_preprocessing/ImagePreprocessorCPU.h"
#include "../image_analysis/image_preprocessing/ImagePreprocessorBuffer.h"
#include "../image_analysis/scale_merge/Merge.h"
#include "../image_analysis/scale_merge/GlobalScale.h"
#include "../image_analysis/scale_merge/SearchSpaceGroup.h"
#include "../image_analysis/scale_merge/Combine3D.h"
#include "../image_analysis/WriteReflections.h"
namespace {
// Pick up to requested_images ordinals spread evenly across [0, images_to_process) for the
// first pass of two-pass rotation indexing.
std::vector<int> select_equally_spaced_image_ordinals(int images_to_process, int requested_images) {
std::vector<int> ret;
if (images_to_process <= 0 || requested_images <= 0)
return ret;
const int n = std::min(images_to_process, requested_images);
if (n == 1) {
ret.push_back(0);
return ret;
}
std::set<int> unique_ordinals;
for (int i = 0; i < n; i++)
unique_ordinals.insert(static_cast<int>(
std::llround(static_cast<double>(i) * static_cast<double>(images_to_process - 1) /
static_cast<double>(n - 1))));
ret.assign(unique_ordinals.begin(), unique_ordinals.end());
return ret;
}
// Global (joint) scaling: refine all per-image scales/mosaicities and the shared per-HKL
// intensities in one Ceres problem (GlobalScaleCeres), then write the result back into each
// reflection's image_scale_corr exactly as ScaleOnTheFly does - recompute the rotation
// partiality from the refined per-image mosaicity, then image_scale_corr = rlp/(partiality*G).
// The downstream rot3d combine and MergeOnTheFly are unchanged, so every metric stays
// comparable with the alternating path; only how G and mosaicity are found differs.
void RunGlobalScaling(const DiffractionExperiment &experiment,
std::vector<IntegrationOutcome> &outcomes, Logger &logger) {
const auto &s = experiment.GetScalingSettings();
const bool rotation = experiment.GetPartialityModel() == PartialityModel::Rotation;
std::vector<std::vector<Reflection>> observations;
observations.reserve(outcomes.size());
for (const auto &o: outcomes)
observations.push_back(o.reflections);
GlobalScaleOptions opt;
opt.merge_friedel = s.GetMergeFriedel();
if (const auto sg = experiment.GetGemmiSpaceGroup())
opt.space_group = *sg;
opt.d_min_limit_A = s.GetHighResolutionLimit_A().value_or(0.0);
opt.mosaicity_init_deg = s.GetDefaultMosaicity();
// The joint problem is large (per-image G + mosaicity + all Itrue + wedge); give the solver
// enough room to converge rather than the per-image default.
opt.max_solver_time_s = 300.0;
opt.max_num_iterations = 50;
double wedge = 0.0;
if (rotation) {
opt.partiality_model = GlobalScaleOptions::PartialityModel::Rotation;
wedge = experiment.GetRotationWedgeForScaling().value_or(0.0);
opt.wedge_deg = wedge;
opt.mosaicity_min_deg = s.GetMinMosaicity();
opt.mosaicity_max_deg = s.GetMaxMosaicity();
} else {
opt.partiality_model = GlobalScaleOptions::PartialityModel::Fixed;
}
const auto result = GlobalScaleCeres(observations, opt);
const double half_wedge = wedge / 2.0;
for (size_t i = 0; i < outcomes.size(); ++i) {
const double G = result.image_scale_g[i];
const double mos = result.mosaicity_deg[i];
for (auto &r: outcomes[i].reflections) {
if (rotation && std::isfinite(r.delta_phi_deg) && std::isfinite(r.zeta) && mos > 1e-6) {
const double c1 = r.zeta / std::sqrt(2.0);
r.partiality = static_cast<float>(
(std::erf((r.delta_phi_deg + half_wedge) * c1 / mos)
- std::erf((r.delta_phi_deg - half_wedge) * c1 / mos)) / 2.0);
}
const double denom = r.partiality * G;
r.image_scale_corr = (std::isfinite(r.rlp) && std::isfinite(denom) && denom > 0.0)
? static_cast<float>(r.rlp / denom) : NAN;
}
if (std::isfinite(G))
outcomes[i].image_scale_g = static_cast<float>(G);
if (rotation && std::isfinite(mos))
outcomes[i].mosaicity_deg = static_cast<float>(mos);
}
logger.Info("Global scaling complete ({} images)", outcomes.size());
}
}
JFJochProcess::JFJochProcess(JFJochHDF5Reader &reader, DiffractionExperiment experiment,
PixelMask pixel_mask, ProcessConfig config)
: reader_(reader), experiment_(std::move(experiment)),
pixel_mask_(std::move(pixel_mask)), config_(std::move(config)) {}
ProcessResult JFJochProcess::Run(JFJochProcessObserver *observer) {
Logger logger("JFJochProcess");
ProcessResult result;
const auto dataset = reader_.GetDataset();
if (!dataset)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"No experiment dataset found in the input file");
if (config_.stride <= 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Image stride must be positive");
const auto total_images_in_file = static_cast<int>(reader_.GetNumberOfImages());
int end_image = config_.end_image;
if (end_image < 0 || end_image > total_images_in_file)
end_image = total_images_in_file;
const int start_image = config_.start_image;
const int images_to_process = (end_image - start_image) / config_.stride;
if (images_to_process <= 0) {
logger.Warning("No images to process (start {}, end {}, stride {}, total {})",
start_image, end_image, config_.stride, total_images_in_file);
return result;
}
const bool full = (config_.mode == ProcessMode::FullAnalysis);
const bool write_files = !config_.output_prefix.empty();
// Output/runtime invariants. Algorithm settings (indexing, scaling, integration, polarization,
// space group, unit cell, ...) are configured on experiment_ by the caller.
experiment_.BitDepthImage(32).PixelSigned(true);
experiment_.Mode(DetectorMode::Standard);
experiment_.OverwriteExistingFiles(true);
experiment_.FilePrefix(config_.output_prefix.empty() ? "output" : config_.output_prefix);
experiment_.SetFileWriterFormat(FileWriterFormat::NXmxLegacy);
experiment_.ImagesPerTrigger(images_to_process);
experiment_.NumTriggers(1);
if (full)
experiment_.Compression(CompressionAlgorithm::BSHUF_LZ4);
// The pipeline indexes images 0..N-1 within this run; if we process a sub-range/strided
// selection, shift the goniometer so local index i maps to the angle of original image
// start+i*stride (keeping the per-image rotation wedge), otherwise rotation angles would be
// wrong for any start_image != 0.
if (const auto g = experiment_.GetGoniometer();
g.has_value() && (start_image != 0 || config_.stride != 1)) {
const float incr = g->GetIncrement_deg();
GoniometerAxis shifted(g->GetName(),
g->GetStart_deg() + incr * static_cast<float>(start_image),
incr * static_cast<float>(config_.stride),
g->GetAxis(), g->GetHelicalStep());
shifted.ScreeningWedge(g->GetScreeningWedge().value_or(incr));
experiment_.Goniometer(shifted);
}
AzimuthalIntegrationMapping mapping(experiment_, pixel_mask_);
JFJochReceiverPlots plots;
plots.Setup(experiment_, mapping);
// Output file (NXmxIntegrated master that links back to the original images).
StartMessage start_message;
experiment_.FillMessage(start_message);
start_message.arm_date = dataset->arm_date;
start_message.az_int_bin_to_q = mapping.GetBinToQ();
start_message.az_int_bin_to_two_theta = mapping.GetBinToTwoTheta();
start_message.az_int_q_bin_count = mapping.GetQBinCount();
start_message.az_int_phi_bin_count = mapping.GetAzimuthalBinCount();
if (mapping.GetAzimuthalBinCount() > 1)
start_message.az_int_bin_to_phi = mapping.GetBinToPhi();
start_message.pixel_mask["default"] = pixel_mask_.GetMask(experiment_);
if (full) {
start_message.rois = experiment_.ROI().ExportMetadata();
if (!experiment_.ROI().empty())
start_message.roi_map = experiment_.ExportROIMap();
start_message.max_spot_count = experiment_.GetMaxSpotCount();
}
start_message.master_suffix = "process";
start_message.file_format = FileWriterFormat::NXmxIntegrated;
start_message.write_master_file = true;
start_message.write_images = false;
start_message.hdf5_source_data = reader_.GetHDF5DataSource(start_image, images_to_process);
std::unique_ptr<FileWriter> writer;
if (write_files)
writer = std::make_unique<FileWriter>(start_message);
logger.Info("Processing {} images (range {}-{}, stride {}) using {} threads [{}]",
images_to_process, start_image, end_image, config_.stride, config_.nthreads,
full ? "full analysis" : "azimuthal integration");
if (observer)
observer->OnPhase(full ? "Full analysis" : "Azimuthal integration");
// Full-analysis shared engines.
std::unique_ptr<IndexerThreadPool> indexer_pool;
std::unique_ptr<IndexAndRefine> indexer;
if (full) {
indexer_pool = std::make_unique<IndexerThreadPool>(experiment_.GetIndexingSettings());
indexer = std::make_unique<IndexAndRefine>(experiment_, indexer_pool.get());
if (!config_.reference_data.empty())
indexer->ReferenceIntensities(config_.reference_data);
}
const auto start_time = std::chrono::steady_clock::now();
// First pass of two-pass rotation indexing (full analysis only).
if (full && config_.forced_rotation_lattice.has_value()) {
indexer->ForceRotationIndexerLattice(*config_.forced_rotation_lattice);
logger.Info("Rotation indexer lattice forced externally - skipping first pass");
} else if (full && config_.rotation_indexing && config_.two_pass_rotation) {
if (observer)
observer->OnPhase("Rotation indexing (first pass)");
const auto selected = select_equally_spaced_image_ordinals(images_to_process,
config_.rotation_indexing_image_count);
logger.Info("First-pass rotation indexing using {} images{}", selected.size(),
config_.reuse_rotation_spots ? " and stored spots" : "");
for (const int ordinal: selected) {
if (cancelled_) break;
const int image_idx = start_image + ordinal * config_.stride;
DataMessage msg{};
msg.number = ordinal; // index into the rotation indexer (0..images_to_process-1)
msg.original_number = image_idx;
try {
if (config_.reuse_rotation_spots) {
msg.spots = reader_.ReadSpots(image_idx);
} else {
auto img = reader_.GetRawImage(image_idx);
if (!img) continue;
MXAnalysisWithoutFPGA analysis(experiment_, mapping, pixel_mask_, *indexer);
AzimuthalIntegrationProfile profile(mapping);
auto first_pass = config_.spot_finding;
first_pass.indexing = false;
first_pass.quick_integration = false;
msg.image = img->image;
if (dataset->efficiency.size() > image_idx)
msg.image_collection_efficiency = dataset->efficiency[image_idx];
analysis.Analyze(msg, profile, first_pass);
}
indexer->AddImageToRotationIndexer(msg);
} catch (const std::exception &e) {
logger.Warning("First-pass rotation indexing failed for image {}: {}", image_idx, e.what());
}
}
if (!cancelled_ && !indexer->FinalizeRotationIndexing().has_value())
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Two-pass rotation indexing failed");
if (!cancelled_)
logger.Info("Two-pass rotation indexing found lattice");
}
// Main per-image loop, spread over N worker threads pulling from a shared counter. HDF5 reads
// are serialized by the global hdf5_mutex; the analysis runs in parallel.
std::atomic<int> next_ordinal = 0;
std::atomic<int> finished_count = 0;
std::atomic<uint64_t> total_uncompressed_bytes = 0;
auto azint_worker = [&]() {
std::vector<uint8_t> decompression_buffer;
ImagePreprocessorCPU preprocessor(experiment_, pixel_mask_);
ImagePreprocessorBuffer buffer(experiment_.GetPixelsNum());
AzIntEngineCPU azint(mapping);
AzimuthalIntegrationProfile profile(mapping);
while (!cancelled_) {
const int ordinal = next_ordinal.fetch_add(1);
const int image_idx = start_image + ordinal * config_.stride;
if (image_idx >= end_image) break;
std::shared_ptr<JFJochReaderRawImage> img;
try {
img = reader_.GetRawImage(image_idx);
} catch (const std::exception &e) {
logger.Error("Failed to load image {}: {}", image_idx, e.what());
continue;
}
if (!img) continue;
DataMessage msg{};
msg.image = img->image;
msg.number = ordinal;
msg.original_number = image_idx;
if (dataset->efficiency.size() > image_idx)
msg.image_collection_efficiency = dataset->efficiency[image_idx];
total_uncompressed_bytes += msg.image.GetUncompressedSize();
const auto t0 = std::chrono::steady_clock::now();
try {
const uint8_t *image_ptr = msg.image.GetUncompressedPtr(decompression_buffer);
preprocessor.Analyze(buffer, image_ptr, msg.image.GetMode());
azint.Run(buffer, profile);
} catch (const std::exception &e) {
logger.Error("Error integrating image {}: {}", image_idx, e.what());
continue;
}
msg.azint_time_s = std::chrono::duration<float>(std::chrono::steady_clock::now() - t0).count();
msg.processing_time_s = msg.azint_time_s;
msg.az_int_profile = profile.GetResult();
msg.az_int_profile_count = profile.GetPixelCount();
msg.az_int_profile_std = profile.GetStd();
msg.bkg_estimate = profile.GetBkgEstimate(mapping.Settings());
msg.run_number = experiment_.GetRunNumber();
msg.run_name = experiment_.GetRunName();
plots.Add(msg, profile);
if (writer) writer->Write(msg);
if (observer) observer->OnImageProcessed(msg);
const int done = finished_count.fetch_add(1) + 1;
if (observer) observer->OnProgress(done, images_to_process);
}
};
auto full_worker = [&]() {
pin_gpu(); // round-robin per worker thread; must precede engine construction
MXAnalysisWithoutFPGA analysis(experiment_, mapping, pixel_mask_, *indexer);
AzimuthalIntegrationProfile profile(mapping);
while (!cancelled_) {
const int ordinal = next_ordinal.fetch_add(1);
const int image_idx = start_image + ordinal * config_.stride;
if (image_idx >= end_image) break;
std::shared_ptr<JFJochReaderRawImage> img;
try {
img = reader_.GetRawImage(image_idx);
} catch (const std::exception &e) {
logger.Error("Failed to load image {}: {}", image_idx, e.what());
continue;
}
if (!img) continue;
DataMessage msg{};
msg.image = img->image;
msg.number = ordinal;
msg.original_number = image_idx;
if (dataset->efficiency.size() > image_idx)
msg.image_collection_efficiency = dataset->efficiency[image_idx];
total_uncompressed_bytes += msg.image.GetUncompressedSize();
const auto t0 = std::chrono::steady_clock::now();
try {
analysis.Analyze(msg, profile, config_.spot_finding);
} catch (const std::exception &e) {
logger.Error("Error analyzing image {}: {}", image_idx, e.what());
continue;
}
msg.processing_time_s = std::chrono::duration<float>(std::chrono::steady_clock::now() - t0).count();
msg.run_number = experiment_.GetRunNumber();
msg.run_name = experiment_.GetRunName();
plots.Add(msg, profile);
if (writer) writer->Write(msg);
if (observer) observer->OnImageProcessed(msg);
const int done = finished_count.fetch_add(1) + 1;
if (observer) observer->OnProgress(done, images_to_process);
}
};
if (observer)
observer->OnPhase("Processing images");
std::function<void()> worker = full ? std::function<void()>(full_worker)
: std::function<void()>(azint_worker);
std::vector<std::future<void> > futures;
futures.reserve(config_.nthreads);
for (int i = 0; i < config_.nthreads; ++i)
futures.push_back(std::async(std::launch::async, worker));
for (auto &f: futures)
f.get();
const auto end_time = std::chrono::steady_clock::now();
result.cancelled = cancelled_;
result.images_processed = finished_count.load();
result.processing_time_s = std::chrono::duration<double>(end_time - start_time).count();
if (result.processing_time_s > 0.0) {
result.frame_rate_hz = static_cast<double>(result.images_processed) / result.processing_time_s;
result.throughput_MBs = static_cast<double>(total_uncompressed_bytes) / (result.processing_time_s * 1e6);
}
result.mean_processing_time = plots.GetMeanProcessingTime();
result.indexing_rate = plots.GetIndexingRate();
// End message (also written to the file).
EndMessage end_msg;
end_msg.max_image_number = result.images_processed;
end_msg.images_collected_count = result.images_processed;
end_msg.images_sent_to_write_count = result.images_processed;
end_msg.end_date = time_UTC(std::chrono::system_clock::now());
end_msg.run_number = experiment_.GetRunNumber();
end_msg.run_name = experiment_.GetRunName();
end_msg.bkg_estimate = plots.GetBkgEstimate();
end_msg.az_int_result["dataset"] = plots.GetAzIntProfile();
end_msg.indexing_rate = result.indexing_rate;
if (full && !cancelled_) {
if (const auto rot = indexer->FinalizeRotationIndexing(); rot.has_value()) {
end_msg.rotation_lattice = rot->lattice;
end_msg.rotation_lattice_type = LatticeMessage{
.centering = rot->search_result.centering,
.niggli_class = rot->search_result.niggli_class,
.crystal_system = rot->search_result.system
};
result.rotation_lattice_found = true;
}
result.consensus_cell = indexer->GetConsensusUnitCell();
end_msg.unit_cell = result.consensus_cell;
}
// Scaling and merging (full analysis only).
if (full && !cancelled_ && result.indexing_rate.has_value() && result.indexing_rate > 0
&& (config_.run_scaling || !config_.reference_data.empty())) {
if (observer)
observer->OnPhase("Scaling and merging");
const bool pixel_refine_path =
experiment_.GetIndexingSettings().GetGeomRefinementAlgorithm() == GeomRefinementAlgorithmEnum::PixelRefine;
// ScaleOnTheFly is only for the classical, no-reference path; with a reference (or
// PixelRefine) each image is already scaled, so we merge directly.
if (config_.reference_data.empty() && !pixel_refine_path) {
if (config_.global_scaling) {
logger.Info("Running global (joint) scaling ...");
RunGlobalScaling(experiment_, indexer->GetIntegrationOutcome(), logger);
} else {
logger.Info("Running scaling ...");
ScalingResult scale_result(0);
for (int i = 0; i < config_.scaling_iter; i++) {
auto merge_result = MergeAll(experiment_, indexer->GetIntegrationOutcome(), false);
scale_result = indexer->ScaleAllImages(merge_result);
}
}
}
// -P rot3d: weight-sum each reflection's per-frame partials into one full before merging, so
// the error model sees counting statistics (high ISa) instead of rocking-curve slicing scatter.
const bool rot3d = experiment_.GetScalingSettings().GetCombine3D();
std::vector<IntegrationOutcome> combined;
if (rot3d)
combined = CombineRotationObservations(indexer->GetIntegrationOutcome(), experiment_, &logger,
config_.observation_dump_path);
const std::vector<IntegrationOutcome> &merge_input =
rot3d ? combined : indexer->GetIntegrationOutcome();
MergeOnTheFly merge_engine(experiment_);
if (result.consensus_cell.has_value())
merge_engine.ReferenceCell(*result.consensus_cell);
merge_engine.RefineErrorModel(merge_input);
if (merge_engine.ErrorModelActive())
logger.Info("Error model: a={:.3f} b={:.3f} ISa={:.1f}", merge_engine.ErrorModelA(),
merge_engine.ErrorModelB(),
merge_engine.ErrorModelB() > 0 ? 1.0 / merge_engine.ErrorModelB() : 0.0);
for (const auto &outcome: merge_input)
merge_engine.AddImage(outcome);
auto merged_reflections = merge_engine.ExportReflections();
auto merged_statistics = merge_engine.MergeStats(merged_reflections, merge_input,
config_.reference_data);
logger.Info("Merge complete ({} unique reflections)", merged_reflections.size());
std::ostringstream stats_text;
if (!experiment_.GetGemmiSpaceGroup().has_value()) {
SearchSpaceGroupOptions sg_opts;
sg_opts.crystal_system.reset();
sg_opts.centering = '\0';
sg_opts.merge_friedel = experiment_.GetScalingSettings().GetMergeFriedel();
sg_opts.d_min_limit_A = experiment_.GetScalingSettings().GetHighResolutionLimit_A().value_or(0.0);
sg_opts.min_operator_cc = 0.80;
sg_opts.min_pairs_per_operator = 20;
sg_opts.min_total_compared = 100;
sg_opts.test_systematic_absences = true;
stats_text << SearchSpaceGroupResultToText(SearchSpaceGroup(merged_reflections, sg_opts)) << "\n\n";
}
stats_text << merged_statistics;
result.merge_statistics_text = stats_text.str();
if (result.consensus_cell && write_files)
WriteReflections(merged_reflections, *result.consensus_cell, experiment_, config_.output_prefix);
}
if (writer) {
writer->WriteHDF5(end_msg);
writer->Finalize();
result.written_master_path = config_.output_prefix + "_process.h5";
}
if (observer)
observer->OnPhase(cancelled_ ? "Cancelled" : "Done");
logger.Info("{} {} images in {:.2f} s ({:.2f} Hz)", cancelled_ ? "Cancelled after" : "Processed",
result.images_processed, result.processing_time_s, result.frame_rate_hz);
return result;
}