// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include #include #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" #include "../image_puller/ZMQImagePuller.h" TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; const int64_t npullers = 2; const int64_t images_per_file = 16; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); x.RunNumber(567); std::mt19937 g1(1387); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19001"; std::vector > puller; for (int i = 0; i < npullers; i++) { puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); } TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::vector received(npullers, 0); std::vector processed(npullers, 0); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .run_number = x.GetRunNumber(), .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); }); std::vector receivers; receivers.reserve(npullers); std::mutex counts_mutex; std::vector received_by_socket(npullers, 0); std::vector processed_by_socket(npullers, 0); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_start = false; bool seen_end = false; std::optional my_socket_number; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { my_socket_number = h.socket_number; PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(ack.run_number == x.GetRunNumber()); REQUIRE(puller[w]->SendAck(ack)); seen_start = true; continue; } if (out->cbor->data_message) { REQUIRE(seen_start); REQUIRE(my_socket_number.has_value()); auto n = out->cbor->data_message->number; REQUIRE(((n / images_per_file) % npullers) == static_cast(*my_socket_number)); { std::lock_guard lg(counts_mutex); received_by_socket.at(*my_socket_number)++; processed_by_socket.at(*my_socket_number)++; } continue; } if (out->cbor->end_message) { REQUIRE(my_socket_number.has_value()); PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; { std::lock_guard lg(counts_mutex); ack.processed_images = processed_by_socket.at(*my_socket_number); } ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); REQUIRE(received_by_socket[0] == nframes / 2); REQUIRE(received_by_socket[1] == nframes / 2); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(42); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19003"; std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::atomic sent_fatal{false}; std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } REQUIRE(pusher.EndDataCollection(end)); const auto final_msg = pusher.Finalize(); REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; bool local_fatal_sent = false; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(2)); if (!out.has_value()) { // Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream. if (local_fatal_sent) break; REQUIRE(out.has_value()); } REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { if (w == 0 && !sent_fatal.exchange(true)) { PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = false; ack.fatal = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = static_cast(out->cbor->data_message->number); ack.error_code = TCPAckCode::DiskQuotaExceeded; ack.error_text = "quota exceeded"; REQUIRE(puller[w]->SendAck(ack)); local_fatal_sent = true; } continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(123); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19004"; std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == static_cast(npullers)); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); std::this_thread::sleep_for(std::chrono::seconds(5)); auto progress = pusher.GetImagesWritten(); REQUIRE(progress.has_value()); REQUIRE(progress == nframes / 2); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; uint64_t processed = 0; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { auto number = out->cbor->data_message->number; processed++; PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = (number % 2 == 0) ? true : false; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = number; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t: receivers) t.join(); for (auto &p: puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_AutoPort_StarBind", "[TCP]") { const size_t nframes = 8; const int64_t images_per_file = 4; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::vector image1(x.GetPixelsNum() * nframes, 7u); TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024); std::thread receiver([&] { bool seen_end = false; uint64_t processed = 0; while (!seen_end) { auto out = puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { processed++; } else if (out->cbor->end_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); seen_end = true; } } }); std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{.images_per_file = images_per_file, .write_master_file = true}; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); receiver.join(); puller.Disconnect(); } TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { const size_t nframes = 256; const int64_t images_per_file = 16; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::vector image1(x.GetPixelsNum() * nframes, 11u); TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024); std::thread receiver([&] { bool disconnected = false; while (!disconnected) { auto out = puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{ .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, .socket_number = h.socket_number, .error_code = TCPAckCode::None }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { puller.Disconnect(); // simulate puller disappearing mid-stream disconnected = true; } } }); auto sender = std::async(std::launch::async, [&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{.images_per_file = images_per_file, .write_master_file = true}; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } return pusher.EndDataCollection(end); }); REQUIRE(sender.wait_for(std::chrono::seconds(20)) == std::future_status::ready); CHECK(sender.get() == false); receiver.join(); } TEST_CASE("TCPImageCommTest_RepubToZMQ", "[TCP][ZeroMQ]") { // Chain: TCPStreamPusher --TCP--> TCPImagePuller --ZMQ repub--> ZMQImagePuller const size_t nframes = 64; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(9999); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); // 1. Create TCP pusher on an auto-assigned port TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); // 2. Create TCP puller with repub over ZMQ (ipc, auto-assigned) const std::string repub_addr = "ipc://*"; // Need to figure out the actual repub endpoint after bind — ZMQ ipc://* picks a temp path. // However, ZMQSocket::Bind with "ipc://*" is used in project; the repub socket binds internally, // so we need a known address. Use a tcp address instead for testability. const std::string repub_bind_addr = "tcp://127.0.0.1:19010"; TCPImagePuller tcp_puller(pusher.GetAddress()[0], {}, repub_bind_addr); // 3. Create ZMQ puller that connects to the repub address ZMQImagePuller zmq_puller(repub_bind_addr); // Wait for TCP connection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < 1; ++attempt) std::this_thread::sleep_for(std::chrono::milliseconds(50)); REQUIRE(pusher.GetConnectedWriters() == 1); // Sender thread: push frames over TCP std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); }); // TCP puller consumer: drains the TCP side (with ACKs) so data keeps flowing std::thread tcp_consumer([&] { bool seen_end = false; while (!seen_end) { auto out = tcp_puller.PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack{}; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(tcp_puller.SendAck(ack)); } else if (out->cbor->end_message) { PullerAckMessage ack{}; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(tcp_puller.SendAck(ack)); seen_end = true; } // data frames: no ack needed for this test } }); // ZMQ puller consumer: verify the republished stream size_t zmq_nimages = 0; size_t zmq_errors = 0; bool zmq_seen_start = false; bool zmq_seen_end = false; std::thread zmq_consumer([&] { auto timeout = std::chrono::seconds(30); // First message should be START auto img = zmq_puller.PollImage(timeout); if (!img || !img->cbor || !img->cbor->start_message) { zmq_errors++; return; } zmq_seen_start = true; // Republished START should have writer_notification_zmq_addr cleared if (!img->cbor->start_message->writer_notification_zmq_addr.empty()) { zmq_errors++; } // Consume data and END img = zmq_puller.PollImage(timeout); while (img && img->cbor && !img->cbor->end_message) { if (img->cbor->data_message) { auto n = img->cbor->data_message->number; if (img->cbor->data_message->image.GetCompressedSize() != x.GetPixelsNum() * sizeof(uint16_t)) zmq_errors++; else if (memcmp(img->cbor->data_message->image.GetCompressed(), image1.data() + n * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t)) != 0) zmq_errors++; zmq_nimages++; } img = zmq_puller.PollImage(timeout); } if (img && img->cbor && img->cbor->end_message) zmq_seen_end = true; }); sender.join(); tcp_consumer.join(); zmq_consumer.join(); tcp_puller.Disconnect(); zmq_puller.Disconnect(); // The repub uses non-blocking Put for data, so some frames *could* be dropped // under extreme back-pressure, but with only 64 frames we expect all of them. REQUIRE(zmq_seen_start); REQUIRE(zmq_seen_end); REQUIRE(zmq_nimages == nframes); REQUIRE(zmq_errors == 0); }