Indexer: Add more general Run() call with reciprocal coordinates to accommodate more demanding schemes

This commit is contained in:
2025-11-29 13:59:52 +01:00
parent b6b6bab064
commit 4225173daa
12 changed files with 61 additions and 52 deletions
+17 -8
View File
@@ -41,13 +41,22 @@ void SpotAnalyze(const DiffractionExperiment &experiment,
output.spots = spots_out;
if ((indexer != nullptr) && spot_finding_settings.indexing) {
auto latt_f = indexer->Run(experiment, output);
auto latt = latt_f.get();
std::vector<Coord> recip;
recip.reserve(output.spots.size());
bool index_ice_rings = experiment.GetIndexingSettings().GetIndexIceRings();
for (const auto &i: output.spots) {
if (index_ice_rings || !i.ice_ring)
recip.push_back(i.ReciprocalCoord(geom));
}
if (!latt)
auto latt_f = indexer->Run(experiment, recip);
auto indexer_result = latt_f.get();
output.indexing_time_s = indexer_result.indexing_time_s;
if (indexer_result.lattice.empty())
output.indexing_result = false;
else {
auto uc = latt->GetUnitCell();
auto latt = indexer_result.lattice[0];
bool beam_center_updated = false;
@@ -71,7 +80,7 @@ void SpotAnalyze(const DiffractionExperiment &experiment,
};
}
} else {
auto sym_result = LatticeSearch(latt.value());
auto sym_result = LatticeSearch(latt);
symmetry = LatticeMessage{
.centering = sym_result.centering,
.niggli_class = sym_result.niggli_class,
@@ -85,7 +94,7 @@ void SpotAnalyze(const DiffractionExperiment &experiment,
DiffractionExperiment experiment_copy(experiment);
XtalOptimizerData data{
.geom = experiment_copy.GetDiffractionGeometry(),
.latt = latt.value(),
.latt = latt,
.crystal_system = symmetry.crystal_system,
.min_spots = experiment.GetIndexingSettings().GetViableCellMinSpots(),
};
@@ -105,7 +114,7 @@ void SpotAnalyze(const DiffractionExperiment &experiment,
break;
}
if (AnalyzeIndexing(output, experiment_copy, latt.value())) {
if (AnalyzeIndexing(output, experiment_copy, latt)) {
float ewald_dist_cutoff = 0.001f;
if (output.profile_radius)
@@ -120,7 +129,7 @@ void SpotAnalyze(const DiffractionExperiment &experiment,
}
if (spot_finding_settings.quick_integration) {
auto res = BraggIntegrate2D(experiment_copy, image, latt.value(),
auto res = BraggIntegrate2D(experiment_copy, image, latt,
prediction, ewald_dist_cutoff, output.number, symmetry.centering);
constexpr size_t kMaxReflections = 10000;
+1 -1
View File
@@ -18,7 +18,7 @@ void FFBIDXIndexer::SetupUnitCell(const std::optional<UnitCell> &cell) {
l1.Vec2().x, l1.Vec2().y, l1.Vec2().z;
}
std::vector<CrystalLattice> FFBIDXIndexer::Run(const std::vector<Coord> &coord, size_t nspots) {
std::vector<CrystalLattice> FFBIDXIndexer::RunInternal(const std::vector<Coord> &coord, size_t nspots) {
std::vector<CrystalLattice> ret;
if (nspots > coord.size())
+1 -1
View File
@@ -38,7 +38,7 @@ public:
FFBIDXIndexer(const FFBIDXIndexer &i) = delete;
const FFBIDXIndexer& operator=(const FFBIDXIndexer &i) = delete;
std::vector<CrystalLattice> Run(const std::vector<Coord> &coord, size_t nspots) override;
std::vector<CrystalLattice> RunInternal(const std::vector<Coord> &coord, size_t nspots) override;
};
#endif //JUNGFRAUJOCH_INDEXERWRAPPER_H
+1 -1
View File
@@ -159,7 +159,7 @@ std::vector<Coord> FFTIndexer::FilterFFTResults() const {
}
std::vector<CrystalLattice> FFTIndexer::Run(const std::vector<Coord> &coord, size_t nspots) {
std::vector<CrystalLattice> FFTIndexer::RunInternal(const std::vector<Coord> &coord, size_t nspots) {
if (nspots > coord.size())
nspots = coord.size();
+1 -1
View File
@@ -43,7 +43,7 @@ public:
explicit FFTIndexer(const IndexingSettings& settings);
~FFTIndexer() override = default;
std::vector<CrystalLattice> Run(const std::vector<Coord> &coord, size_t nspots) override;
std::vector<CrystalLattice> RunInternal(const std::vector<Coord> &coord, size_t nspots) override;
};
#endif //JFJOCH_FFTINDEXER_H
+5 -17
View File
@@ -12,25 +12,13 @@ void Indexer::Setup(const DiffractionExperiment& experiment) {
SetupUnitCell(experiment.GetUnitCell());
}
std::optional<CrystalLattice> Indexer::Run(DataMessage &message) {
IndexerResult Indexer::Run(const std::vector<Coord> &coord) {
IndexerResult ret;
auto start = std::chrono::high_resolution_clock::now();
std::vector<Coord> recip;
recip.reserve(message.spots.size());
for (const auto &i: message.spots) {
if (index_ice_rings || !i.ice_ring)
recip.push_back(i.ReciprocalCoord(geom));
}
auto ret = Run(recip, recip.size());
ret.lattice = RunInternal(coord, coord.size());
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<float> duration = end - start;
message.indexing_time_s = duration.count();
if (ret.empty())
return {};
return ret[0];
ret.indexing_time_s = duration.count();
return ret;
}
+6 -2
View File
@@ -12,6 +12,10 @@
#include "../../common/JFJochMessages.h"
#include "../../common/SpotToSave.h"
struct IndexerResult {
std::vector<CrystalLattice> lattice;
float indexing_time_s;
};
class Indexer {
protected:
@@ -24,11 +28,11 @@ protected:
std::optional<UnitCell> reference_unit_cell;
virtual void SetupUnitCell(const std::optional<UnitCell>& cell) = 0;
virtual std::vector<CrystalLattice> RunInternal(const std::vector<Coord> &coord, size_t nspots) = 0;
public:
virtual ~Indexer() = default;
void Setup(const DiffractionExperiment& experiment);
virtual std::vector<CrystalLattice> Run(const std::vector<Coord> &coord, size_t nspots) = 0;
std::optional<CrystalLattice> Run(DataMessage &message);
IndexerResult Run(const std::vector<Coord> &coord);
};
#endif //JFJOCH_INDEXER_H
@@ -47,11 +47,11 @@ IndexerThreadPool::~IndexerThreadPool() { {
}
}
std::future<std::optional<CrystalLattice> > IndexerThreadPool::Run(const DiffractionExperiment &experiment,
DataMessage &message) {
std::future<IndexerResult> IndexerThreadPool::Run(const DiffractionExperiment &experiment,
const std::vector<Coord>& recip) {
// Create a promise/future pair
auto promise = std::make_shared<std::promise<std::optional<CrystalLattice> > >();
std::future<std::optional<CrystalLattice> > result = promise->get_future(); {
auto promise = std::make_shared<std::promise<IndexerResult > >();
std::future<IndexerResult> result = promise->get_future(); {
std::unique_lock<std::mutex> lock(m);
// Don't allow enqueueing after stopping the pool
@@ -60,7 +60,7 @@ std::future<std::optional<CrystalLattice> > IndexerThreadPool::Run(const Diffrac
}
// Create a task package with the data message and coordinates
taskQueue.emplace(TaskPackage{promise, &experiment, &message});
taskQueue.emplace(TaskPackage{promise, &experiment, &recip});
}
cond.notify_one();
@@ -128,7 +128,7 @@ void IndexerThreadPool::Worker(int32_t threadIndex, const NUMAHWPolicy &numa_pol
}
}
try {
std::optional<CrystalLattice> result;
IndexerResult result;
auto algorithm = task.experiment->GetIndexingAlgorithm();
Indexer *indexer = nullptr;
@@ -143,7 +143,7 @@ void IndexerThreadPool::Worker(int32_t threadIndex, const NUMAHWPolicy &numa_pol
if (indexer) {
indexer->Setup(*task.experiment);
result = indexer->Run(*task.message);
result = indexer->Run(*task.recip);
}
// Set the result via the promise
+4 -4
View File
@@ -26,9 +26,9 @@ class IndexerThreadPool {
std::atomic<bool> failed_start = false;
struct TaskPackage {
std::shared_ptr<std::promise<std::optional<CrystalLattice>>> promise;
const DiffractionExperiment* experiment;
DataMessage* message;
std::shared_ptr<std::promise<IndexerResult>> promise;
const DiffractionExperiment *experiment;
const std::vector<Coord> *recip;
};
std::vector<std::thread> workers;
@@ -44,7 +44,7 @@ public:
IndexerThreadPool(const IndexingSettings& settings, const NUMAHWPolicy &numa_policy = NUMAHWPolicy());
~IndexerThreadPool();
std::future<std::optional<CrystalLattice>> Run(const DiffractionExperiment& experiment, DataMessage& message);
std::future<IndexerResult> Run(const DiffractionExperiment& experiment, const std::vector<Coord>& recip);
};
+3 -3
View File
@@ -36,11 +36,11 @@ TEST_CASE("FFTIndexer") {
logger.Info("Spots {}", vec.size());
auto start = std::chrono::high_resolution_clock::now();
auto result = indexer->Run(vec, vec.size());
auto result = indexer->Run(vec);
auto end = std::chrono::high_resolution_clock::now();
REQUIRE(result.size() == 1);
auto uc_out = result[0].GetUnitCell();
REQUIRE(result.lattice.size() == 1);
auto uc_out = result.lattice[0].GetUnitCell();
CHECK(uc_out.a == Catch::Approx(uc.a));;
CHECK(uc_out.b == Catch::Approx(uc.b));;
CHECK(uc_out.c == Catch::Approx(uc.c));;
+5 -5
View File
@@ -54,8 +54,8 @@ TEST_CASE("FastFeedbackIndexer","[Indexing]") {
experiment.SetUnitCell(c);
indexer->Setup(experiment);
auto ret = indexer->Run(recip, recip.size());
REQUIRE(!ret.empty());
auto ret = indexer->Run(recip);
REQUIRE(!ret.lattice.empty());
//auto uc = ret[0].GetUnitCell();
//REQUIRE(c.a == Catch::Approx(uc.a));
@@ -64,9 +64,9 @@ TEST_CASE("FastFeedbackIndexer","[Indexing]") {
double err[3] = {0.0, 0.0, 0.0};
for (const auto &iter: recip) {
err[0] += round_err(ret[0].Vec0() * iter);
err[1] += round_err(ret[0].Vec1() * iter);
err[2] += round_err(ret[0].Vec2() * iter);
err[0] += round_err(ret.lattice[0].Vec0() * iter);
err[1] += round_err(ret.lattice[0].Vec1() * iter);
err[2] += round_err(ret.lattice[0].Vec2() * iter);
}
REQUIRE (err[0] < 0.001 * recip.size());
REQUIRE (err[1] < 0.001 * recip.size());
+10 -2
View File
@@ -8,6 +8,7 @@
int main(int argc, char** argv) {
Logger logger("jfjoch_indexing_test");
/*
if (argc != 2) {
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Usage: ./jfjoch_indexing_test <file>");
}
@@ -21,6 +22,7 @@ int main(int argc, char** argv) {
experiment.IndexingAlgorithm(IndexingAlgorithmEnum::Auto);
// experiment.SetUnitCell(UnitCell(37,78,78,90,90,90));
auto indexer = CreateIndexer(experiment);
auto geom = experiment.GetDiffractionGeometry();
for (int i = 0; i < reader.GetNumberOfImages(); i++) {
auto img = reader.LoadImage(i);
@@ -29,7 +31,12 @@ int main(int argc, char** argv) {
continue;
DataMessage msg = img->ImageData();
std::vector<Coord> recip;
recip.reserve(msg.spots.size());
for (const auto &s: msg.spots) {
if (!s.ice_ring)
recip.emplace_back(s.ReciprocalCoord(geom));
}
auto output = indexer->Run(msg);
logger.Info("Result {} {}", msg.number, msg.indexing_result.value_or(0));
@@ -42,5 +49,6 @@ int main(int argc, char** argv) {
logger.Info("Lattice {:8.02f} {:8.02f} {:8.02f} {:8.02f} {:8.02f} {:8.02f} {:8.02f} {:8.02f} {:8.02f}",
a[0],a[1],a[2],b[0],b[1],b[2],c[0],c[1], c[2]);
}
}
} */
}