TCP: Implemented ACK return stream, as a feedback channel (to be read properly!)

This commit is contained in:
2026-03-04 12:24:05 +01:00
parent 939bb02ce2
commit a3a986830b
13 changed files with 611 additions and 80 deletions

View File

@@ -201,8 +201,6 @@ std::unique_ptr<ImagePusher> ParseTCPImagePusher(const org::openapitools::server
auto tmp = std::make_unique<TCPStreamPusher>(j.getZeromq().getImageSocket(), send_buffer_size);
if (j.getZeromq().writerNotificationSocketIsSet())
tmp->WriterNotificationSocket(j.getZeromq().getWriterNotificationSocket());
return std::move(tmp);
}

View File

@@ -6,15 +6,33 @@
#include <cstdint>
constexpr uint32_t JFJOCH_TCP_MAGIC = 0x4A464A54; // JFJT
constexpr uint32_t JFJOCH_TCP_VERSION = 1;
constexpr uint32_t JFJOCH_TCP_VERSION = 2;
enum class TCPFrameType : uint16_t {
START = 1,
DATA = 2,
CALIBRATION = 3,
END = 4
END = 4,
ACK = 5,
CANCEL = 6
};
enum class TCPAckCode : uint16_t {
None = 0,
StartFailed = 1,
DataWriteFailed = 2,
EndFailed = 3,
DiskQuotaExceeded = 4,
NoSpaceLeft = 5,
PermissionDenied = 6,
IoError = 7,
ProtocolError = 8
};
constexpr uint32_t TCP_ACK_FLAG_OK = 1u << 0;
constexpr uint32_t TCP_ACK_FLAG_FATAL = 1u << 1;
constexpr uint32_t TCP_ACK_FLAG_HAS_ERROR_TEXT = 1u << 2;
struct alignas(64) TcpFrameHeader {
uint32_t magic = JFJOCH_TCP_MAGIC;
uint16_t version = JFJOCH_TCP_VERSION ;
@@ -24,5 +42,10 @@ struct alignas(64) TcpFrameHeader {
uint32_t socket_number = 0;
uint32_t flags = 0;
uint64_t run_number = 0;
uint64_t reserved[4] = {0, 0, 0, 0};
uint32_t ack_processed_images = 0;
uint16_t ack_code = 0;
uint16_t ack_for = 0;
uint64_t reserved[2] = {0, 0};
};

View File

@@ -11,6 +11,19 @@
#include "../frame_serialize/CBORStream2Deserializer.h"
#include "../common/ThreadSafeFIFO.h"
#include "../common/JfjochTCP.h"
#include "../common/JfjochTCP.h"
struct PullerAckMessage {
TCPFrameType ack_for = TCPFrameType::DATA;
bool ok = true;
bool fatal = false;
uint64_t run_number = 0;
uint32_t socket_number = 0;
uint64_t image_number = 0;
uint64_t processed_images = 0;
TCPAckCode error_code = TCPAckCode::None;
std::string error_text;
};
struct RawFrame {
TcpFrameHeader header{};
@@ -42,6 +55,9 @@ public:
[[nodiscard]] size_t GetMaxFifoUtilization() const;
[[nodiscard]] size_t GetCurrentFifoUtilization() const;
virtual bool SupportsAck() const { return false; }
virtual bool SendAck(const PullerAckMessage &) { return false; }
virtual void Disconnect() = 0;
};

View File

@@ -51,6 +51,77 @@ TCPImagePuller::TCPImagePuller(const std::string &tcp_addr,
cbor_thread = std::thread(&TCPImagePuller::CBORThread, this);
}
bool TCPImagePuller::SendAll(const void *buf, size_t len) {
const auto *p = static_cast<const uint8_t *>(buf);
size_t sent = 0;
while (sent < len) {
if (disconnect)
return false;
int local_fd = -1;
{
std::unique_lock ul(fd_mutex);
local_fd = fd;
}
if (local_fd < 0)
return false;
ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL);
if (rc < 0) {
if (errno == EINTR)
continue;
return false;
}
sent += static_cast<size_t>(rc);
}
return true;
}
bool TCPImagePuller::SendAck(const PullerAckMessage &ack) {
TcpFrameHeader h{};
h.type = static_cast<uint16_t>(TCPFrameType::ACK);
h.run_number = ack.run_number;
h.socket_number = ack.socket_number;
h.image_number = ack.image_number;
h.flags = 0;
if (ack.ok)
h.flags |= TCP_ACK_FLAG_OK;
if (ack.fatal)
h.flags |= TCP_ACK_FLAG_FATAL;
if (!ack.error_text.empty())
h.flags |= TCP_ACK_FLAG_HAS_ERROR_TEXT;
h.ack_for = static_cast<uint16_t>(ack.ack_for);
h.ack_processed_images = ack.processed_images;
h.ack_code = static_cast<uint32_t>(ack.error_code);
h.payload_size = ack.error_text.size();
if (!SendAll(&h, sizeof(h)))
return false;
if (!ack.error_text.empty())
return SendAll(ack.error_text.data(), ack.error_text.size());
return true;
}
void TCPImagePuller::CBORThread() {
auto ret = cbor_fifo.GetBlocking();
while (ret.tcp_msg) {
try {
const auto type = static_cast<TCPFrameType>(ret.tcp_msg->header.type);
if (type == TCPFrameType::CANCEL) {
outside_fifo.PutBlocking(ret);
} else {
ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size());
outside_fifo.PutBlocking(ret);
}
} catch (const JFJochException &e) {
logger.ErrorException(e);
}
ret = cbor_fifo.GetBlocking();
}
outside_fifo.PutBlocking(ret);
}
TCPImagePuller::~TCPImagePuller() {
TCPImagePuller::Disconnect();
}
@@ -179,6 +250,17 @@ void TCPImagePuller::ReceiverThread() {
continue;
}
// Ignore ACK on puller side
if (static_cast<TCPFrameType>(frame.header.type) == TCPFrameType::ACK) {
if (frame.header.payload_size > 0) {
std::vector<uint8_t> discard(frame.header.payload_size);
if (!ReadExact(discard.data(), discard.size())) {
CloseSocket();
}
}
continue;
}
ImagePullerOutput out;
out.tcp_msg = std::make_shared<RawFrame>();
out.tcp_msg->header = frame.header;
@@ -206,19 +288,6 @@ void TCPImagePuller::ReceiverThread() {
cbor_fifo.PutBlocking(ImagePullerOutput{});
}
void TCPImagePuller::CBORThread() {
auto ret = cbor_fifo.GetBlocking();
while (ret.tcp_msg) {
try {
ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size());
outside_fifo.PutBlocking(ret);
} catch (const JFJochException &e) {
logger.ErrorException(e);
}
ret = cbor_fifo.GetBlocking();
}
outside_fifo.PutBlocking(ret);
}
void TCPImagePuller::Disconnect() {
if (disconnect.exchange(true))
return;

View File

@@ -29,6 +29,7 @@ class TCPImagePuller : public ImagePuller {
Logger logger{"TCPImagePuller"};
bool ReadExact(void *buf, size_t size);
bool SendAll(const void *buf, size_t len);
bool EnsureConnected();
void CloseSocket();
void ReceiverThread();
@@ -37,5 +38,7 @@ public:
explicit TCPImagePuller(const std::string &tcp_addr, std::optional<int32_t> rcv_buffer_size = {});
~TCPImagePuller() override;
bool SupportsAck() const override { return true; }
bool SendAck(const PullerAckMessage &ack) override;
void Disconnect() override;
};

View File

@@ -26,7 +26,6 @@ public:
void SendImage(ZeroCopyReturnValue &z) override;
bool SendCalibration(const CompressedImage &message) override;
std::string PrintSetup() const override;
};

View File

@@ -19,7 +19,6 @@ TCPStreamPusher::TCPStreamPusher(const std::vector<std::string> &addr,
}
}
void TCPStreamPusher::StartDataCollection(StartMessage &message) {
if (message.images_per_file < 1)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
@@ -35,6 +34,25 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) {
"TCP accept timeout/failure on socket " + socket[i]->GetEndpointName());
}
for (auto &s : socket)
s->StartWriterThread();
std::vector<bool> started(socket.size(), false);
auto rollback_cancel = [&]() {
for (size_t i = 0; i < socket.size(); i++) {
if (!started[i] || socket[i]->IsBroken())
continue;
(void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL);
std::string cancel_ack_err;
(void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err);
}
for (auto &s : socket)
s->StopWriterThread();
};
for (size_t i = 0; i < socket.size(); i++) {
message.socket_number = static_cast<int64_t>(i);
if (i > 0)
@@ -44,17 +62,20 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) {
socket[i]->SetRunNumber(run_number);
if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) {
// one-shot recovery: reconnect and retry START once
if (!socket[i]->AcceptConnection(std::chrono::seconds(5)) ||
!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) {
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Timeout/failure sending START");
}
rollback_cancel();
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Timeout/failure sending START on " + socket[i]->GetEndpointName());
}
}
for (auto &s : socket)
s->StartWriterThread();
std::string ack_err;
if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) {
rollback_cancel();
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err);
}
started[i] = true;
}
}
bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) {
@@ -88,12 +109,25 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage &message) {
bool ret = true;
for (auto &s : socket) {
s->StopWriterThread();
if (s->IsBroken())
if (s->IsBroken()) {
ret = false;
else if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END))
continue;
}
if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) {
ret = false;
continue;
}
std::string ack_err;
if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) {
ret = false;
}
}
for (auto &s : socket)
s->StopWriterThread();
transmission_error = !ret;
return ret;
}
@@ -102,32 +136,16 @@ std::string TCPStreamPusher::Finalize() {
std::string ret;
if (transmission_error)
ret += "Timeout sending images (e.g., writer disabled during data collection);";
if (writer_notification_socket) {
for (size_t i = 0; i < socket.size(); i++) {
auto n = writer_notification_socket->Receive(run_number, run_name);
if (!n)
ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end";
else if (static_cast<size_t>(n->socket_number) >= socket.size())
ret += "Writer " + std::to_string(i) + ": mismatch in socket number";
else if (!n->ok)
ret += "Writer " + std::to_string(i) + ": " + n->error;
for (size_t i = 0; i < socket.size(); i++) {
if (socket[i]->IsBroken()) {
const auto reason = socket[i]->GetLastAckError();
ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";";
}
}
return ret;
}
std::string TCPStreamPusher::GetWriterNotificationSocketAddress() const {
if (writer_notification_socket)
return writer_notification_socket->GetEndpointName();
else
return "";
}
TCPStreamPusher &TCPStreamPusher::WriterNotificationSocket(const std::string &addr) {
writer_notification_socket = std::make_unique<ZMQWriterNotificationPuller>(addr, std::chrono::minutes(1));
return *this;
}
std::string TCPStreamPusher::PrintSetup() const {
std::string output = "TCPStream2Pusher: Sending images to sockets: ";
for (const auto &s : socket)

View File

@@ -11,8 +11,6 @@ class TCPStreamPusher : public ImagePusher {
CBORStream2Serializer serializer;
std::vector<std::unique_ptr<TCPStreamPusherSocket>> socket;
std::unique_ptr<ZMQWriterNotificationPuller> writer_notification_socket;
int64_t images_per_file = 1;
uint64_t run_number = 0;
std::string run_name;
@@ -23,9 +21,6 @@ public:
std::optional<size_t> zerocopy_threshold = {},
size_t send_queue_size = 4096);
TCPStreamPusher& WriterNotificationSocket(const std::string& addr);
std::string GetWriterNotificationSocketAddress() const override;
void StartDataCollection(StartMessage& message) override;
bool EndDataCollection(const EndMessage& message) override;
bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override;

View File

@@ -197,6 +197,50 @@ bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) {
return true;
}
bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) {
auto *p = static_cast<uint8_t *>(buf);
size_t got = 0;
while (got < len) {
if (!active)
return false;
int local_fd = fd.load();
if (local_fd < 0)
return false;
pollfd pfd{};
pfd.fd = local_fd;
pfd.events = POLLIN;
const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window
if (prc == 0)
continue;
if (prc < 0) {
if (errno == EINTR)
continue;
return false;
}
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
return false;
if ((pfd.revents & POLLIN) == 0)
continue;
ssize_t rc = ::recv(local_fd, p + got, len - got, 0);
if (rc == 0)
return false;
if (rc < 0) {
if (errno == EINTR)
continue;
if (errno == EAGAIN || errno == EWOULDBLOCK)
continue;
return false;
}
got += static_cast<size_t>(rc);
}
return true;
}
bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) {
#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY)
int local_fd = fd.load();
@@ -358,10 +402,89 @@ void TCPStreamPusherSocket::CompletionThread() {
#endif
}
void TCPStreamPusherSocket::AckThread() {
while (active) {
TcpFrameHeader h{};
if (!ReadExact(&h, sizeof(h))) {
if (active) {
broken = true;
logger.Error("TCP ACK reader disconnected on " + endpoint);
}
break;
}
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast<TCPFrameType>(h.type) != TCPFrameType::ACK) {
broken = true;
logger.Error("Invalid ACK frame on " + endpoint);
break;
}
std::string error_text;
if (h.payload_size > 0) {
error_text.resize(h.payload_size);
if (!ReadExact(error_text.data(), error_text.size())) {
broken = true;
break;
}
}
const auto ack_for = static_cast<TCPFrameType>(h.ack_for);
const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0;
const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0;
const auto code = static_cast<TCPAckCode>(h.ack_code);
{
std::unique_lock ul(ack_state_mutex);
last_ack_code = code;
if (!error_text.empty())
last_ack_error = error_text;
if (ack_for == TCPFrameType::START) {
start_ack_received = true;
start_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "START rejected";
} else if (ack_for == TCPFrameType::END) {
end_ack_received = true;
end_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "END rejected";
} else if (ack_for == TCPFrameType::CANCEL) {
cancel_ack_received = true;
cancel_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "CANCEL rejected";
} else if (ack_for == TCPFrameType::DATA && (!ok || fatal)) {
broken = true;
if (error_text.empty())
last_ack_error = "DATA fatal ACK";
logger.Error("Received fatal DATA ACK on " + endpoint + ": " + last_ack_error);
}
}
ack_cv.notify_all();
}
}
void TCPStreamPusherSocket::StartWriterThread() {
if (active)
return;
{
std::unique_lock ul(ack_state_mutex);
start_ack_received = false;
start_ack_ok = false;
end_ack_received = false;
end_ack_ok = false;
cancel_ack_received = false;
cancel_ack_ok = false;
last_ack_error.clear();
last_ack_code = TCPAckCode::None;
}
active = true;
send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this);
completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this);
ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this);
}
void TCPStreamPusherSocket::StopWriterThread() {
@@ -369,11 +492,14 @@ void TCPStreamPusherSocket::StopWriterThread() {
return;
active = false;
queue.PutBlocking({.end = true});
ack_cv.notify_all();
if (send_future.valid())
send_future.get();
if (completion_future.valid())
completion_future.get();
if (ack_future.valid())
ack_future.get();
// Keep fd open: END frame may still be sent after writer thread stops.
// Socket is closed in destructor / explicit close path.
@@ -403,3 +529,46 @@ bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType
return SendFrame(data, size, type, image_number, nullptr);
}
bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) {
std::unique_lock ul(ack_state_mutex);
const bool ok = ack_cv.wait_for(ul, timeout, [&] {
if (ack_for == TCPFrameType::START)
return start_ack_received || broken.load();
if (ack_for == TCPFrameType::END)
return end_ack_received || broken.load();
if (ack_for == TCPFrameType::CANCEL)
return cancel_ack_received || broken.load();
return false;
});
if (!ok) {
if (error_text)
*error_text = "ACK timeout";
return false;
}
if (broken) {
if (error_text)
*error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error;
return false;
}
bool ack_ok = false;
if (ack_for == TCPFrameType::START)
ack_ok = start_ack_ok;
else if (ack_for == TCPFrameType::END)
ack_ok = end_ack_ok;
else if (ack_for == TCPFrameType::CANCEL)
ack_ok = cancel_ack_ok;
if (!ack_ok && error_text)
*error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error;
return ack_ok;
}
std::string TCPStreamPusherSocket::GetLastAckError() const {
std::unique_lock ul(ack_state_mutex);
return last_ack_error;
}

View File

@@ -26,6 +26,7 @@ class TCPStreamPusherSocket {
std::atomic<bool> active = false;
std::future<void> send_future;
std::future<void> completion_future;
std::future<void> ack_future;
ThreadSafeFIFO<ImagePusherQueueElement> queue;
@@ -40,6 +41,16 @@ class TCPStreamPusherSocket {
constexpr static auto AcceptTimeout = std::chrono::seconds(5);
std::atomic<bool> broken{false};
std::atomic<TCPAckCode> last_ack_code{TCPAckCode::None};
std::string last_ack_error;
mutable std::mutex ack_state_mutex;
std::condition_variable ack_cv;
bool start_ack_received = false;
bool start_ack_ok = false;
bool end_ack_received = false;
bool end_ack_ok = false;
bool cancel_ack_received = false;
bool cancel_ack_ok = false;
std::atomic<uint64_t> next_tx_id{1};
std::mutex inflight_mutex;
@@ -49,10 +60,12 @@ class TCPStreamPusherSocket {
void WriterThread();
void CompletionThread();
void AckThread();
void CloseDataSocket();
bool SendAll(const void *buf, size_t len);
bool ReadExact(void *buf, size_t len);
bool SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z);
bool SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z);
public:
@@ -74,9 +87,12 @@ public:
void StartWriterThread();
void StopWriterThread();
bool WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text = nullptr);
void SetRunNumber(uint64_t in_run_number);
void SendImage(ZeroCopyReturnValue &z);
bool IsBroken() const;
std::string GetLastAckError() const;
};

View File

@@ -7,7 +7,7 @@
#include "../image_pusher/TCPStreamPusher.h"
#include "../image_puller/TCPImagePuller.h"
TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") {
const size_t nframes = 128;
const int64_t npullers = 2;
const int64_t images_per_file = 16;
@@ -23,24 +23,24 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
for (auto &i : image1) i = dist(g1);
std::vector<std::string> addr{
"tcp://127.0.0.1:19001",
"tcp://127.0.0.1:19002"
};
"tcp://127.0.0.1:19001",
"tcp://127.0.0.1:19002"
};
std::vector<std::unique_ptr<TCPImagePuller>> puller;
for (int i = 0; i < npullers; i++) {
puller.push_back(std::make_unique<TCPImagePuller>(
addr[i], 64 * 1024 * 1024)); // decoded cbor ring
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
}
TCPStreamPusher pusher(
addr,
64 * 1024 * 1024,
128 * 1024, // zerocopy threshold
8192 // sender queue
128 * 1024,
8192
);
std::vector<size_t> received(npullers, 0);
std::vector<size_t> processed(npullers, 0);
std::thread sender([&] {
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
@@ -70,29 +70,204 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
REQUIRE(pusher.EndDataCollection(end));
});
std::vector<std::thread> receivers;
receivers.reserve(npullers);
for (int w = 0; w < npullers; w++) {
bool seen_end = false;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(10));
REQUIRE(out.has_value());
REQUIRE(out->cbor != nullptr);
if (out->cbor->end_message) {
seen_end = true;
continue;
receivers.emplace_back([&, w] {
bool seen_start = false;
bool seen_end = false;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(10));
REQUIRE(out.has_value());
REQUIRE(out->cbor != nullptr);
REQUIRE(out->tcp_msg != nullptr);
const auto &h = out->tcp_msg->header;
if (out->cbor->start_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::START;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = 0;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_start = true;
continue;
}
if (out->cbor->data_message) {
REQUIRE(seen_start);
auto n = out->cbor->data_message->number;
REQUIRE(((n / images_per_file) % npullers) == w);
received[w]++;
processed[w]++;
continue;
}
if (out->cbor->end_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::END;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = processed[w];
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_end = true;
}
}
if (out->cbor->data_message) {
auto n = out->cbor->data_message->number;
REQUIRE(((n / images_per_file) % npullers) == w);
received[w]++;
}
}
});
}
sender.join();
for (auto &t : receivers) t.join();
REQUIRE(received[0] == nframes / 2);
REQUIRE(received[1] == nframes / 2);
for (auto &p : puller)
p->Disconnect();
}
TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") {
const size_t nframes = 64;
const int64_t npullers = 2;
const int64_t images_per_file = 8;
DiffractionExperiment x(DetJF(1));
x.Raw();
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
std::mt19937 g1(42);
std::uniform_int_distribution<uint16_t> dist;
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
for (auto &i : image1) i = dist(g1);
std::vector<std::string> addr{
"tcp://127.0.0.1:19011",
"tcp://127.0.0.1:19012"
};
std::vector<std::unique_ptr<TCPImagePuller>> puller;
for (int i = 0; i < npullers; i++) {
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
}
TCPStreamPusher pusher(
addr,
64 * 1024 * 1024,
128 * 1024,
8192
);
std::atomic<bool> sent_fatal{false};
std::thread sender([&] {
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
StartMessage start{
.images_per_file = images_per_file,
.write_master_file = true
};
EndMessage end{};
pusher.StartDataCollection(start);
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
DataMessage data_message;
data_message.number = i;
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
x.GetPixelsNum() * sizeof(uint16_t),
x.GetXPixelsNum(),
x.GetYPixelsNum(),
x.GetImageMode(),
x.GetCompressionAlgorithm());
serializer.SerializeImage(data_message);
(void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
}
REQUIRE_FALSE(pusher.EndDataCollection(end));
const auto final_msg = pusher.Finalize();
REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota"));
});
std::vector<std::thread> receivers;
receivers.reserve(npullers);
for (int w = 0; w < npullers; w++) {
receivers.emplace_back([&, w] {
bool seen_end = false;
bool local_fatal_sent = false;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(2));
if (!out.has_value()) {
// Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream.
if (local_fatal_sent)
break;
REQUIRE(out.has_value());
}
REQUIRE(out->cbor != nullptr);
REQUIRE(out->tcp_msg != nullptr);
const auto &h = out->tcp_msg->header;
if (out->cbor->start_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::START;
ack.ok = true;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->data_message) {
if (w == 0 && !sent_fatal.exchange(true)) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::DATA;
ack.ok = false;
ack.fatal = true;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
ack.error_code = TCPAckCode::DiskQuotaExceeded;
ack.error_text = "quota exceeded";
REQUIRE(puller[w]->SendAck(ack));
local_fatal_sent = true;
}
continue;
}
if (out->cbor->end_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::END;
ack.ok = true;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_end = true;
}
}
});
}
sender.join();
for (auto &t : receivers) t.join();
for (auto &p : puller)
p->Disconnect();
}

View File

@@ -20,6 +20,27 @@ StreamWriter::StreamWriter(Logger &in_logger,
max_image_number(0) {
}
void StreamWriter::NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text) {
if (!image_puller.SupportsAck())
return;
PullerAckMessage ack;
ack.ack_for = ack_for;
ack.ok = ok;
ack.fatal = fatal;
ack.error_code = code;
ack.error_text = error_text;
ack.run_number = run_number;
ack.socket_number = static_cast<uint32_t>(socket_number);
ack.processed_images = processed_images.load();
if (image_puller_output.cbor && image_puller_output.cbor->data_message)
ack.image_number = image_puller_output.cbor->data_message->number;
if (!image_puller.SendAck(ack))
logger.Warning("Failed to send TCP ACK");
}
void StreamWriter::ProcessStartMessage() {
if (state == StreamWriterState::Finalized)
return; // Should not happen (?)
@@ -28,6 +49,7 @@ void StreamWriter::ProcessStartMessage() {
FinalizeDataCollection();
err = "";
tcp_data_fatal_sent = false;
max_image_number = 0;
@@ -51,11 +73,13 @@ void StreamWriter::ProcessStartMessage() {
image_puller_output.cbor->start_message->file_prefix,
image_puller_output.cbor->start_message->number_of_images);
state = StreamWriterState::Started;
NotifyTcpAck(TCPFrameType::START, true, false, TCPAckCode::None);
} catch (const JFJochException &e) {
logger.ErrorException(e);
logger.Error("Error writing start message - switching to error state");
state = StreamWriterState::Error;
err = e.what();
NotifyTcpAck(TCPFrameType::START, false, true, TCPAckCode::StartFailed, err);
}
}
@@ -108,6 +132,10 @@ void StreamWriter::ProcessDataImage() {
logger.Warning("Error writing image - switching to error state");
state = StreamWriterState::Error;
err = e.what();
if (!tcp_data_fatal_sent) {
tcp_data_fatal_sent = true;
NotifyTcpAck(TCPFrameType::DATA, false, true, TCPAckCode::DataWriteFailed, err);
}
}
break;
case StreamWriterState::Error:
@@ -156,6 +184,11 @@ void StreamWriter::FinalizeDataCollection() {
}
file_writer.reset();
NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr);
NotifyTcpAck(TCPFrameType::END,
state != StreamWriterState::Error,
state == StreamWriterState::Error,
state == StreamWriterState::Error ? TCPAckCode::EndFailed : TCPAckCode::None,
state == StreamWriterState::Error ? err : "");
logger.Info("Data writing finished");
state = StreamWriterState::Finalized;
}
@@ -168,6 +201,21 @@ void StreamWriter::CollectImages() {
while (run && state != StreamWriterState::Finalized) {
run = WaitForImage();
if (image_puller_output.tcp_msg &&
static_cast<TCPFrameType>(image_puller_output.tcp_msg->header.type) == TCPFrameType::CANCEL) {
logger.Warning("Received TCP CANCEL, finalizing data collection");
if (state != StreamWriterState::Idle && state != StreamWriterState::Finalized)
FinalizeDataCollection();
NotifyTcpAck(TCPFrameType::CANCEL, true, false, TCPAckCode::None);
state = StreamWriterState::Finalized;
continue;
}
if (!image_puller_output.cbor) {
logger.Warning("Missing CBOR payload for non-CANCEL TCP frame");
continue;
}
if (image_puller_output.cbor->start_message)
ProcessStartMessage();
else if (image_puller_output.cbor->calibration)

View File

@@ -55,12 +55,14 @@ class StreamWriter {
std::vector<HDF5DataFileStatistics> hdf5_data_file_statistics;
bool debug_skip_write_notification = false;
bool tcp_data_fatal_sent = false;
ImagePuller &image_puller;
Logger &logger;
void CollectImages();
bool WaitForImage();
void NotifyReceiverOnFinalizedWrite(const std::string &detector_update_zmq_addr);
void NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text = "");
void ProcessStartMessage();
void ProcessEndMessage();
void ProcessDataImage();