diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index 0170514e..991d4b40 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -20,16 +20,16 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { std::mt19937 g1(1387); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); - for (auto &i : image1) i = dist(g1); + for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19001"; - std::vector> puller; + std::vector > puller; for (int i = 0; i < npullers; i++) { puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); } - TCPStreamPusher pusher(addr,npullers); + TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) @@ -44,8 +44,8 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ - .images_per_file = images_per_file, - .write_master_file = true + .images_per_file = images_per_file, + .write_master_file = true }; EndMessage end{}; @@ -70,10 +70,15 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { std::vector receivers; receivers.reserve(npullers); + std::mutex counts_mutex; + std::vector received_by_socket(npullers, 0); + std::vector processed_by_socket(npullers, 0); + for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_start = false; bool seen_end = false; + std::optional my_socket_number; while (!seen_end) { auto out = puller[w]->PollImage(std::chrono::seconds(10)); @@ -84,6 +89,8 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { + my_socket_number = h.socket_number; + PullerAckMessage ack; ack.ack_for = TCPFrameType::START; ack.ok = true; @@ -100,14 +107,22 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { if (out->cbor->data_message) { REQUIRE(seen_start); + REQUIRE(my_socket_number.has_value()); + auto n = out->cbor->data_message->number; - REQUIRE(((n / images_per_file) % npullers) == w); - received[w]++; - processed[w]++; + REQUIRE(((n / images_per_file) % npullers) == static_cast(*my_socket_number)); + + { + std::lock_guard lg(counts_mutex); + received_by_socket.at(*my_socket_number)++; + processed_by_socket.at(*my_socket_number)++; + } continue; } if (out->cbor->end_message) { + REQUIRE(my_socket_number.has_value()); + PullerAckMessage ack; ack.ack_for = TCPFrameType::END; ack.ok = true; @@ -115,7 +130,10 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { ack.run_number = h.run_number; ack.socket_number = h.socket_number; ack.image_number = 0; - ack.processed_images = processed[w]; + { + std::lock_guard lg(counts_mutex); + ack.processed_images = processed_by_socket.at(*my_socket_number); + } ack.error_code = TCPAckCode::None; REQUIRE(puller[w]->SendAck(ack)); seen_end = true; @@ -125,12 +143,12 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { } sender.join(); - for (auto &t : receivers) t.join(); + for (auto &t: receivers) t.join(); - REQUIRE(received[0] == nframes / 2); - REQUIRE(received[1] == nframes / 2); + REQUIRE(received_by_socket[0] == nframes / 2); + REQUIRE(received_by_socket[1] == nframes / 2); - for (auto &p : puller) + for (auto &p: puller) p->Disconnect(); } @@ -147,15 +165,15 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { std::mt19937 g1(42); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); - for (auto &i : image1) i = dist(g1); + for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19003"; - std::vector> puller; + std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); - TCPStreamPusher pusher(addr,npullers); + TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) @@ -169,8 +187,8 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ - .images_per_file = images_per_file, - .write_master_file = true + .images_per_file = images_per_file, + .write_master_file = true }; EndMessage end{}; @@ -186,7 +204,7 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); - (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); + (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } REQUIRE(pusher.EndDataCollection(end)); @@ -197,7 +215,7 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { std::vector receivers; receivers.reserve(npullers); - for (int w = 0; w < npullers; w++) { + for (int w = 0; w < npullers; w++) { receivers.emplace_back([&, w] { bool seen_end = false; bool local_fatal_sent = false; @@ -259,9 +277,9 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { } sender.join(); - for (auto &t : receivers) t.join(); + for (auto &t: receivers) t.join(); - for (auto &p : puller) + for (auto &p: puller) p->Disconnect(); } @@ -278,15 +296,15 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { std::mt19937 g1(123); std::uniform_int_distribution dist; std::vector image1(x.GetPixelsNum() * nframes); - for (auto &i : image1) i = dist(g1); + for (auto &i: image1) i = dist(g1); std::string addr = "tcp://127.0.0.1:19004"; - std::vector> puller; + std::vector > puller; for (int i = 0; i < npullers; i++) puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); - TCPStreamPusher pusher(addr,npullers); + TCPStreamPusher pusher(addr, npullers); // Wait for all pullers to connect before starting data collection for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast(npullers); ++attempt) @@ -298,8 +316,8 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); StartMessage start{ - .images_per_file = images_per_file, - .write_master_file = true + .images_per_file = images_per_file, + .write_master_file = true }; EndMessage end{}; @@ -392,9 +410,9 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { } sender.join(); - for (auto &t : receivers) t.join(); + for (auto &t: receivers) t.join(); - for (auto &p : puller) + for (auto &p: puller) p->Disconnect(); } @@ -424,14 +442,18 @@ TEST_CASE("TCPImageCommTest_AutoPort_StarBind", "[TCP]") { const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { - PullerAckMessage ack{.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, - .socket_number = h.socket_number, .error_code = TCPAckCode::None}; + PullerAckMessage ack{ + .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, + .socket_number = h.socket_number, .error_code = TCPAckCode::None + }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { processed++; } else if (out->cbor->end_message) { - PullerAckMessage ack{.ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number, - .socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None}; + PullerAckMessage ack{ + .ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number, + .socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None + }; REQUIRE(puller.SendAck(ack)); seen_end = true; } @@ -485,8 +507,10 @@ TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { const auto &h = out->tcp_msg->header; if (out->cbor->start_message) { - PullerAckMessage ack{.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, - .socket_number = h.socket_number, .error_code = TCPAckCode::None}; + PullerAckMessage ack{ + .ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number, + .socket_number = h.socket_number, .error_code = TCPAckCode::None + }; REQUIRE(puller.SendAck(ack)); } else if (out->cbor->data_message) { puller.Disconnect(); // simulate puller disappearing mid-stream @@ -512,7 +536,7 @@ TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { x.GetXPixelsNum(), x.GetYPixelsNum(), x.GetImageMode(), x.GetCompressionAlgorithm()); serializer.SerializeImage(data_message); - (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); + (void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } return pusher.EndDataCollection(end); @@ -522,4 +546,4 @@ TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { CHECK(sender.get() == false); receiver.join(); -} \ No newline at end of file +}