From d2c66edd4553bea54d722a259aebb5e6eb68b5cb Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 16:13:14 +0100 Subject: [PATCH] TCP: Allow to get written image count --- image_pusher/HDF5FilePusher.cpp | 21 +-- image_pusher/HDF5FilePusher.h | 3 +- image_pusher/ImagePusher.h | 10 +- image_pusher/TCPStreamPusher.cpp | 13 +- image_pusher/TCPStreamPusher.h | 2 +- image_pusher/TCPStreamPusherSocket.h | 8 ++ image_pusher/ZMQStream2Pusher.cpp | 22 +++- image_pusher/ZMQStream2Pusher.h | 5 +- tests/JFJochReceiverProcessingTest.cpp | 8 +- tests/TCPImagePusherTest.cpp | 176 +------------------------ 10 files changed, 48 insertions(+), 220 deletions(-) diff --git a/image_pusher/HDF5FilePusher.cpp b/image_pusher/HDF5FilePusher.cpp index e1ad3183..cca402ef 100644 --- a/image_pusher/HDF5FilePusher.cpp +++ b/image_pusher/HDF5FilePusher.cpp @@ -11,7 +11,6 @@ void HDF5FilePusher::StartDataCollection(StartMessage &message) { writer = std::make_unique(message); writer_future = std::async(std::launch::async, &HDF5FilePusher::WriterThread, this); images_written = 0; - images_err = 0; } bool HDF5FilePusher::EndDataCollection(const EndMessage &message) { @@ -36,12 +35,8 @@ bool HDF5FilePusher::SendImage(const uint8_t *image_data, size_t image_size, int auto deserialized = CBORStream2Deserialize(image_data, image_size); if (deserialized->data_message) { - try { - writer->Write(*deserialized->data_message); - images_written++; - } catch (const JFJochException &e) { - images_err++; - } + writer->Write(*deserialized->data_message); + images_written++; } else throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "HDF5FilePusher::SendImage accepts only data image"); @@ -78,14 +73,6 @@ std::string HDF5FilePusher::PrintSetup() const { return "HDF5FilePusher: Images are written directly to file in base directory " + currentPath.string(); } -std::optional HDF5FilePusher::GetAckProgress() const { - uint64_t ack_ok = images_written; - uint64_t ack_bad = images_err; - uint64_t ack_total = ack_ok + ack_bad; - - return ImagePusherAckProgress{ - .data_acked_ok = ack_ok, - .data_acked_bad = ack_bad, - .data_acked_total = ack_total - }; +std::optional HDF5FilePusher::GetImagesWritten() const { + return images_written; } diff --git a/image_pusher/HDF5FilePusher.h b/image_pusher/HDF5FilePusher.h index 91c28c80..cfa3b0cd 100644 --- a/image_pusher/HDF5FilePusher.h +++ b/image_pusher/HDF5FilePusher.h @@ -19,7 +19,6 @@ class HDF5FilePusher : public ImagePusher { void WriterThread(); std::atomic images_written = 0; - std::atomic images_err = 0; public: // Thread safety: StartDataCollection, EndDataCollection and SendCalibration must run poorly in serial context // SendImage can be executed in parallel @@ -31,7 +30,7 @@ public: std::string PrintSetup() const override; - std::optional GetAckProgress() const override; + std::optional GetImagesWritten() const override; }; diff --git a/image_pusher/ImagePusher.h b/image_pusher/ImagePusher.h index bba1a8a7..e12d09f0 100644 --- a/image_pusher/ImagePusher.h +++ b/image_pusher/ImagePusher.h @@ -20,14 +20,6 @@ struct ImagePusherQueueElement { bool end; }; -struct ImagePusherAckProgress { - uint64_t data_sent = 0; - uint64_t data_acked_ok = 0; - uint64_t data_acked_bad = 0; - uint64_t data_acked_total = 0; - uint64_t data_ack_pending = 0; -}; - void PrepareCBORImage(DataMessage& message, const DiffractionExperiment &experiment, void *image, size_t image_size); @@ -43,7 +35,7 @@ public: virtual std::string GetWriterNotificationSocketAddress() const; virtual ~ImagePusher() = default; virtual std::string PrintSetup() const = 0; - virtual std::optional GetAckProgress() const { return std::nullopt; } + virtual std::optional GetImagesWritten() const { return std::nullopt; } }; diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index 5814e204..75066630 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -168,16 +168,11 @@ bool TCPStreamPusher::SendCalibration(const CompressedImage &message) { return socket[0]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION); } -std::optional TCPStreamPusher::GetAckProgress() const { - ImagePusherAckProgress out{ - }; +std::optional TCPStreamPusher::GetImagesWritten() const { + uint64_t ret = 0; for (const auto &s : socket) { auto p = s->GetDataAckProgress(); - out.data_sent += p.data_sent; - out.data_acked_ok += p.data_acked_ok; - out.data_acked_bad += p.data_acked_bad; - out.data_acked_total += p.data_acked_total; - out.data_ack_pending += p.data_ack_pending; + ret += p.data_acked_ok; } - return out; + return ret; } diff --git a/image_pusher/TCPStreamPusher.h b/image_pusher/TCPStreamPusher.h index c9c4d631..5451b714 100644 --- a/image_pusher/TCPStreamPusher.h +++ b/image_pusher/TCPStreamPusher.h @@ -30,5 +30,5 @@ public: std::string Finalize() override; std::string PrintSetup() const override; - std::optional GetAckProgress() const override; + std::optional GetImagesWritten() const override; }; \ No newline at end of file diff --git a/image_pusher/TCPStreamPusherSocket.h b/image_pusher/TCPStreamPusherSocket.h index a00be113..b128f204 100644 --- a/image_pusher/TCPStreamPusherSocket.h +++ b/image_pusher/TCPStreamPusherSocket.h @@ -16,6 +16,14 @@ #include "../common/Logger.h" #include "../common/JfjochTCP.h" +struct ImagePusherAckProgress { + uint64_t data_sent = 0; + uint64_t data_acked_ok = 0; + uint64_t data_acked_bad = 0; + uint64_t data_acked_total = 0; + uint64_t data_ack_pending = 0; +}; + class TCPStreamPusherSocket { struct InflightZC { ZeroCopyReturnValue *z = nullptr; diff --git a/image_pusher/ZMQStream2Pusher.cpp b/image_pusher/ZMQStream2Pusher.cpp index 77817691..b83db799 100644 --- a/image_pusher/ZMQStream2Pusher.cpp +++ b/image_pusher/ZMQStream2Pusher.cpp @@ -37,6 +37,10 @@ void ZMQStream2Pusher::SendImage(ZeroCopyReturnValue &z) { } void ZMQStream2Pusher::StartDataCollection(StartMessage& message) { + { + std::unique_lock ul(images_written_mutex); + images_written = std::nullopt; + } if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative"); @@ -93,6 +97,7 @@ std::string ZMQStream2Pusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; + uint64_t images = 0; if (writer_notification_socket) { for (int i = 0; i < socket.size(); i++) { auto n = writer_notification_socket->Receive(run_number, run_name); @@ -100,10 +105,18 @@ std::string ZMQStream2Pusher::Finalize() { ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end"; else if (n->socket_number >= socket.size()) ret += "Writer " + std::to_string(i) + ": mismatch in socket number"; - else if (!n->ok) - ret += "Writer " + std::to_string(i) + ": " + n->error; + else { + if (!n->ok) + ret += "Writer " + std::to_string(i) + ": " + n->error; + images += n->processed_images; + } } } + + { + std::unique_lock ul(images_written_mutex); + images_written = images; + } return ret; } @@ -125,3 +138,8 @@ std::string ZMQStream2Pusher::PrintSetup() const { output += s->GetEndpointName() + " "; return output; } + +std::optional ZMQStream2Pusher::GetImagesWritten() const { + std::unique_lock ul(images_written_mutex); + return images_written; +} \ No newline at end of file diff --git a/image_pusher/ZMQStream2Pusher.h b/image_pusher/ZMQStream2Pusher.h index 5007c5d4..23e046e8 100644 --- a/image_pusher/ZMQStream2Pusher.h +++ b/image_pusher/ZMQStream2Pusher.h @@ -22,6 +22,9 @@ class ZMQStream2Pusher : public ImagePusher { uint64_t run_number = 0; std::string run_name; std::atomic transmission_error = false; + + mutable std::mutex images_written_mutex; + std::optional images_written; public: explicit ZMQStream2Pusher(const std::vector& addr, std::optional send_buffer_high_watermark = {}, @@ -43,8 +46,8 @@ public: std::string Finalize() override; - std::string PrintSetup() const override; + std::optional GetImagesWritten() const override; }; #endif //JUNGFRAUJOCH_ZMQSTREAM2PUSHER_H diff --git a/tests/JFJochReceiverProcessingTest.cpp b/tests/JFJochReceiverProcessingTest.cpp index 1ef87746..ad8524d6 100644 --- a/tests/JFJochReceiverProcessingTest.cpp +++ b/tests/JFJochReceiverProcessingTest.cpp @@ -1548,11 +1548,7 @@ TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver] REQUIRE_NOTHROW(writer_future.get()); - auto ack = pusher.GetAckProgress(); + auto ack = pusher.GetImagesWritten(); REQUIRE(ack.has_value()); - CHECK(ack->data_acked_ok == experiment.GetImageNum()); - CHECK(ack->data_acked_bad == 0); - CHECK(ack->data_acked_total == experiment.GetImageNum()); - CHECK(ack->data_ack_pending == 0); - CHECK(ack->data_sent == experiment.GetImageNum()); + CHECK(ack == experiment.GetImageNum()); } diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index 9024c082..b7d1f454 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -331,22 +331,11 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { REQUIRE(pusher.EndDataCollection(end)); - const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); - while (std::chrono::steady_clock::now() < deadline) { - auto progress = pusher.GetAckProgress(); - REQUIRE(progress.has_value()); - if (progress->data_acked_total == nframes) - break; - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } + std::this_thread::sleep_for(std::chrono::seconds(5)); - auto progress = pusher.GetAckProgress(); + auto progress = pusher.GetImagesWritten(); REQUIRE(progress.has_value()); - REQUIRE(progress->data_sent == nframes); - REQUIRE(progress->data_acked_ok == nframes / 2); - REQUIRE(progress->data_acked_bad == nframes / 2); - REQUIRE(progress->data_acked_total == nframes); - REQUIRE(progress->data_ack_pending == 0); + REQUIRE(progress == nframes / 2); }); std::vector receivers; @@ -419,162 +408,3 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { for (auto &p : puller) p->Disconnect(); } - -TEST_CASE("TCPImageCommTest_GetAckProgress_InFlightPending", "[TCP]") { - const size_t nframes = 128; - const int64_t npullers = 2; - const int64_t images_per_file = 8; - - DiffractionExperiment x(DetJF(1)); - x.Raw(); - x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) - .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); - - std::mt19937 g1(321); - std::uniform_int_distribution dist; - std::vector image1(x.GetPixelsNum() * nframes); - for (auto &i : image1) i = dist(g1); - - std::vector addr{ - "tcp://127.0.0.1:19031", - "tcp://127.0.0.1:19032" - }; - - std::vector> puller; - for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); - } - - TCPStreamPusher pusher( - addr, - 64 * 1024 * 1024, - 128 * 1024, - 8192 - ); - - std::atomic observed_pending{false}; - - std::thread sender([&] { - std::vector serialization_buffer(16 * 1024 * 1024); - CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); - - StartMessage start{ - .images_per_file = images_per_file, - .write_master_file = true - }; - EndMessage end{}; - - pusher.StartDataCollection(start); - - for (int64_t i = 0; i < static_cast(nframes); i++) { - DataMessage data_message; - data_message.number = i; - data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), - x.GetPixelsNum() * sizeof(uint16_t), - x.GetXPixelsNum(), - x.GetYPixelsNum(), - x.GetImageMode(), - x.GetCompressionAlgorithm()); - serializer.SerializeImage(data_message); - REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); - - if (i >= 16 && !observed_pending.load()) { - auto progress = pusher.GetAckProgress(); - REQUIRE(progress.has_value()); - if (progress->data_sent > progress->data_acked_total && progress->data_ack_pending > 0) { - observed_pending = true; - } - } - } - - REQUIRE(pusher.EndDataCollection(end)); - REQUIRE(observed_pending.load()); - - const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); - while (std::chrono::steady_clock::now() < deadline) { - auto progress = pusher.GetAckProgress(); - REQUIRE(progress.has_value()); - if (progress->data_acked_total == nframes) - break; - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - - auto progress = pusher.GetAckProgress(); - REQUIRE(progress.has_value()); - REQUIRE(progress->data_sent == nframes); - REQUIRE(progress->data_acked_ok == nframes); - REQUIRE(progress->data_acked_bad == 0); - REQUIRE(progress->data_acked_total == nframes); - REQUIRE(progress->data_ack_pending == 0); - }); - - std::vector receivers; - receivers.reserve(npullers); - - for (int w = 0; w < npullers; w++) { - receivers.emplace_back([&, w] { - bool seen_end = false; - uint64_t processed = 0; - - while (!seen_end) { - auto out = puller[w]->PollImage(std::chrono::seconds(10)); - REQUIRE(out.has_value()); - REQUIRE(out->cbor != nullptr); - REQUIRE(out->tcp_msg != nullptr); - - const auto &h = out->tcp_msg->header; - - if (out->cbor->start_message) { - PullerAckMessage ack; - ack.ack_for = TCPFrameType::START; - ack.ok = true; - ack.fatal = false; - ack.run_number = h.run_number; - ack.socket_number = h.socket_number; - ack.image_number = 0; - ack.processed_images = 0; - ack.error_code = TCPAckCode::None; - REQUIRE(puller[w]->SendAck(ack)); - continue; - } - - if (out->cbor->data_message) { - processed++; - std::this_thread::sleep_for(std::chrono::milliseconds(3)); - - PullerAckMessage ack; - ack.ack_for = TCPFrameType::DATA; - ack.ok = true; - ack.fatal = false; - ack.run_number = h.run_number; - ack.socket_number = h.socket_number; - ack.image_number = static_cast(out->cbor->data_message->number); - ack.processed_images = processed; - ack.error_code = TCPAckCode::None; - REQUIRE(puller[w]->SendAck(ack)); - continue; - } - - if (out->cbor->end_message) { - PullerAckMessage ack; - ack.ack_for = TCPFrameType::END; - ack.ok = true; - ack.fatal = false; - ack.run_number = h.run_number; - ack.socket_number = h.socket_number; - ack.image_number = 0; - ack.processed_images = processed; - ack.error_code = TCPAckCode::None; - REQUIRE(puller[w]->SendAck(ack)); - seen_end = true; - } - } - }); - } - - sender.join(); - for (auto &t : receivers) t.join(); - - for (auto &p : puller) - p->Disconnect(); -} \ No newline at end of file