diff --git a/image_pusher/HDF5FilePusher.cpp b/image_pusher/HDF5FilePusher.cpp index e3b68ee7..e1ad3183 100644 --- a/image_pusher/HDF5FilePusher.cpp +++ b/image_pusher/HDF5FilePusher.cpp @@ -12,7 +12,6 @@ void HDF5FilePusher::StartDataCollection(StartMessage &message) { writer_future = std::async(std::launch::async, &HDF5FilePusher::WriterThread, this); images_written = 0; images_err = 0; - last_processed_image = 0; } bool HDF5FilePusher::EndDataCollection(const EndMessage &message) { @@ -40,8 +39,6 @@ bool HDF5FilePusher::SendImage(const uint8_t *image_data, size_t image_size, int try { writer->Write(*deserialized->data_message); images_written++; - if (image_number > last_processed_image) - last_processed_image = image_number; } catch (const JFJochException &e) { images_err++; } @@ -85,12 +82,10 @@ std::optional HDF5FilePusher::GetAckProgress() const { uint64_t ack_ok = images_written; uint64_t ack_bad = images_err; uint64_t ack_total = ack_ok + ack_bad; - uint64_t last = last_processed_image; return ImagePusherAckProgress{ .data_acked_ok = ack_ok, .data_acked_bad = ack_bad, - .data_acked_total = ack_total, - .last_processed_images = last + .data_acked_total = ack_total }; } diff --git a/image_pusher/HDF5FilePusher.h b/image_pusher/HDF5FilePusher.h index c475fcf8..91c28c80 100644 --- a/image_pusher/HDF5FilePusher.h +++ b/image_pusher/HDF5FilePusher.h @@ -20,7 +20,6 @@ class HDF5FilePusher : public ImagePusher { std::atomic images_written = 0; std::atomic images_err = 0; - std::atomic last_processed_image = 0; public: // Thread safety: StartDataCollection, EndDataCollection and SendCalibration must run poorly in serial context // SendImage can be executed in parallel diff --git a/image_pusher/ImagePusher.h b/image_pusher/ImagePusher.h index 455dd032..bba1a8a7 100644 --- a/image_pusher/ImagePusher.h +++ b/image_pusher/ImagePusher.h @@ -26,7 +26,6 @@ struct ImagePusherAckProgress { uint64_t data_acked_bad = 0; uint64_t data_acked_total = 0; uint64_t data_ack_pending = 0; - uint64_t last_processed_images = 0; }; void PrepareCBORImage(DataMessage& message, diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index b0714bea..5814e204 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -169,7 +169,8 @@ bool TCPStreamPusher::SendCalibration(const CompressedImage &message) { } std::optional TCPStreamPusher::GetAckProgress() const { - ImagePusherAckProgress out; + ImagePusherAckProgress out{ + }; for (const auto &s : socket) { auto p = s->GetDataAckProgress(); out.data_sent += p.data_sent; diff --git a/image_pusher/TCPStreamPusherSocket.cpp b/image_pusher/TCPStreamPusherSocket.cpp index d39ebd93..4aa9b9a3 100644 --- a/image_pusher/TCPStreamPusherSocket.cpp +++ b/image_pusher/TCPStreamPusherSocket.cpp @@ -460,13 +460,10 @@ void TCPStreamPusherSocket::AckThread() { last_ack_error = "CANCEL rejected"; } else if (ack_for == TCPFrameType::DATA) { data_acked_total.fetch_add(1, std::memory_order_relaxed); - last_processed_images.store(h.ack_processed_images, std::memory_order_relaxed); - if (ok && !fatal) { data_acked_ok.fetch_add(1, std::memory_order_relaxed); } else { data_acked_bad.fetch_add(1, std::memory_order_relaxed); - broken = true; // mandatory DATA ACK mode: bad DATA ACK breaks stream if (error_text.empty()) last_ack_error = "DATA ACK failed"; logger.Error("Received failing DATA ACK on " + endpoint + ": " + last_ack_error); @@ -497,7 +494,6 @@ void TCPStreamPusherSocket::StartWriterThread() { data_acked_ok.store(0, std::memory_order_relaxed); data_acked_bad.store(0, std::memory_order_relaxed); data_acked_total.store(0, std::memory_order_relaxed); - last_processed_images.store(0, std::memory_order_relaxed); active = true; send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this); @@ -598,6 +594,5 @@ ImagePusherAckProgress TCPStreamPusherSocket::GetDataAckProgress() const { p.data_acked_bad = data_acked_bad.load(std::memory_order_relaxed); p.data_acked_total = data_acked_total.load(std::memory_order_relaxed); p.data_ack_pending = (p.data_sent >= p.data_acked_total) ? (p.data_sent - p.data_acked_total) : 0; - p.last_processed_images = last_processed_images.load(std::memory_order_relaxed); return p; } \ No newline at end of file diff --git a/image_pusher/TCPStreamPusherSocket.h b/image_pusher/TCPStreamPusherSocket.h index d15a23aa..a00be113 100644 --- a/image_pusher/TCPStreamPusherSocket.h +++ b/image_pusher/TCPStreamPusherSocket.h @@ -62,7 +62,6 @@ class TCPStreamPusherSocket { std::atomic data_acked_ok{0}; std::atomic data_acked_bad{0}; std::atomic data_acked_total{0}; - std::atomic last_processed_images{0}; void WriterThread(); void CompletionThread(); diff --git a/tests/JFJochReceiverProcessingTest.cpp b/tests/JFJochReceiverProcessingTest.cpp index 03427553..1ef87746 100644 --- a/tests/JFJochReceiverProcessingTest.cpp +++ b/tests/JFJochReceiverProcessingTest.cpp @@ -1547,4 +1547,12 @@ TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver] REQUIRE(!service.GetProgress().has_value()); REQUIRE_NOTHROW(writer_future.get()); + + auto ack = pusher.GetAckProgress(); + REQUIRE(ack.has_value()); + CHECK(ack->data_acked_ok == experiment.GetImageNum()); + CHECK(ack->data_acked_bad == 0); + CHECK(ack->data_acked_total == experiment.GetImageNum()); + CHECK(ack->data_ack_pending == 0); + CHECK(ack->data_sent == experiment.GetImageNum()); } diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index 0cb0a646..9024c082 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -268,6 +268,313 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { sender.join(); for (auto &t : receivers) t.join(); + for (auto &p : puller) + p->Disconnect(); +} + +TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { + const size_t nframes = 64; + const int64_t npullers = 2; + const int64_t images_per_file = 8; + + DiffractionExperiment x(DetJF(1)); + x.Raw(); + x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) + .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); + + std::mt19937 g1(123); + std::uniform_int_distribution dist; + std::vector image1(x.GetPixelsNum() * nframes); + for (auto &i : image1) i = dist(g1); + + std::vector addr{ + "tcp://127.0.0.1:19021", + "tcp://127.0.0.1:19022" + }; + + std::vector> puller; + for (int i = 0; i < npullers; i++) { + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); + } + + TCPStreamPusher pusher( + addr, + 64 * 1024 * 1024, + 128 * 1024, + 8192 + ); + + std::thread sender([&] { + std::vector serialization_buffer(16 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + + StartMessage start{ + .images_per_file = images_per_file, + .write_master_file = true + }; + EndMessage end{}; + + pusher.StartDataCollection(start); + + for (int64_t i = 0; i < static_cast(nframes); i++) { + DataMessage data_message; + data_message.number = i; + data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t), + x.GetXPixelsNum(), + x.GetYPixelsNum(), + x.GetImageMode(), + x.GetCompressionAlgorithm()); + serializer.SerializeImage(data_message); + REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); + } + + REQUIRE(pusher.EndDataCollection(end)); + + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (std::chrono::steady_clock::now() < deadline) { + auto progress = pusher.GetAckProgress(); + REQUIRE(progress.has_value()); + if (progress->data_acked_total == nframes) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + auto progress = pusher.GetAckProgress(); + REQUIRE(progress.has_value()); + REQUIRE(progress->data_sent == nframes); + REQUIRE(progress->data_acked_ok == nframes / 2); + REQUIRE(progress->data_acked_bad == nframes / 2); + REQUIRE(progress->data_acked_total == nframes); + REQUIRE(progress->data_ack_pending == 0); + }); + + std::vector receivers; + receivers.reserve(npullers); + + for (int w = 0; w < npullers; w++) { + receivers.emplace_back([&, w] { + bool seen_end = false; + uint64_t processed = 0; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(10)); + REQUIRE(out.has_value()); + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = 0; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->data_message) { + auto number = out->cbor->data_message->number; + + processed++; + PullerAckMessage ack; + ack.ack_for = TCPFrameType::DATA; + ack.ok = (number % 2 == 0) ? true : false; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = number; + ack.processed_images = processed; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = processed; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } + } + }); + } + + sender.join(); + for (auto &t : receivers) t.join(); + + for (auto &p : puller) + p->Disconnect(); +} + +TEST_CASE("TCPImageCommTest_GetAckProgress_InFlightPending", "[TCP]") { + const size_t nframes = 128; + const int64_t npullers = 2; + const int64_t images_per_file = 8; + + DiffractionExperiment x(DetJF(1)); + x.Raw(); + x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) + .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); + + std::mt19937 g1(321); + std::uniform_int_distribution dist; + std::vector image1(x.GetPixelsNum() * nframes); + for (auto &i : image1) i = dist(g1); + + std::vector addr{ + "tcp://127.0.0.1:19031", + "tcp://127.0.0.1:19032" + }; + + std::vector> puller; + for (int i = 0; i < npullers; i++) { + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); + } + + TCPStreamPusher pusher( + addr, + 64 * 1024 * 1024, + 128 * 1024, + 8192 + ); + + std::atomic observed_pending{false}; + + std::thread sender([&] { + std::vector serialization_buffer(16 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + + StartMessage start{ + .images_per_file = images_per_file, + .write_master_file = true + }; + EndMessage end{}; + + pusher.StartDataCollection(start); + + for (int64_t i = 0; i < static_cast(nframes); i++) { + DataMessage data_message; + data_message.number = i; + data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t), + x.GetXPixelsNum(), + x.GetYPixelsNum(), + x.GetImageMode(), + x.GetCompressionAlgorithm()); + serializer.SerializeImage(data_message); + REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); + + if (i >= 16 && !observed_pending.load()) { + auto progress = pusher.GetAckProgress(); + REQUIRE(progress.has_value()); + if (progress->data_sent > progress->data_acked_total && progress->data_ack_pending > 0) { + observed_pending = true; + } + } + } + + REQUIRE(pusher.EndDataCollection(end)); + REQUIRE(observed_pending.load()); + + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (std::chrono::steady_clock::now() < deadline) { + auto progress = pusher.GetAckProgress(); + REQUIRE(progress.has_value()); + if (progress->data_acked_total == nframes) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + auto progress = pusher.GetAckProgress(); + REQUIRE(progress.has_value()); + REQUIRE(progress->data_sent == nframes); + REQUIRE(progress->data_acked_ok == nframes); + REQUIRE(progress->data_acked_bad == 0); + REQUIRE(progress->data_acked_total == nframes); + REQUIRE(progress->data_ack_pending == 0); + }); + + std::vector receivers; + receivers.reserve(npullers); + + for (int w = 0; w < npullers; w++) { + receivers.emplace_back([&, w] { + bool seen_end = false; + uint64_t processed = 0; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(10)); + REQUIRE(out.has_value()); + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = 0; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->data_message) { + processed++; + std::this_thread::sleep_for(std::chrono::milliseconds(3)); + + PullerAckMessage ack; + ack.ack_for = TCPFrameType::DATA; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = static_cast(out->cbor->data_message->number); + ack.processed_images = processed; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = processed; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } + } + }); + } + + sender.join(); + for (auto &t : receivers) t.join(); + for (auto &p : puller) p->Disconnect(); } \ No newline at end of file