2a9fd084ab
Follow-up simplifications after removing the zerocopy machinery, plus a real backpressure bug the cleanup surfaced: - SendImage(ZeroCopyReturnValue&) imposed a hard 2s deadline on enqueueing and then marked the connection broken. At high frame rate the 128-deep queue fills in tens of ms, so any filesystem stall longer than ~2s dropped the run even though the writer was alive and heartbeating -- defeating the whole BUSY-heartbeat backpressure design. Block instead while the peer is alive (!broken && active); the real liveness decision already lives in SendAll's peer-liveness timeout, which the writer's BUSY heartbeats keep fresh. This makes the queue path consistent with the send path: both wait out arbitrarily long stalls and only give up when the peer goes genuinely silent. - Drop the dead per-connection data_sent counter (written, never read) and the redundant ImagePusherQueueElement.image_data set on the TCP path (only the HDF5 pusher reads that field). - Add SetPeerLivenessTimeout() so the liveness window is tunable (and testable). Add TCPImageCommTest_StalledWriter_SurvivesViaHeartbeat: a controllable raw writer double connects, ACKs START, then stops draining for 4s while still sending BUSY heartbeats (peer-liveness window set to 2s). The run must ride out the stall on the zero-copy queue path and deliver all 1000 images. Verified to fail (115/1000 delivered, connection dropped) against the old 2s-deadline behavior. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
933 lines
36 KiB
C++
933 lines
36 KiB
C++
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
|
|
// SPDX-License-Identifier: GPL-3.0-only
|
|
|
|
#include <random>
|
|
#include <future>
|
|
#include <thread>
|
|
#include <atomic>
|
|
#include <mutex>
|
|
#include <condition_variable>
|
|
#include <chrono>
|
|
#include <vector>
|
|
#include <utility>
|
|
#include <cstring>
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <unistd.h>
|
|
#include <catch2/catch_all.hpp>
|
|
|
|
#include "../image_pusher/TCPStreamPusher.h"
|
|
#include "../image_puller/TCPImagePuller.h"
|
|
#include "../image_puller/ZMQImagePuller.h"
|
|
#include "../common/ImageBuffer.h"
|
|
#include "../common/ZeroCopyReturnValue.h"
|
|
|
|
TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") {
|
|
const size_t nframes = 128;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 16;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
x.RunNumber(567);
|
|
|
|
std::mt19937 g1(1387);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i: image1) i = dist(g1);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers);
|
|
std::vector<std::unique_ptr<TCPImagePuller> > puller;
|
|
for (int i = 0; i < npullers; i++)
|
|
puller.push_back(std::make_unique<TCPImagePuller>(pusher.GetAddress()[0], 64 * 1024 * 1024));
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::vector<size_t> received(npullers, 0);
|
|
std::vector<size_t> processed(npullers, 0);
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.run_number = x.GetRunNumber(),
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
std::mutex counts_mutex;
|
|
std::vector<size_t> received_by_socket(npullers, 0);
|
|
std::vector<size_t> processed_by_socket(npullers, 0);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_start = false;
|
|
bool seen_end = false;
|
|
std::optional<uint32_t> my_socket_number;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
my_socket_number = h.socket_number;
|
|
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = 0;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(ack.run_number == x.GetRunNumber());
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_start = true;
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
REQUIRE(seen_start);
|
|
REQUIRE(my_socket_number.has_value());
|
|
|
|
auto n = out->cbor->data_message->number;
|
|
REQUIRE(((n / images_per_file) % npullers) == static_cast<int64_t>(*my_socket_number));
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lg(counts_mutex);
|
|
received_by_socket.at(*my_socket_number)++;
|
|
processed_by_socket.at(*my_socket_number)++;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
REQUIRE(my_socket_number.has_value());
|
|
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
{
|
|
std::lock_guard<std::mutex> lg(counts_mutex);
|
|
ack.processed_images = processed_by_socket.at(*my_socket_number);
|
|
}
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t: receivers) t.join();
|
|
|
|
REQUIRE(received_by_socket[0] == nframes / 2);
|
|
REQUIRE(received_by_socket[1] == nframes / 2);
|
|
|
|
for (auto &p: puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") {
|
|
const size_t nframes = 64;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 8;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(42);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i: image1) i = dist(g1);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers);
|
|
std::vector<std::unique_ptr<TCPImagePuller> > puller;
|
|
for (int i = 0; i < npullers; i++)
|
|
puller.push_back(std::make_unique<TCPImagePuller>(pusher.GetAddress()[0], 64 * 1024 * 1024));
|
|
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::atomic<bool> sent_fatal{false};
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
(void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
const auto final_msg = pusher.Finalize();
|
|
REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota"));
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_end = false;
|
|
bool local_fatal_sent = false;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(2));
|
|
if (!out.has_value()) {
|
|
// Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream.
|
|
if (local_fatal_sent)
|
|
break;
|
|
REQUIRE(out.has_value());
|
|
}
|
|
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
if (w == 0 && !sent_fatal.exchange(true)) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::DATA;
|
|
ack.ok = false;
|
|
ack.fatal = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
|
|
ack.error_code = TCPAckCode::DiskQuotaExceeded;
|
|
ack.error_text = "quota exceeded";
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
local_fatal_sent = true;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t: receivers) t.join();
|
|
|
|
for (auto &p: puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") {
|
|
const size_t nframes = 64;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 8;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(123);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i: image1) i = dist(g1);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", npullers);
|
|
std::vector<std::unique_ptr<TCPImagePuller> > puller;
|
|
for (int i = 0; i < npullers; i++)
|
|
puller.push_back(std::make_unique<TCPImagePuller>(pusher.GetAddress()[0], 64 * 1024 * 1024));
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
|
|
std::this_thread::sleep_for(std::chrono::seconds(5));
|
|
|
|
auto progress = pusher.GetImagesWritten();
|
|
REQUIRE(progress.has_value());
|
|
REQUIRE(progress == nframes / 2);
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_end = false;
|
|
uint64_t processed = 0;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = 0;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
auto number = out->cbor->data_message->number;
|
|
|
|
processed++;
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::DATA;
|
|
ack.ok = (number % 2 == 0) ? true : false;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = number;
|
|
ack.processed_images = processed;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = processed;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t: receivers) t.join();
|
|
|
|
for (auto &p: puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_AutoPort_StarBind", "[TCP]") {
|
|
const size_t nframes = 8;
|
|
const int64_t images_per_file = 4;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes, 7u);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1);
|
|
TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024);
|
|
|
|
std::this_thread::sleep_for(std::chrono::seconds(2));
|
|
REQUIRE(pusher.GetConnectedWriters() == 1);
|
|
|
|
std::future<void> receiver = std::async(std::launch::async, [&] {
|
|
bool seen_end = false;
|
|
uint64_t processed = 0;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller.PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack{
|
|
.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .error_code = TCPAckCode::None
|
|
};
|
|
REQUIRE(puller.SendAck(ack));
|
|
} else if (out->cbor->data_message) {
|
|
processed++;
|
|
} else if (out->cbor->end_message) {
|
|
PullerAckMessage ack{
|
|
.ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None
|
|
};
|
|
REQUIRE(puller.SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{.images_per_file = images_per_file, .write_master_file = true};
|
|
EndMessage end{};
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(), x.GetYPixelsNum(),
|
|
x.GetImageMode(), x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
REQUIRE_NOTHROW(receiver.get());
|
|
puller.Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") {
|
|
const size_t nframes = 256;
|
|
const int64_t images_per_file = 16;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes, 11u);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1);
|
|
TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024);
|
|
|
|
std::thread receiver([&] {
|
|
bool disconnected = false;
|
|
while (!disconnected) {
|
|
auto out = puller.PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack{
|
|
.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .error_code = TCPAckCode::None
|
|
};
|
|
REQUIRE(puller.SendAck(ack));
|
|
} else if (out->cbor->data_message) {
|
|
puller.Disconnect(); // simulate puller disappearing mid-stream
|
|
disconnected = true;
|
|
}
|
|
}
|
|
});
|
|
|
|
auto sender = std::async(std::launch::async, [&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{.images_per_file = images_per_file, .write_master_file = true};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(), x.GetYPixelsNum(),
|
|
x.GetImageMode(), x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
(void) pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
|
|
}
|
|
|
|
return pusher.EndDataCollection(end);
|
|
});
|
|
|
|
REQUIRE(sender.wait_for(std::chrono::seconds(20)) == std::future_status::ready);
|
|
CHECK(sender.get() == false);
|
|
|
|
receiver.join();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_RepubToZMQ", "[TCP][ZeroMQ]") {
|
|
// Chain: TCPStreamPusher --TCP--> TCPImagePuller --ZMQ repub--> ZMQImagePuller
|
|
const size_t nframes = 64;
|
|
const int64_t images_per_file = 8;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(9999);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i : image1) i = dist(g1);
|
|
|
|
// 1. Create TCP pusher on an auto-assigned port
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1);
|
|
|
|
// 2. Create TCP puller with repub over ZMQ (ipc, auto-assigned)
|
|
const std::string repub_addr = "ipc://*";
|
|
// Need to figure out the actual repub endpoint after bind — ZMQ ipc://* picks a temp path.
|
|
// However, ZMQSocket::Bind with "ipc://*" is used in project; the repub socket binds internally,
|
|
// so we need a known address. Use a tcp address instead for testability.
|
|
const std::string repub_bind_addr = "tcp://127.0.0.1:19010";
|
|
TCPImagePuller tcp_puller(pusher.GetAddress()[0], {}, repub_bind_addr);
|
|
|
|
// 3. Create ZMQ puller that connects to the repub address
|
|
ZMQImagePuller zmq_puller(repub_bind_addr);
|
|
|
|
// Wait for TCP connection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < 1; ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == 1);
|
|
|
|
// Sender thread: push frames over TCP
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
});
|
|
|
|
// TCP puller consumer: drains the TCP side (with ACKs) so data keeps flowing
|
|
std::thread tcp_consumer([&] {
|
|
bool seen_end = false;
|
|
while (!seen_end) {
|
|
auto out = tcp_puller.PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack{};
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(tcp_puller.SendAck(ack));
|
|
} else if (out->cbor->end_message) {
|
|
PullerAckMessage ack{};
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(tcp_puller.SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
// data frames: no ack needed for this test
|
|
}
|
|
});
|
|
|
|
// ZMQ puller consumer: verify the republished stream
|
|
size_t zmq_nimages = 0;
|
|
size_t zmq_errors = 0;
|
|
bool zmq_seen_start = false;
|
|
bool zmq_seen_end = false;
|
|
|
|
std::thread zmq_consumer([&] {
|
|
auto timeout = std::chrono::seconds(30);
|
|
|
|
// First message should be START
|
|
auto img = zmq_puller.PollImage(timeout);
|
|
if (!img || !img->cbor || !img->cbor->start_message) {
|
|
zmq_errors++;
|
|
return;
|
|
}
|
|
zmq_seen_start = true;
|
|
|
|
// Republished START should have writer_notification_zmq_addr cleared
|
|
if (!img->cbor->start_message->writer_notification_zmq_addr.empty()) {
|
|
zmq_errors++;
|
|
}
|
|
|
|
// Consume data and END
|
|
img = zmq_puller.PollImage(timeout);
|
|
while (img && img->cbor && !img->cbor->end_message) {
|
|
if (img->cbor->data_message) {
|
|
auto n = img->cbor->data_message->number;
|
|
if (img->cbor->data_message->image.GetCompressedSize() != x.GetPixelsNum() * sizeof(uint16_t))
|
|
zmq_errors++;
|
|
else if (memcmp(img->cbor->data_message->image.GetCompressed(),
|
|
image1.data() + n * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t)) != 0)
|
|
zmq_errors++;
|
|
zmq_nimages++;
|
|
}
|
|
img = zmq_puller.PollImage(timeout);
|
|
}
|
|
if (img && img->cbor && img->cbor->end_message)
|
|
zmq_seen_end = true;
|
|
});
|
|
|
|
sender.join();
|
|
tcp_consumer.join();
|
|
zmq_consumer.join();
|
|
|
|
tcp_puller.Disconnect();
|
|
zmq_puller.Disconnect();
|
|
|
|
// The repub uses non-blocking Put for data, so some frames *could* be dropped
|
|
// under extreme back-pressure, but with only 64 frames we expect all of them.
|
|
REQUIRE(zmq_seen_start);
|
|
REQUIRE(zmq_seen_end);
|
|
REQUIRE(zmq_nimages == nframes);
|
|
REQUIRE(zmq_errors == 0);
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Controllable TCP "writer" peer for backpressure tests. Connects to the pusher, ACKs
|
|
// START, then *stalls* (stops draining the socket) until Release() is called, while a
|
|
// background thread keeps sending BUSY heartbeats — i.e. a writer that is alive but
|
|
// wedged (e.g. on a slow filesystem at high frame rate). Catch2 assertion macros are not
|
|
// thread-safe, so the worker threads only touch atomics; the test thread asserts.
|
|
class StallableWriterDouble {
|
|
public:
|
|
StallableWriterDouble(const std::string &tcp_addr, int rcvbuf_bytes) {
|
|
auto [host, port] = ParseHostPort(tcp_addr);
|
|
fd_ = ::socket(AF_INET, SOCK_STREAM, 0);
|
|
if (fd_ < 0)
|
|
return;
|
|
setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &rcvbuf_bytes, sizeof(rcvbuf_bytes));
|
|
sockaddr_in sin{};
|
|
sin.sin_family = AF_INET;
|
|
sin.sin_port = htons(port);
|
|
inet_pton(AF_INET, host.c_str(), &sin.sin_addr);
|
|
if (::connect(fd_, reinterpret_cast<sockaddr *>(&sin), sizeof(sin)) != 0) {
|
|
::close(fd_);
|
|
fd_ = -1;
|
|
return;
|
|
}
|
|
busy_thread_ = std::thread([this] { BusyLoop(); });
|
|
reader_thread_ = std::thread([this] { ReaderLoop(); });
|
|
}
|
|
|
|
~StallableWriterDouble() {
|
|
stop_ = true;
|
|
Release();
|
|
if (fd_ >= 0)
|
|
::shutdown(fd_, SHUT_RDWR);
|
|
if (reader_thread_.joinable())
|
|
reader_thread_.join();
|
|
if (busy_thread_.joinable())
|
|
busy_thread_.join();
|
|
if (fd_ >= 0)
|
|
::close(fd_);
|
|
}
|
|
|
|
[[nodiscard]] bool Connected() const { return fd_ >= 0; }
|
|
|
|
// Stop stalling: let the reader drain DATA and ACK END.
|
|
void Release() {
|
|
{
|
|
std::lock_guard<std::mutex> lg(mtx_);
|
|
released_ = true;
|
|
}
|
|
cv_.notify_all();
|
|
}
|
|
|
|
[[nodiscard]] size_t DataFramesReceived() const { return data_frames_.load(); }
|
|
[[nodiscard]] bool EndAcked() const { return end_acked_.load(); }
|
|
|
|
private:
|
|
static std::pair<std::string, uint16_t> ParseHostPort(const std::string &addr) {
|
|
const std::string prefix = "tcp://";
|
|
const auto hp = addr.substr(prefix.size());
|
|
const auto p = hp.find_last_of(':');
|
|
return {hp.substr(0, p), static_cast<uint16_t>(std::stoi(hp.substr(p + 1)))};
|
|
}
|
|
|
|
bool SendHeader(TCPFrameType type, TCPFrameType ack_for, uint64_t run, uint32_t sock, uint32_t flags) {
|
|
TcpFrameHeader h{};
|
|
h.type = static_cast<uint16_t>(type);
|
|
h.ack_for = static_cast<uint16_t>(ack_for);
|
|
h.run_number = run;
|
|
h.socket_number = sock;
|
|
h.flags = flags;
|
|
h.payload_size = 0;
|
|
std::lock_guard<std::mutex> lg(send_mtx_);
|
|
if (fd_ < 0)
|
|
return false;
|
|
return ::send(fd_, &h, sizeof(h), MSG_NOSIGNAL) == static_cast<ssize_t>(sizeof(h));
|
|
}
|
|
|
|
bool ReadExact(void *buf, size_t len) {
|
|
auto *p = static_cast<uint8_t *>(buf);
|
|
size_t got = 0;
|
|
while (got < len) {
|
|
const ssize_t rc = ::recv(fd_, p + got, len - got, 0);
|
|
if (rc <= 0)
|
|
return false;
|
|
got += static_cast<size_t>(rc);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void BusyLoop() {
|
|
// Heartbeat keeps the pusher's peer-liveness fresh even while we are not draining.
|
|
while (!stop_) {
|
|
SendHeader(TCPFrameType::BUSY, TCPFrameType::DATA, run_.load(), sock_.load(), 0);
|
|
for (int i = 0; i < 5 && !stop_; i++)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
}
|
|
}
|
|
|
|
void ReaderLoop() {
|
|
std::vector<uint8_t> discard;
|
|
while (!stop_) {
|
|
TcpFrameHeader h{};
|
|
if (!ReadExact(&h, sizeof(h)))
|
|
return;
|
|
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION)
|
|
return;
|
|
if (h.payload_size > 0) {
|
|
discard.resize(h.payload_size);
|
|
if (!ReadExact(discard.data(), discard.size()))
|
|
return;
|
|
}
|
|
switch (static_cast<TCPFrameType>(h.type)) {
|
|
case TCPFrameType::START:
|
|
run_.store(h.run_number);
|
|
sock_.store(h.socket_number);
|
|
SendHeader(TCPFrameType::ACK, TCPFrameType::START, h.run_number, h.socket_number, TCP_ACK_FLAG_OK);
|
|
{ // Stall: stop reading until released.
|
|
std::unique_lock<std::mutex> ul(mtx_);
|
|
cv_.wait(ul, [this] { return released_ || stop_; });
|
|
}
|
|
break;
|
|
case TCPFrameType::DATA:
|
|
data_frames_.fetch_add(1);
|
|
break;
|
|
case TCPFrameType::END:
|
|
SendHeader(TCPFrameType::ACK, TCPFrameType::END, h.run_number, h.socket_number, TCP_ACK_FLAG_OK);
|
|
end_acked_.store(true);
|
|
return;
|
|
default:
|
|
break; // ignore KEEPALIVE etc.
|
|
}
|
|
}
|
|
}
|
|
|
|
int fd_ = -1;
|
|
std::thread reader_thread_;
|
|
std::thread busy_thread_;
|
|
std::atomic<bool> stop_{false};
|
|
std::atomic<uint64_t> run_{0};
|
|
std::atomic<uint32_t> sock_{0};
|
|
std::atomic<size_t> data_frames_{0};
|
|
std::atomic<bool> end_acked_{false};
|
|
std::mutex send_mtx_;
|
|
std::mutex mtx_;
|
|
std::condition_variable cv_;
|
|
bool released_ = false;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
TEST_CASE("TCPImageCommTest_StalledWriter_SurvivesViaHeartbeat", "[TCP]") {
|
|
// A writer that is alive (still heartbeating) but has stopped draining — e.g. wedged
|
|
// on a slow filesystem at high frame rate — must NOT be dropped mid-run. The pusher
|
|
// rides out the backpressure on the production zero-copy queue path until the writer
|
|
// recovers. Regression for the queue-path send giving up on a fixed deadline, and for
|
|
// the BUSY heartbeat keeping the connection alive past the peer-liveness window.
|
|
constexpr int64_t N = 1000; // > queue depth + socket buffers
|
|
constexpr auto liveness = std::chrono::milliseconds(2000);
|
|
constexpr auto stall = std::chrono::milliseconds(4000); // > liveness AND > old send deadline
|
|
|
|
// Small SO_SNDBUF/SO_RCVBUF so backpressure reaches the queue after few images.
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1, 16 * 1024);
|
|
pusher.SetPeerLivenessTimeout(liveness);
|
|
|
|
StallableWriterDouble writer(pusher.GetAddress()[0], 16 * 1024);
|
|
REQUIRE(writer.Connected());
|
|
|
|
for (int attempt = 0; attempt < 200 && pusher.GetConnectedWriters() < 1; ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(25));
|
|
REQUIRE(pusher.GetConnectedWriters() == 1);
|
|
|
|
ImageBuffer image_buffer(16 * 1024 * 1024);
|
|
image_buffer.StartMeasurement(static_cast<size_t>(4096));
|
|
|
|
StartMessage start{.images_per_file = 1000, .write_master_file = true};
|
|
pusher.StartDataCollection(start); // writer ACKs START, then stalls (stops reading)
|
|
|
|
auto sender = std::async(std::launch::async, [&] {
|
|
for (int64_t i = 0; i < N; i++) {
|
|
ZeroCopyReturnValue *slot = nullptr;
|
|
while ((slot = image_buffer.GetImageSlot()) == nullptr)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
|
std::memset(slot->GetImage(), 0, 256);
|
|
slot->SetImageNumber(i);
|
|
slot->SetImageSize(256); // arbitrary payload; the writer double discards it
|
|
slot->ReadyToSend();
|
|
pusher.SendImage(*slot);
|
|
}
|
|
});
|
|
|
|
// During the stall the queue is full; SendImage must block, not drop the connection.
|
|
std::this_thread::sleep_for(stall);
|
|
CHECK(pusher.GetConnectedWriters() == 1);
|
|
CHECK(sender.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready);
|
|
|
|
// Writer recovers and starts draining.
|
|
writer.Release();
|
|
|
|
REQUIRE(sender.wait_for(std::chrono::seconds(30)) == std::future_status::ready);
|
|
sender.get();
|
|
|
|
// Every image makes it across once the stall clears.
|
|
for (int attempt = 0; attempt < 1200 && writer.DataFramesReceived() < static_cast<size_t>(N); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(25));
|
|
CHECK(writer.DataFramesReceived() == static_cast<size_t>(N));
|
|
|
|
// Queue fully drained: END now hands over cleanly without racing data frames.
|
|
EndMessage end{};
|
|
CHECK(pusher.EndDataCollection(end) == true);
|
|
|
|
for (int attempt = 0; attempt < 200 && !writer.EndAcked(); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(25));
|
|
CHECK(writer.EndAcked());
|
|
CHECK(pusher.GetConnectedWriters() == 1);
|
|
|
|
image_buffer.Finalize(std::chrono::seconds(5));
|
|
}
|