// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" #include "../image_puller/ZMQImagePuller.h" #include "../common/ImageBuffer.h" #include "../common/ZeroCopyReturnValue.h" TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; const int64_t npullers = 2; const int64_t images_per_file = 16; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); x.RunNumber(567); std::mt19937 g1(1387); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers); std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(pusher.GetAddress()[0], 64 * 1024 * 1024)); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::vector received(npullers, 0); std::vector processed(npullers, 0); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .run_number = x.GetRunNumber(), .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); }); std::vector receivers; receivers.reserve(npullers); std::mutex counts_mutex; std::vector received_by_socket(npullers, 0); std::vector processed_by_socket(npullers, 0); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_start = false; bool seen_end = false; std::optional my_socket_number; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { my_socket_number = h.socket_number; PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(ack.run_number == x.GetRunNumber()); REQUIRE(puller[w]->SendAck(ack)); seen_start = true; continue; } if (out->cbor->data_message) { REQUIRE(seen_start); REQUIRE(my_socket_number.has_value()); auto n = out->cbor->data_message->number; REQUIRE(((n / images_per_file) % npullers) == static_cast(*my_socket_number)); { std::lock_guard lg(counts_mutex); received_by_socket.at(*my_socket_number)++; processed_by_socket.at(*my_socket_number)++; } continue; } if (out->cbor->end_message) { REQUIRE(my_socket_number.has_value()); PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; { std::lock_guard lg(counts_mutex); ack.processed_images = processed_by_socket.at(*my_socket_number); } ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); REQUIRE(received_by_socket[0] == nframes / 2); REQUIRE(received_by_socket[1] == nframes / 2); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(42); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers); std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(pusher.GetAddress()[0], 64 * 1024 * 1024)); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::atomic sent_fatal{false}; std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } REQUIRE(pusher.EndDataCollection(end)); const auto final_msg = pusher.Finalize(); REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; bool local_fatal_sent = false; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(2)); if (!out.has_value()) { // Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream. if (local_fatal_sent) break; REQUIRE(out.has_value()); } REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { if (w == 0 && !sent_fatal.exchange(true)) { PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = false; ack.fatal = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = static_cast(out->cbor->data_message->number); ack.error_code = TCPAckCode::DiskQuotaExceeded; ack.error_text = "quota exceeded"; REQUIRE(puller[w]->SendAck(ack)); local_fatal_sent = true; } continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(123); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers); std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(pusher.GetAddress()[0], 64 * 1024 * 1024)); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); std::this_thread::sleep_for(std::chrono::seconds(5)); auto progress = pusher.GetImagesWritten(); REQUIRE(progress.has_value()); REQUIRE(progress == nframes / 2); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; uint64_t processed = 0; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { auto number = out->cbor->data_message->number; processed++; PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = (number % 2 == 0) ? true : false; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = number; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_AutoPort_StarBind", "[TCP]") { const size_t nframes = 8; const int64_t images_per_file = 4; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::vector image1(x.GetPixelsNum() * nframes, 7u); TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024); std::this_thread::sleep_for(std::chrono::seconds(2)); REQUIRE(pusher.GetConnectedWriters() == 1); std::future receiver = std::async(std::launch::async, [&] { bool seen_end = false; uint64_t processed = 0; while (!seen_end) { auto out = puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { processed++; } else if (out->cbor->end_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); seen_end = true; } } }); std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{.images_per_file = images_per_file, .write_master_file = true}; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); REQUIRE_NOTHROW(receiver.get()); puller.Disconnect(); } TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { const size_t nframes = 256; const int64_t images_per_file = 16; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::vector image1(x.GetPixelsNum() * nframes, 11u); TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024); std::thread receiver([&] { bool disconnected = false; while (!disconnected) { auto out = puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { puller.Disconnect(); // simulate puller disappearing mid-stream disconnected = true; } } }); auto sender = std::async(std::launch::async, [&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{.images_per_file = images_per_file, .write_master_file = true}; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } return pusher.EndDataCollection(end); }); REQUIRE(sender.wait_for(std::chrono::seconds(20)) == std::future_status::ready); CHECK(sender.get() == false); receiver.join(); } TEST_CASE("TCPImageCommTest_RepubToZMQ", "[TCP][ZeroMQ]") { // Chain: TCPStreamPusher --TCP--> TCPImagePuller --ZMQ repub--> ZMQImagePuller const size_t nframes = 64; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(9999); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); // 1. Create TCP pusher on an auto-assigned port TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); // 2. Create TCP puller with repub over ZMQ (ipc, auto-assigned) const std::string repub_addr = "ipc://*"; // Need to figure out the actual repub endpoint after bind — ZMQ ipc://* picks a temp path. // However, ZMQSocket::Bind with "ipc://*" is used in project; the repub socket binds internally, // so we need a known address. Use a tcp address instead for testability. const std::string repub_bind_addr = "tcp://127.0.0.1:19010"; TCPImagePuller tcp_puller(pusher.GetAddress()[0], {}, repub_bind_addr); // 3. Create ZMQ puller that connects to the repub address ZMQImagePuller zmq_puller(repub_bind_addr); // Wait for TCP connection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < 1; ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == 1); // Sender thread: push frames over TCP std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); }); // TCP puller consumer: drains the TCP side (with ACKs) so data keeps flowing std::thread tcp_consumer([&] { bool seen_end = false; while (!seen_end) { auto out = tcp_puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{}; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(tcp_puller.SendAck(ack)); } else if (out->cbor->end_message) { PullerAckMessage ack{}; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(tcp_puller.SendAck(ack)); seen_end = true; } // data frames: no ack needed for this test } }); // ZMQ puller consumer: verify the republished stream size_t zmq_nimages = 0; size_t zmq_errors = 0; bool zmq_seen_start = false; bool zmq_seen_end = false; std::thread zmq_consumer([&] { auto timeout = std::chrono::seconds(30); // First message should be START auto img = zmq_puller.PollImage(timeout); if (!img || !img->cbor || !img->cbor->start_message) { zmq_errors++; return; } zmq_seen_start = true; // Republished START should have writer_notification_zmq_addr cleared if (!img->cbor->start_message->writer_notification_zmq_addr.empty()) { zmq_errors++; } // Consume data and END img = zmq_puller.PollImage(timeout); while (img && img->cbor && !img->cbor->end_message) { if (img->cbor->data_message) { auto n = img->cbor->data_message->number; if (img->cbor->data_message->image.GetCompressedSize() != x.GetPixelsNum() * sizeof(uint16_t)) zmq_errors++; else if (memcmp(img->cbor->data_message->image.GetCompressed(), image1.data() + n * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t)) != 0) zmq_errors++; zmq_nimages++; } img = zmq_puller.PollImage(timeout); } if (img && img->cbor && img->cbor->end_message) zmq_seen_end = true; }); sender.join(); tcp_consumer.join(); zmq_consumer.join(); tcp_puller.Disconnect(); zmq_puller.Disconnect(); // The repub uses non-blocking Put for data, so some frames *could* be dropped // under extreme back-pressure, but with only 64 frames we expect all of them. REQUIRE(zmq_seen_start); REQUIRE(zmq_seen_end); REQUIRE(zmq_nimages == nframes); REQUIRE(zmq_errors == 0); } namespace { // Controllable TCP "writer" peer for backpressure tests. Connects to the pusher, ACKs // START, then *stalls* (stops draining the socket) until Release() is called, while a // background thread keeps sending BUSY heartbeats — i.e. a writer that is alive but // wedged (e.g. on a slow filesystem at high frame rate). Catch2 assertion macros are not // thread-safe, so the worker threads only touch atomics; the test thread asserts. class StallableWriterDouble { public: StallableWriterDouble(const std::string &tcp_addr, int rcvbuf_bytes) { auto [host, port] = ParseHostPort(tcp_addr); fd_ = ::socket(AF_INET, SOCK_STREAM, 0); if (fd_ < 0) return; setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &rcvbuf_bytes, sizeof(rcvbuf_bytes)); sockaddr_in sin{}; sin.sin_family = AF_INET; sin.sin_port = htons(port); inet_pton(AF_INET, host.c_str(), &sin.sin_addr); if (::connect(fd_, reinterpret_cast(&sin), sizeof(sin)) != 0) { ::close(fd_); fd_ = -1; return; } busy_thread_ = std::thread([this] { BusyLoop(); }); reader_thread_ = std::thread([this] { ReaderLoop(); }); } ~StallableWriterDouble() { stop_ = true; Release(); if (fd_ >= 0) ::shutdown(fd_, SHUT_RDWR); if (reader_thread_.joinable()) reader_thread_.join(); if (busy_thread_.joinable()) busy_thread_.join(); if (fd_ >= 0) ::close(fd_); } [[nodiscard]] bool Connected() const { return fd_ >= 0; } // Stop stalling: let the reader drain DATA and ACK END. void Release() { { std::lock_guard lg(mtx_); released_ = true; } cv_.notify_all(); } [[nodiscard]] size_t DataFramesReceived() const { return data_frames_.load(); } [[nodiscard]] bool EndAcked() const { return end_acked_.load(); } private: static std::pair ParseHostPort(const std::string &addr) { const std::string prefix = "tcp://"; const auto hp = addr.substr(prefix.size()); const auto p = hp.find_last_of(':'); return {hp.substr(0, p), static_cast(std::stoi(hp.substr(p + 1)))}; } bool SendHeader(TCPFrameType type, TCPFrameType ack_for, uint64_t run, uint32_t sock, uint32_t flags) { TcpFrameHeader h{}; h.type = static_cast(type); h.ack_for = static_cast(ack_for); h.run_number = run; h.socket_number = sock; h.flags = flags; h.payload_size = 0; std::lock_guard lg(send_mtx_); if (fd_ < 0) return false; return ::send(fd_, &h, sizeof(h), MSG_NOSIGNAL) == static_cast(sizeof(h)); } bool ReadExact(void *buf, size_t len) { auto *p = static_cast(buf); size_t got = 0; while (got < len) { const ssize_t rc = ::recv(fd_, p + got, len - got, 0); if (rc <= 0) return false; got += static_cast(rc); } return true; } void BusyLoop() { // Heartbeat keeps the pusher's peer-liveness fresh even while we are not draining. while (!stop_) { SendHeader(TCPFrameType::BUSY, TCPFrameType::DATA, run_.load(), sock_.load(), 0); for (int i = 0; i < 5 && !stop_; i++) std::this_thread::sleep_for(std::chrono::milliseconds(50)); } } void ReaderLoop() { std::vector discard; while (!stop_) { TcpFrameHeader h{}; if (!ReadExact(&h, sizeof(h))) return; if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION) return; if (h.payload_size > 0) { discard.resize(h.payload_size); if (!ReadExact(discard.data(), discard.size())) return; } switch (static_cast(h.type)) { case TCPFrameType::START: run_.store(h.run_number); sock_.store(h.socket_number); SendHeader(TCPFrameType::ACK, TCPFrameType::START, h.run_number, h.socket_number, TCP_ACK_FLAG_OK); { // Stall: stop reading until released. std::unique_lock ul(mtx_); cv_.wait(ul, [this] { return released_ || stop_; }); } break; case TCPFrameType::DATA: data_frames_.fetch_add(1); break; case TCPFrameType::END: SendHeader(TCPFrameType::ACK, TCPFrameType::END, h.run_number, h.socket_number, TCP_ACK_FLAG_OK); end_acked_.store(true); return; default: break; // ignore KEEPALIVE etc. } } } int fd_ = -1; std::thread reader_thread_; std::thread busy_thread_; std::atomic stop_{false}; std::atomic run_{0}; std::atomic sock_{0}; std::atomic data_frames_{0}; std::atomic end_acked_{false}; std::mutex send_mtx_; std::mutex mtx_; std::condition_variable cv_; bool released_ = false; }; } // namespace TEST_CASE("TCPImageCommTest_StalledWriter_SurvivesViaHeartbeat", "[TCP]") { // A writer that is alive (still heartbeating) but has stopped draining — e.g. wedged // on a slow filesystem at high frame rate — must NOT be dropped mid-run. The pusher // rides out the backpressure on the production zero-copy queue path until the writer // recovers. Regression for the queue-path send giving up on a fixed deadline, and for // the BUSY heartbeat keeping the connection alive past the peer-liveness window. constexpr int64_t N = 1000; // > queue depth + socket buffers constexpr auto liveness = std::chrono::milliseconds(2000); constexpr auto stall = std::chrono::milliseconds(4000); // > liveness AND > old send deadline // Small SO_SNDBUF/SO_RCVBUF so backpressure reaches the queue after few images. TCPStreamPusher pusher("tcp://127.0.0.1:*", 1, 16 * 1024); pusher.SetPeerLivenessTimeout(liveness); StallableWriterDouble writer(pusher.GetAddress()[0], 16 * 1024); REQUIRE(writer.Connected()); for (int attempt = 0; attempt < 200 && pusher.GetConnectedWriters() < 1; ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(25)); REQUIRE(pusher.GetConnectedWriters() == 1); ImageBuffer image_buffer(16 * 1024 * 1024); image_buffer.StartMeasurement(static_cast(4096)); StartMessage start{.images_per_file = 1000, .write_master_file = true}; pusher.StartDataCollection(start); // writer ACKs START, then stalls (stops reading) auto sender = std::async(std::launch::async, [&] { for (int64_t i = 0; i < N; i++) { ZeroCopyReturnValue *slot = nullptr; while ((slot = image_buffer.GetImageSlot()) == nullptr) std::this_thread::sleep_for(std::chrono::milliseconds(1)); std::memset(slot->GetImage(), 0, 256); slot->SetImageNumber(i); slot->SetImageSize(256); // arbitrary payload; the writer double discards it slot->ReadyToSend(); pusher.SendImage(*slot); } }); // During the stall the queue is full; SendImage must block, not drop the connection. std::this_thread::sleep_for(stall); CHECK(pusher.GetConnectedWriters() == 1); CHECK(sender.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready); // Writer recovers and starts draining. writer.Release(); REQUIRE(sender.wait_for(std::chrono::seconds(30)) == std::future_status::ready); sender.get(); // Every image makes it across once the stall clears. for (int attempt = 0; attempt < 1200 && writer.DataFramesReceived() < static_cast(N); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(25)); CHECK(writer.DataFramesReceived() == static_cast(N)); // Queue fully drained: END now hands over cleanly without racing data frames. EndMessage end{}; CHECK(pusher.EndDataCollection(end) == true); for (int attempt = 0; attempt < 200 && !writer.EndAcked(); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(25)); CHECK(writer.EndAcked()); CHECK(pusher.GetConnectedWriters() == 1); image_buffer.Finalize(std::chrono::seconds(5)); } TEST_CASE("TCPImageCommTest_WedgedWriter_DroppedByBackpressureCap", "[TCP]") { // A writer that keeps heartbeating but never drains (e.g. a permanently wedged // filesystem) must not block the run or its finalization forever. The hard // backpressure cap tears the connection down even though BUSY keeps arriving, and // well before the (longer) peer-liveness timeout that those heartbeats keep at bay. constexpr int64_t N = 1000; constexpr auto liveness = std::chrono::milliseconds(5000); // kept fresh by heartbeats constexpr auto max_backpressure = std::chrono::milliseconds(1500); TCPStreamPusher pusher("tcp://127.0.0.1:*", 1, 16 * 1024); pusher.SetPeerLivenessTimeout(liveness); pusher.SetMaxBackpressureTimeout(max_backpressure); StallableWriterDouble writer(pusher.GetAddress()[0], 16 * 1024); // never released REQUIRE(writer.Connected()); for (int attempt = 0; attempt < 200 && pusher.GetConnectedWriters() < 1; ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(25)); REQUIRE(pusher.GetConnectedWriters() == 1); ImageBuffer image_buffer(16 * 1024 * 1024); image_buffer.StartMeasurement(static_cast(4096)); StartMessage start{.images_per_file = 1000, .write_master_file = true}; pusher.StartDataCollection(start); // writer ACKs START, then stalls forever auto sender = std::async(std::launch::async, [&] { for (int64_t i = 0; i < N; i++) { ZeroCopyReturnValue *slot = nullptr; while ((slot = image_buffer.GetImageSlot()) == nullptr) std::this_thread::sleep_for(std::chrono::milliseconds(1)); std::memset(slot->GetImage(), 0, 256); slot->SetImageNumber(i); slot->SetImageSize(256); slot->ReadyToSend(); pusher.SendImage(*slot); } }); // The cap must fire and drop the connection despite continuous heartbeats. bool dropped = false; for (int attempt = 0; attempt < 400 && !dropped; ++attempt) { if (pusher.GetConnectedWriters() == 0) dropped = true; else std::this_thread::sleep_for(std::chrono::milliseconds(25)); } CHECK(dropped); // Neither the producers nor finalization may hang once the writer is wedged. REQUIRE(sender.wait_for(std::chrono::seconds(10)) == std::future_status::ready); sender.get(); EndMessage end{}; CHECK(pusher.EndDataCollection(end) == false); // bounded, and reports failure image_buffer.Finalize(std::chrono::seconds(5)); }