TCP: Fixes - allow for live monitoring of ACK progress

This commit is contained in:
2026-03-04 15:54:40 +01:00
parent 3f8736dd34
commit 536fc0761d
8 changed files with 318 additions and 15 deletions

View File

@@ -12,7 +12,6 @@ void HDF5FilePusher::StartDataCollection(StartMessage &message) {
writer_future = std::async(std::launch::async, &HDF5FilePusher::WriterThread, this);
images_written = 0;
images_err = 0;
last_processed_image = 0;
}
bool HDF5FilePusher::EndDataCollection(const EndMessage &message) {
@@ -40,8 +39,6 @@ bool HDF5FilePusher::SendImage(const uint8_t *image_data, size_t image_size, int
try {
writer->Write(*deserialized->data_message);
images_written++;
if (image_number > last_processed_image)
last_processed_image = image_number;
} catch (const JFJochException &e) {
images_err++;
}
@@ -85,12 +82,10 @@ std::optional<ImagePusherAckProgress> HDF5FilePusher::GetAckProgress() const {
uint64_t ack_ok = images_written;
uint64_t ack_bad = images_err;
uint64_t ack_total = ack_ok + ack_bad;
uint64_t last = last_processed_image;
return ImagePusherAckProgress{
.data_acked_ok = ack_ok,
.data_acked_bad = ack_bad,
.data_acked_total = ack_total,
.last_processed_images = last
.data_acked_total = ack_total
};
}

View File

@@ -20,7 +20,6 @@ class HDF5FilePusher : public ImagePusher {
std::atomic<uint64_t> images_written = 0;
std::atomic<uint64_t> images_err = 0;
std::atomic<uint64_t> last_processed_image = 0;
public:
// Thread safety: StartDataCollection, EndDataCollection and SendCalibration must run poorly in serial context
// SendImage can be executed in parallel

View File

@@ -26,7 +26,6 @@ struct ImagePusherAckProgress {
uint64_t data_acked_bad = 0;
uint64_t data_acked_total = 0;
uint64_t data_ack_pending = 0;
uint64_t last_processed_images = 0;
};
void PrepareCBORImage(DataMessage& message,

View File

@@ -169,7 +169,8 @@ bool TCPStreamPusher::SendCalibration(const CompressedImage &message) {
}
std::optional<ImagePusherAckProgress> TCPStreamPusher::GetAckProgress() const {
ImagePusherAckProgress out;
ImagePusherAckProgress out{
};
for (const auto &s : socket) {
auto p = s->GetDataAckProgress();
out.data_sent += p.data_sent;

View File

@@ -460,13 +460,10 @@ void TCPStreamPusherSocket::AckThread() {
last_ack_error = "CANCEL rejected";
} else if (ack_for == TCPFrameType::DATA) {
data_acked_total.fetch_add(1, std::memory_order_relaxed);
last_processed_images.store(h.ack_processed_images, std::memory_order_relaxed);
if (ok && !fatal) {
data_acked_ok.fetch_add(1, std::memory_order_relaxed);
} else {
data_acked_bad.fetch_add(1, std::memory_order_relaxed);
broken = true; // mandatory DATA ACK mode: bad DATA ACK breaks stream
if (error_text.empty())
last_ack_error = "DATA ACK failed";
logger.Error("Received failing DATA ACK on " + endpoint + ": " + last_ack_error);
@@ -497,7 +494,6 @@ void TCPStreamPusherSocket::StartWriterThread() {
data_acked_ok.store(0, std::memory_order_relaxed);
data_acked_bad.store(0, std::memory_order_relaxed);
data_acked_total.store(0, std::memory_order_relaxed);
last_processed_images.store(0, std::memory_order_relaxed);
active = true;
send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this);
@@ -598,6 +594,5 @@ ImagePusherAckProgress TCPStreamPusherSocket::GetDataAckProgress() const {
p.data_acked_bad = data_acked_bad.load(std::memory_order_relaxed);
p.data_acked_total = data_acked_total.load(std::memory_order_relaxed);
p.data_ack_pending = (p.data_sent >= p.data_acked_total) ? (p.data_sent - p.data_acked_total) : 0;
p.last_processed_images = last_processed_images.load(std::memory_order_relaxed);
return p;
}

View File

@@ -62,7 +62,6 @@ class TCPStreamPusherSocket {
std::atomic<uint64_t> data_acked_ok{0};
std::atomic<uint64_t> data_acked_bad{0};
std::atomic<uint64_t> data_acked_total{0};
std::atomic<uint64_t> last_processed_images{0};
void WriterThread();
void CompletionThread();

View File

@@ -1547,4 +1547,12 @@ TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver]
REQUIRE(!service.GetProgress().has_value());
REQUIRE_NOTHROW(writer_future.get());
auto ack = pusher.GetAckProgress();
REQUIRE(ack.has_value());
CHECK(ack->data_acked_ok == experiment.GetImageNum());
CHECK(ack->data_acked_bad == 0);
CHECK(ack->data_acked_total == experiment.GetImageNum());
CHECK(ack->data_ack_pending == 0);
CHECK(ack->data_sent == experiment.GetImageNum());
}

View File

@@ -268,6 +268,313 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") {
sender.join();
for (auto &t : receivers) t.join();
for (auto &p : puller)
p->Disconnect();
}
TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") {
const size_t nframes = 64;
const int64_t npullers = 2;
const int64_t images_per_file = 8;
DiffractionExperiment x(DetJF(1));
x.Raw();
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
std::mt19937 g1(123);
std::uniform_int_distribution<uint16_t> dist;
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
for (auto &i : image1) i = dist(g1);
std::vector<std::string> addr{
"tcp://127.0.0.1:19021",
"tcp://127.0.0.1:19022"
};
std::vector<std::unique_ptr<TCPImagePuller>> puller;
for (int i = 0; i < npullers; i++) {
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
}
TCPStreamPusher pusher(
addr,
64 * 1024 * 1024,
128 * 1024,
8192
);
std::thread sender([&] {
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
StartMessage start{
.images_per_file = images_per_file,
.write_master_file = true
};
EndMessage end{};
pusher.StartDataCollection(start);
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
DataMessage data_message;
data_message.number = i;
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
x.GetPixelsNum() * sizeof(uint16_t),
x.GetXPixelsNum(),
x.GetYPixelsNum(),
x.GetImageMode(),
x.GetCompressionAlgorithm());
serializer.SerializeImage(data_message);
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
}
REQUIRE(pusher.EndDataCollection(end));
const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_acked_total == nframes)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
REQUIRE(progress->data_sent == nframes);
REQUIRE(progress->data_acked_ok == nframes / 2);
REQUIRE(progress->data_acked_bad == nframes / 2);
REQUIRE(progress->data_acked_total == nframes);
REQUIRE(progress->data_ack_pending == 0);
});
std::vector<std::thread> receivers;
receivers.reserve(npullers);
for (int w = 0; w < npullers; w++) {
receivers.emplace_back([&, w] {
bool seen_end = false;
uint64_t processed = 0;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(10));
REQUIRE(out.has_value());
REQUIRE(out->cbor != nullptr);
REQUIRE(out->tcp_msg != nullptr);
const auto &h = out->tcp_msg->header;
if (out->cbor->start_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::START;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = 0;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->data_message) {
auto number = out->cbor->data_message->number;
processed++;
PullerAckMessage ack;
ack.ack_for = TCPFrameType::DATA;
ack.ok = (number % 2 == 0) ? true : false;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = number;
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->end_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::END;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_end = true;
}
}
});
}
sender.join();
for (auto &t : receivers) t.join();
for (auto &p : puller)
p->Disconnect();
}
TEST_CASE("TCPImageCommTest_GetAckProgress_InFlightPending", "[TCP]") {
const size_t nframes = 128;
const int64_t npullers = 2;
const int64_t images_per_file = 8;
DiffractionExperiment x(DetJF(1));
x.Raw();
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
std::mt19937 g1(321);
std::uniform_int_distribution<uint16_t> dist;
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
for (auto &i : image1) i = dist(g1);
std::vector<std::string> addr{
"tcp://127.0.0.1:19031",
"tcp://127.0.0.1:19032"
};
std::vector<std::unique_ptr<TCPImagePuller>> puller;
for (int i = 0; i < npullers; i++) {
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
}
TCPStreamPusher pusher(
addr,
64 * 1024 * 1024,
128 * 1024,
8192
);
std::atomic<bool> observed_pending{false};
std::thread sender([&] {
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
StartMessage start{
.images_per_file = images_per_file,
.write_master_file = true
};
EndMessage end{};
pusher.StartDataCollection(start);
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
DataMessage data_message;
data_message.number = i;
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
x.GetPixelsNum() * sizeof(uint16_t),
x.GetXPixelsNum(),
x.GetYPixelsNum(),
x.GetImageMode(),
x.GetCompressionAlgorithm());
serializer.SerializeImage(data_message);
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
if (i >= 16 && !observed_pending.load()) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_sent > progress->data_acked_total && progress->data_ack_pending > 0) {
observed_pending = true;
}
}
}
REQUIRE(pusher.EndDataCollection(end));
REQUIRE(observed_pending.load());
const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_acked_total == nframes)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
REQUIRE(progress->data_sent == nframes);
REQUIRE(progress->data_acked_ok == nframes);
REQUIRE(progress->data_acked_bad == 0);
REQUIRE(progress->data_acked_total == nframes);
REQUIRE(progress->data_ack_pending == 0);
});
std::vector<std::thread> receivers;
receivers.reserve(npullers);
for (int w = 0; w < npullers; w++) {
receivers.emplace_back([&, w] {
bool seen_end = false;
uint64_t processed = 0;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(10));
REQUIRE(out.has_value());
REQUIRE(out->cbor != nullptr);
REQUIRE(out->tcp_msg != nullptr);
const auto &h = out->tcp_msg->header;
if (out->cbor->start_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::START;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = 0;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->data_message) {
processed++;
std::this_thread::sleep_for(std::chrono::milliseconds(3));
PullerAckMessage ack;
ack.ack_for = TCPFrameType::DATA;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->end_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::END;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_end = true;
}
}
});
}
sender.join();
for (auto &t : receivers) t.join();
for (auto &p : puller)
p->Disconnect();
}