// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include #include #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; const int64_t npullers = 2; const int64_t images_per_file = 16; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(1387); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); std::vector addr{ "tcp://127.0.0.1:19001", "tcp://127.0.0.1:19002" }; std::vector> puller; for (int i = 0; i < npullers; i++) { puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); } TCPStreamPusher pusher( addr, 64 * 1024 * 1024, 128 * 1024, 8192 ); std::vector received(npullers, 0); std::vector processed(npullers, 0); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_start = false; bool seen_end = false; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_start = true; continue; } if (out->cbor->data_message) { REQUIRE(seen_start); auto n = out->cbor->data_message->number; REQUIRE(((n / images_per_file) % npullers) == w); received[w]++; processed[w]++; continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = processed[w]; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t : receivers) t.join(); REQUIRE(received[0] == nframes / 2); REQUIRE(received[1] == nframes / 2); for (auto &p : puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(42); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); std::vector addr{ "tcp://127.0.0.1:19011", "tcp://127.0.0.1:19012" }; std::vector> puller; for (int i = 0; i < npullers; i++) { puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); } TCPStreamPusher pusher( addr, 64 * 1024 * 1024, 128 * 1024, 8192 ); std::atomic sent_fatal{false}; std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } REQUIRE_FALSE(pusher.EndDataCollection(end)); const auto final_msg = pusher.Finalize(); REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; bool local_fatal_sent = false; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(2)); if (!out.has_value()) { // Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream. if (local_fatal_sent) break; REQUIRE(out.has_value()); } REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { if (w == 0 && !sent_fatal.exchange(true)) { PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = false; ack.fatal = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = static_cast(out->cbor->data_message->number); ack.error_code = TCPAckCode::DiskQuotaExceeded; ack.error_text = "quota exceeded"; REQUIRE(puller[w]->SendAck(ack)); local_fatal_sent = true; } continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t : receivers) t.join(); for (auto &p : puller) p->Disconnect(); } TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { const size_t nframes = 64; const int64_t npullers = 2; const int64_t images_per_file = 8; DiffractionExperiment x(DetJF(1)); x.Raw(); x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); std::mt19937 g1(123); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); std::vector addr{ "tcp://127.0.0.1:19021", "tcp://127.0.0.1:19022" }; std::vector> puller; for (int i = 0; i < npullers; i++) { puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); } TCPStreamPusher pusher( addr, 64 * 1024 * 1024, 128 * 1024, 8192 ); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ .images_per_file = images_per_file, .write_master_file = true }; EndMessage end{}; pusher.StartDataCollection(start); for (int64_t i = 0; i < static_cast(nframes); i++) { DataMessage data_message; data_message.number = i; data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), x.GetPixelsNum() * sizeof(uint16_t), x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); } REQUIRE(pusher.EndDataCollection(end)); std::this_thread::sleep_for(std::chrono::seconds(5)); auto progress = pusher.GetImagesWritten(); REQUIRE(progress.has_value()); REQUIRE(progress == nframes / 2); }); std::vector receivers; receivers.reserve(npullers); for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; uint64_t processed = 0; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); REQUIRE(out.has_value()); REQUIRE(out->cbor != nullptr); REQUIRE(out->tcp_msg != nullptr); const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = 0; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->data_message) { auto number = out->cbor->data_message->number; processed++; PullerAckMessage ack; ack.ack_for = TCPFrameType::DATA; ack.ok = (number % 2 == 0) ? true : false; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = number; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); continue; } if (out->cbor->end_message) { PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; ack.fatal = false; ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; ack.processed_images = processed; ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; } } }); } sender.join(); for (auto &t : receivers) t.join(); for (auto &p : puller) p->Disconnect(); }