// SPDX-FileCopyrightText: 2026 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "JFJochProcess.h" #include #include #include #include #include #include #include #include #include #include "../reader/JFJochHDF5Reader.h" #include "../common/Logger.h" #include "../common/AzimuthalIntegrationMapping.h" #include "../common/AzimuthalIntegrationProfile.h" #include "../common/CUDAWrapper.h" #include "../common/time_utc.h" #include "../writer/FileWriter.h" #include "../image_analysis/MXAnalysisWithoutFPGA.h" #include "../image_analysis/IndexAndRefine.h" #include "../image_analysis/indexing/IndexerThreadPool.h" #include "../image_analysis/azint/AzIntEngineCPU.h" #include "../image_analysis/image_preprocessing/ImagePreprocessorCPU.h" #include "../image_analysis/image_preprocessing/ImagePreprocessorBuffer.h" #include "../image_analysis/scale_merge/Merge.h" #include "../image_analysis/scale_merge/ScaleOnTheFly.h" #include "../image_analysis/scale_merge/SearchSpaceGroup.h" #include "../image_analysis/scale_merge/Combine3D.h" #include "../image_analysis/WriteReflections.h" namespace { // Pick up to requested_images ordinals spread evenly across [0, images_to_process) for the // first pass of two-pass rotation indexing. std::vector select_equally_spaced_image_ordinals(int images_to_process, int requested_images) { std::vector ret; if (images_to_process <= 0 || requested_images <= 0) return ret; const int n = std::min(images_to_process, requested_images); if (n == 1) { ret.push_back(0); return ret; } std::set unique_ordinals; for (int i = 0; i < n; i++) unique_ordinals.insert(static_cast( std::llround(static_cast(i) * static_cast(images_to_process - 1) / static_cast(n - 1)))); ret.assign(unique_ordinals.begin(), unique_ordinals.end()); return ret; } // XDS-order scaling. The rot3d combine emits fulls with partiality == 1 (image_scale_corr == 1), // so they were only ever scaled as per-frame *partials* upstream - their per-frame scale is // entangled with the rocking-curve/partiality model. This refits a per-frame scale directly on // the complete reflections with the Unity model (no partiality term, G*Itrue - I_full), the way // XDS/DIALS scale 3D-integrated fulls. A pure post-correction: it updates image_scale_corr on the // fulls (1 -> 1/G) without re-combining. void ScaleFulls(const DiffractionExperiment &experiment, std::vector &fulls, int scaling_iter, size_t nthreads, Logger &logger) { DiffractionExperiment unity = experiment; ScalingSettings ss = unity.GetScalingSettings(); ss.SetPartialityModel(PartialityModel::Unity); unity.ImportScalingSettings(ss); for (int i = 0; i < scaling_iter; i++) { const auto reference = MergeAll(unity, fulls, false); ScaleOnTheFly(unity, reference).Scale(fulls, nthreads); } logger.Info("Scaled fulls (XDS order, Unity model)"); } } JFJochProcess::JFJochProcess(JFJochHDF5Reader &reader, DiffractionExperiment experiment, PixelMask pixel_mask, ProcessConfig config) : reader_(reader), experiment_(std::move(experiment)), pixel_mask_(std::move(pixel_mask)), config_(std::move(config)) {} ProcessResult JFJochProcess::Run(JFJochProcessObserver *observer) { Logger logger("JFJochProcess"); ProcessResult result; const auto dataset = reader_.GetDataset(); if (!dataset) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No experiment dataset found in the input file"); if (config_.stride <= 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Image stride must be positive"); const auto total_images_in_file = static_cast(reader_.GetNumberOfImages()); int end_image = config_.end_image; if (end_image < 0 || end_image > total_images_in_file) end_image = total_images_in_file; const int start_image = config_.start_image; const int images_to_process = (end_image - start_image) / config_.stride; if (images_to_process <= 0) { logger.Warning("No images to process (start {}, end {}, stride {}, total {})", start_image, end_image, config_.stride, total_images_in_file); return result; } const bool full = (config_.mode == ProcessMode::FullAnalysis); const bool write_files = !config_.output_prefix.empty(); // Output/runtime invariants. Algorithm settings (indexing, scaling, integration, polarization, // space group, unit cell, ...) are configured on experiment_ by the caller. experiment_.BitDepthImage(32).PixelSigned(true); experiment_.Mode(DetectorMode::Standard); experiment_.OverwriteExistingFiles(true); experiment_.FilePrefix(config_.output_prefix.empty() ? "output" : config_.output_prefix); experiment_.SetFileWriterFormat(FileWriterFormat::NXmxLegacy); experiment_.ImagesPerTrigger(images_to_process); experiment_.NumTriggers(1); if (full) experiment_.Compression(CompressionAlgorithm::BSHUF_LZ4); // The pipeline indexes images 0..N-1 within this run; if we process a sub-range/strided // selection, shift the goniometer so local index i maps to the angle of original image // start+i*stride (keeping the per-image rotation wedge), otherwise rotation angles would be // wrong for any start_image != 0. if (const auto g = experiment_.GetGoniometer(); g.has_value() && (start_image != 0 || config_.stride != 1)) { const float incr = g->GetIncrement_deg(); GoniometerAxis shifted(g->GetName(), g->GetStart_deg() + incr * static_cast(start_image), incr * static_cast(config_.stride), g->GetAxis(), g->GetHelicalStep()); shifted.ScreeningWedge(g->GetScreeningWedge().value_or(incr)); experiment_.Goniometer(shifted); } AzimuthalIntegrationMapping mapping(experiment_, pixel_mask_); JFJochReceiverPlots plots; plots.Setup(experiment_, mapping); // Output file (NXmxIntegrated master that links back to the original images). StartMessage start_message; experiment_.FillMessage(start_message); start_message.arm_date = dataset->arm_date; start_message.az_int_bin_to_q = mapping.GetBinToQ(); start_message.az_int_bin_to_two_theta = mapping.GetBinToTwoTheta(); start_message.az_int_q_bin_count = mapping.GetQBinCount(); start_message.az_int_phi_bin_count = mapping.GetAzimuthalBinCount(); if (mapping.GetAzimuthalBinCount() > 1) start_message.az_int_bin_to_phi = mapping.GetBinToPhi(); start_message.pixel_mask["default"] = pixel_mask_.GetMask(experiment_); if (full) { start_message.rois = experiment_.ROI().ExportMetadata(); if (!experiment_.ROI().empty()) start_message.roi_map = experiment_.ExportROIMap(); start_message.max_spot_count = experiment_.GetMaxSpotCount(); } start_message.master_suffix = "process"; start_message.file_format = FileWriterFormat::NXmxIntegrated; start_message.write_master_file = true; start_message.write_images = false; start_message.hdf5_source_data = reader_.GetHDF5DataSource(start_image, images_to_process); std::unique_ptr writer; if (write_files) writer = std::make_unique(start_message); logger.Info("Processing {} images (range {}-{}, stride {}) using {} threads [{}]", images_to_process, start_image, end_image, config_.stride, config_.nthreads, full ? "full analysis" : "azimuthal integration"); if (observer) observer->OnPhase(full ? "Full analysis" : "Azimuthal integration"); // Full-analysis shared engines. std::unique_ptr indexer_pool; std::unique_ptr indexer; if (full) { indexer_pool = std::make_unique(experiment_.GetIndexingSettings()); indexer = std::make_unique(experiment_, indexer_pool.get()); if (!config_.reference_data.empty()) indexer->ReferenceIntensities(config_.reference_data); } const auto start_time = std::chrono::steady_clock::now(); // First pass of two-pass rotation indexing (full analysis only). if (full && config_.forced_rotation_lattice.has_value()) { indexer->ForceRotationIndexerLattice(*config_.forced_rotation_lattice); logger.Info("Rotation indexer lattice forced externally - skipping first pass"); } else if (full && config_.rotation_indexing && config_.two_pass_rotation) { if (observer) observer->OnPhase("Rotation indexing (first pass)"); const auto selected = select_equally_spaced_image_ordinals(images_to_process, config_.rotation_indexing_image_count); logger.Info("First-pass rotation indexing using {} images{}", selected.size(), config_.reuse_rotation_spots ? " and stored spots" : ""); for (const int ordinal: selected) { if (cancelled_) break; const int image_idx = start_image + ordinal * config_.stride; DataMessage msg{}; msg.number = ordinal; // index into the rotation indexer (0..images_to_process-1) msg.original_number = image_idx; try { if (config_.reuse_rotation_spots) { msg.spots = reader_.ReadSpots(image_idx); } else { auto img = reader_.GetRawImage(image_idx); if (!img) continue; MXAnalysisWithoutFPGA analysis(experiment_, mapping, pixel_mask_, *indexer); AzimuthalIntegrationProfile profile(mapping); auto first_pass = config_.spot_finding; first_pass.indexing = false; first_pass.quick_integration = false; msg.image = img->image; if (dataset->efficiency.size() > image_idx) msg.image_collection_efficiency = dataset->efficiency[image_idx]; analysis.Analyze(msg, profile, first_pass); } indexer->AddImageToRotationIndexer(msg); } catch (const std::exception &e) { logger.Warning("First-pass rotation indexing failed for image {}: {}", image_idx, e.what()); } } if (!cancelled_ && !indexer->FinalizeRotationIndexing().has_value()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Two-pass rotation indexing failed"); if (!cancelled_) logger.Info("Two-pass rotation indexing found lattice"); } // Main per-image loop, spread over N worker threads pulling from a shared counter. HDF5 reads // are serialized by the global hdf5_mutex; the analysis runs in parallel. std::atomic next_ordinal = 0; std::atomic finished_count = 0; std::atomic total_uncompressed_bytes = 0; auto azint_worker = [&]() { std::vector decompression_buffer; ImagePreprocessorCPU preprocessor(experiment_, pixel_mask_); ImagePreprocessorBuffer buffer(experiment_.GetPixelsNum()); AzIntEngineCPU azint(mapping); AzimuthalIntegrationProfile profile(mapping); while (!cancelled_) { const int ordinal = next_ordinal.fetch_add(1); const int image_idx = start_image + ordinal * config_.stride; if (image_idx >= end_image) break; std::shared_ptr img; try { img = reader_.GetRawImage(image_idx); } catch (const std::exception &e) { logger.Error("Failed to load image {}: {}", image_idx, e.what()); continue; } if (!img) continue; DataMessage msg{}; msg.image = img->image; msg.number = ordinal; msg.original_number = image_idx; if (dataset->efficiency.size() > image_idx) msg.image_collection_efficiency = dataset->efficiency[image_idx]; total_uncompressed_bytes += msg.image.GetUncompressedSize(); const auto t0 = std::chrono::steady_clock::now(); try { const uint8_t *image_ptr = msg.image.GetUncompressedPtr(decompression_buffer); preprocessor.Analyze(buffer, image_ptr, msg.image.GetMode()); azint.Run(buffer, profile); } catch (const std::exception &e) { logger.Error("Error integrating image {}: {}", image_idx, e.what()); continue; } msg.azint_time_s = std::chrono::duration(std::chrono::steady_clock::now() - t0).count(); msg.processing_time_s = msg.azint_time_s; msg.az_int_profile = profile.GetResult(); msg.az_int_profile_count = profile.GetPixelCount(); msg.az_int_profile_std = profile.GetStd(); msg.bkg_estimate = profile.GetBkgEstimate(mapping.Settings()); msg.run_number = experiment_.GetRunNumber(); msg.run_name = experiment_.GetRunName(); plots.Add(msg, profile); if (writer) writer->Write(msg); if (observer) observer->OnImageProcessed(msg); const int done = finished_count.fetch_add(1) + 1; if (observer) observer->OnProgress(done, images_to_process); } }; auto full_worker = [&]() { pin_gpu(); // round-robin per worker thread; must precede engine construction MXAnalysisWithoutFPGA analysis(experiment_, mapping, pixel_mask_, *indexer); AzimuthalIntegrationProfile profile(mapping); while (!cancelled_) { const int ordinal = next_ordinal.fetch_add(1); const int image_idx = start_image + ordinal * config_.stride; if (image_idx >= end_image) break; std::shared_ptr img; try { img = reader_.GetRawImage(image_idx); } catch (const std::exception &e) { logger.Error("Failed to load image {}: {}", image_idx, e.what()); continue; } if (!img) continue; DataMessage msg{}; msg.image = img->image; msg.number = ordinal; msg.original_number = image_idx; if (dataset->efficiency.size() > image_idx) msg.image_collection_efficiency = dataset->efficiency[image_idx]; total_uncompressed_bytes += msg.image.GetUncompressedSize(); const auto t0 = std::chrono::steady_clock::now(); try { analysis.Analyze(msg, profile, config_.spot_finding); } catch (const std::exception &e) { logger.Error("Error analyzing image {}: {}", image_idx, e.what()); continue; } msg.processing_time_s = std::chrono::duration(std::chrono::steady_clock::now() - t0).count(); msg.run_number = experiment_.GetRunNumber(); msg.run_name = experiment_.GetRunName(); plots.Add(msg, profile); if (writer) writer->Write(msg); if (observer) observer->OnImageProcessed(msg); const int done = finished_count.fetch_add(1) + 1; if (observer) observer->OnProgress(done, images_to_process); } }; if (observer) observer->OnPhase("Processing images"); std::function worker = full ? std::function(full_worker) : std::function(azint_worker); std::vector > futures; futures.reserve(config_.nthreads); for (int i = 0; i < config_.nthreads; ++i) futures.push_back(std::async(std::launch::async, worker)); for (auto &f: futures) f.get(); const auto end_time = std::chrono::steady_clock::now(); result.cancelled = cancelled_; result.images_processed = finished_count.load(); result.processing_time_s = std::chrono::duration(end_time - start_time).count(); if (result.processing_time_s > 0.0) { result.frame_rate_hz = static_cast(result.images_processed) / result.processing_time_s; result.throughput_MBs = static_cast(total_uncompressed_bytes) / (result.processing_time_s * 1e6); } result.mean_processing_time = plots.GetMeanProcessingTime(); result.indexing_rate = plots.GetIndexingRate(); // End message (also written to the file). EndMessage end_msg; end_msg.max_image_number = result.images_processed; end_msg.images_collected_count = result.images_processed; end_msg.images_sent_to_write_count = result.images_processed; end_msg.end_date = time_UTC(std::chrono::system_clock::now()); end_msg.run_number = experiment_.GetRunNumber(); end_msg.run_name = experiment_.GetRunName(); end_msg.bkg_estimate = plots.GetBkgEstimate(); end_msg.az_int_result["dataset"] = plots.GetAzIntProfile(); end_msg.indexing_rate = result.indexing_rate; if (full && !cancelled_) { if (const auto rot = indexer->FinalizeRotationIndexing(); rot.has_value()) { end_msg.rotation_lattice = rot->lattice; end_msg.rotation_lattice_type = LatticeMessage{ .centering = rot->search_result.centering, .niggli_class = rot->search_result.niggli_class, .crystal_system = rot->search_result.system }; result.rotation_lattice_found = true; } result.consensus_cell = indexer->GetConsensusUnitCell(); end_msg.unit_cell = result.consensus_cell; } // Scaling and merging (full analysis only). if (full && !cancelled_ && result.indexing_rate.has_value() && result.indexing_rate > 0 && (config_.run_scaling || !config_.reference_data.empty())) { // Scaling/merging is a long post-pass; report each sub-step as a phase so the GUI progress // bar reflects what is happening instead of freezing on one label. Also time each phase // (logged on transition) so the bottlenecks are visible. auto t_phase = std::chrono::steady_clock::now(); std::string prev_phase; auto phase = [&](const std::string &p) { const auto now = std::chrono::steady_clock::now(); if (!prev_phase.empty()) logger.Info("[timing] {}: {:.2f} s", prev_phase, std::chrono::duration(now - t_phase).count()); t_phase = now; prev_phase = p; if (observer) observer->OnPhase(p); }; phase("Scaling and merging"); // ScaleOnTheFly self-scaling is only for the no-reference path; with a reference each image // is already scaled against it during the per-image pass, so we merge directly. if (config_.reference_data.empty()) { logger.Info("Running scaling ..."); ScalingResult scale_result(0); double t_merge_all = 0.0, t_scale = 0.0; for (int i = 0; i < config_.scaling_iter; i++) { phase("Scaling images (iteration " + std::to_string(i + 1) + "/" + std::to_string(config_.scaling_iter) + ")"); const auto a = std::chrono::steady_clock::now(); auto merge_result = MergeAll(experiment_, indexer->GetIntegrationOutcome(), false); const auto b = std::chrono::steady_clock::now(); scale_result = indexer->ScaleAllImages(merge_result); const auto c = std::chrono::steady_clock::now(); t_merge_all += std::chrono::duration(b - a).count(); t_scale += std::chrono::duration(c - b).count(); } logger.Info("[timing] scaling loop ({} iter): MergeAll(serial) {:.2f} s, ScaleAllImages(parallel) {:.2f} s", config_.scaling_iter, t_merge_all, t_scale); } // -P rot3d: weight-sum each reflection's per-frame partials into one full before merging, so // the error model sees counting statistics (high ISa) instead of rocking-curve slicing scatter. const bool rot3d = experiment_.GetScalingSettings().GetCombine3D(); std::vector combined; if (rot3d) { phase("Combining 3D partials"); combined = CombineRotationObservations(indexer->GetIntegrationOutcome(), experiment_, &logger, config_.observation_dump_path); } if (rot3d && experiment_.GetScalingSettings().GetScaleFulls()) { phase("Scaling fulls (XDS order)"); ScaleFulls(experiment_, combined, static_cast(config_.scaling_iter), config_.nthreads, logger); } const std::vector &merge_input = rot3d ? combined : indexer->GetIntegrationOutcome(); phase("Merging"); MergeOnTheFly merge_engine(experiment_); if (result.consensus_cell.has_value()) merge_engine.ReferenceCell(*result.consensus_cell); merge_engine.RefineErrorModel(merge_input); if (merge_engine.ErrorModelActive()) logger.Info("Error model: a={:.3f} b={:.3f} ISa={:.1f}", merge_engine.ErrorModelA(), merge_engine.ErrorModelB(), merge_engine.ErrorModelB() > 0 ? 1.0 / merge_engine.ErrorModelB() : 0.0); for (const auto &outcome: merge_input) merge_engine.AddImage(outcome); auto merged_reflections = merge_engine.ExportReflections(); phase("Computing statistics"); auto merged_statistics = merge_engine.MergeStats(merged_reflections, merge_input, config_.reference_data); logger.Info("Merge complete ({} unique reflections)", merged_reflections.size()); std::ostringstream stats_text; if (!experiment_.GetGemmiSpaceGroup().has_value()) { SearchSpaceGroupOptions sg_opts; sg_opts.crystal_system.reset(); sg_opts.centering = '\0'; sg_opts.merge_friedel = experiment_.GetScalingSettings().GetMergeFriedel(); sg_opts.d_min_limit_A = experiment_.GetScalingSettings().GetHighResolutionLimit_A().value_or(0.0); sg_opts.min_operator_cc = 0.80; sg_opts.min_pairs_per_operator = 20; sg_opts.min_total_compared = 100; sg_opts.test_systematic_absences = true; stats_text << SearchSpaceGroupResultToText(SearchSpaceGroup(merged_reflections, sg_opts)) << "\n\n"; } stats_text << merged_statistics; result.merge_statistics_text = stats_text.str(); result.has_merge_statistics = true; result.merge_statistics = merged_statistics; result.error_model_isa = merge_engine.ErrorModelB() > 0 ? 1.0 / merge_engine.ErrorModelB() : 0.0; result.has_reference = !config_.reference_data.empty(); if (result.consensus_cell && write_files) { phase("Writing reflections"); WriteReflections(merged_reflections, *result.consensus_cell, experiment_, config_.output_prefix); } phase(""); // flush the last phase's timing } if (writer) { writer->WriteHDF5(end_msg); writer->Finalize(); result.written_master_path = config_.output_prefix + "_process.h5"; } if (observer) observer->OnPhase(cancelled_ ? "Cancelled" : "Done"); logger.Info("{} {} images in {:.2f} s ({:.2f} Hz)", cancelled_ ? "Cancelled after" : "Processed", result.images_processed, result.processing_time_s, result.frame_rate_hz); return result; }