TCP: Implemented ACK return stream, as a feedback channel (to be read properly!)
This commit is contained in:
@@ -201,8 +201,6 @@ std::unique_ptr<ImagePusher> ParseTCPImagePusher(const org::openapitools::server
|
||||
|
||||
auto tmp = std::make_unique<TCPStreamPusher>(j.getZeromq().getImageSocket(), send_buffer_size);
|
||||
|
||||
if (j.getZeromq().writerNotificationSocketIsSet())
|
||||
tmp->WriterNotificationSocket(j.getZeromq().getWriterNotificationSocket());
|
||||
return std::move(tmp);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,15 +6,33 @@
|
||||
#include <cstdint>
|
||||
|
||||
constexpr uint32_t JFJOCH_TCP_MAGIC = 0x4A464A54; // JFJT
|
||||
constexpr uint32_t JFJOCH_TCP_VERSION = 1;
|
||||
constexpr uint32_t JFJOCH_TCP_VERSION = 2;
|
||||
|
||||
enum class TCPFrameType : uint16_t {
|
||||
START = 1,
|
||||
DATA = 2,
|
||||
CALIBRATION = 3,
|
||||
END = 4
|
||||
END = 4,
|
||||
ACK = 5,
|
||||
CANCEL = 6
|
||||
};
|
||||
|
||||
enum class TCPAckCode : uint16_t {
|
||||
None = 0,
|
||||
StartFailed = 1,
|
||||
DataWriteFailed = 2,
|
||||
EndFailed = 3,
|
||||
DiskQuotaExceeded = 4,
|
||||
NoSpaceLeft = 5,
|
||||
PermissionDenied = 6,
|
||||
IoError = 7,
|
||||
ProtocolError = 8
|
||||
};
|
||||
|
||||
constexpr uint32_t TCP_ACK_FLAG_OK = 1u << 0;
|
||||
constexpr uint32_t TCP_ACK_FLAG_FATAL = 1u << 1;
|
||||
constexpr uint32_t TCP_ACK_FLAG_HAS_ERROR_TEXT = 1u << 2;
|
||||
|
||||
struct alignas(64) TcpFrameHeader {
|
||||
uint32_t magic = JFJOCH_TCP_MAGIC;
|
||||
uint16_t version = JFJOCH_TCP_VERSION ;
|
||||
@@ -24,5 +42,10 @@ struct alignas(64) TcpFrameHeader {
|
||||
uint32_t socket_number = 0;
|
||||
uint32_t flags = 0;
|
||||
uint64_t run_number = 0;
|
||||
uint64_t reserved[4] = {0, 0, 0, 0};
|
||||
|
||||
uint32_t ack_processed_images = 0;
|
||||
uint16_t ack_code = 0;
|
||||
uint16_t ack_for = 0;
|
||||
|
||||
uint64_t reserved[2] = {0, 0};
|
||||
};
|
||||
@@ -11,6 +11,19 @@
|
||||
#include "../frame_serialize/CBORStream2Deserializer.h"
|
||||
#include "../common/ThreadSafeFIFO.h"
|
||||
#include "../common/JfjochTCP.h"
|
||||
#include "../common/JfjochTCP.h"
|
||||
|
||||
struct PullerAckMessage {
|
||||
TCPFrameType ack_for = TCPFrameType::DATA;
|
||||
bool ok = true;
|
||||
bool fatal = false;
|
||||
uint64_t run_number = 0;
|
||||
uint32_t socket_number = 0;
|
||||
uint64_t image_number = 0;
|
||||
uint64_t processed_images = 0;
|
||||
TCPAckCode error_code = TCPAckCode::None;
|
||||
std::string error_text;
|
||||
};
|
||||
|
||||
struct RawFrame {
|
||||
TcpFrameHeader header{};
|
||||
@@ -42,6 +55,9 @@ public:
|
||||
[[nodiscard]] size_t GetMaxFifoUtilization() const;
|
||||
[[nodiscard]] size_t GetCurrentFifoUtilization() const;
|
||||
|
||||
virtual bool SupportsAck() const { return false; }
|
||||
virtual bool SendAck(const PullerAckMessage &) { return false; }
|
||||
|
||||
virtual void Disconnect() = 0;
|
||||
};
|
||||
|
||||
|
||||
@@ -51,6 +51,77 @@ TCPImagePuller::TCPImagePuller(const std::string &tcp_addr,
|
||||
cbor_thread = std::thread(&TCPImagePuller::CBORThread, this);
|
||||
}
|
||||
|
||||
bool TCPImagePuller::SendAll(const void *buf, size_t len) {
|
||||
const auto *p = static_cast<const uint8_t *>(buf);
|
||||
size_t sent = 0;
|
||||
while (sent < len) {
|
||||
if (disconnect)
|
||||
return false;
|
||||
|
||||
int local_fd = -1;
|
||||
{
|
||||
std::unique_lock ul(fd_mutex);
|
||||
local_fd = fd;
|
||||
}
|
||||
if (local_fd < 0)
|
||||
return false;
|
||||
|
||||
ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL);
|
||||
if (rc < 0) {
|
||||
if (errno == EINTR)
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
sent += static_cast<size_t>(rc);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TCPImagePuller::SendAck(const PullerAckMessage &ack) {
|
||||
TcpFrameHeader h{};
|
||||
h.type = static_cast<uint16_t>(TCPFrameType::ACK);
|
||||
h.run_number = ack.run_number;
|
||||
h.socket_number = ack.socket_number;
|
||||
h.image_number = ack.image_number;
|
||||
h.flags = 0;
|
||||
if (ack.ok)
|
||||
h.flags |= TCP_ACK_FLAG_OK;
|
||||
if (ack.fatal)
|
||||
h.flags |= TCP_ACK_FLAG_FATAL;
|
||||
if (!ack.error_text.empty())
|
||||
h.flags |= TCP_ACK_FLAG_HAS_ERROR_TEXT;
|
||||
|
||||
h.ack_for = static_cast<uint16_t>(ack.ack_for);
|
||||
h.ack_processed_images = ack.processed_images;
|
||||
h.ack_code = static_cast<uint32_t>(ack.error_code);
|
||||
h.payload_size = ack.error_text.size();
|
||||
|
||||
if (!SendAll(&h, sizeof(h)))
|
||||
return false;
|
||||
if (!ack.error_text.empty())
|
||||
return SendAll(ack.error_text.data(), ack.error_text.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
void TCPImagePuller::CBORThread() {
|
||||
auto ret = cbor_fifo.GetBlocking();
|
||||
while (ret.tcp_msg) {
|
||||
try {
|
||||
const auto type = static_cast<TCPFrameType>(ret.tcp_msg->header.type);
|
||||
if (type == TCPFrameType::CANCEL) {
|
||||
outside_fifo.PutBlocking(ret);
|
||||
} else {
|
||||
ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size());
|
||||
outside_fifo.PutBlocking(ret);
|
||||
}
|
||||
} catch (const JFJochException &e) {
|
||||
logger.ErrorException(e);
|
||||
}
|
||||
ret = cbor_fifo.GetBlocking();
|
||||
}
|
||||
outside_fifo.PutBlocking(ret);
|
||||
}
|
||||
|
||||
TCPImagePuller::~TCPImagePuller() {
|
||||
TCPImagePuller::Disconnect();
|
||||
}
|
||||
@@ -179,6 +250,17 @@ void TCPImagePuller::ReceiverThread() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ignore ACK on puller side
|
||||
if (static_cast<TCPFrameType>(frame.header.type) == TCPFrameType::ACK) {
|
||||
if (frame.header.payload_size > 0) {
|
||||
std::vector<uint8_t> discard(frame.header.payload_size);
|
||||
if (!ReadExact(discard.data(), discard.size())) {
|
||||
CloseSocket();
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
ImagePullerOutput out;
|
||||
out.tcp_msg = std::make_shared<RawFrame>();
|
||||
out.tcp_msg->header = frame.header;
|
||||
@@ -206,19 +288,6 @@ void TCPImagePuller::ReceiverThread() {
|
||||
cbor_fifo.PutBlocking(ImagePullerOutput{});
|
||||
}
|
||||
|
||||
void TCPImagePuller::CBORThread() {
|
||||
auto ret = cbor_fifo.GetBlocking();
|
||||
while (ret.tcp_msg) {
|
||||
try {
|
||||
ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size());
|
||||
outside_fifo.PutBlocking(ret);
|
||||
} catch (const JFJochException &e) {
|
||||
logger.ErrorException(e);
|
||||
}
|
||||
ret = cbor_fifo.GetBlocking();
|
||||
}
|
||||
outside_fifo.PutBlocking(ret);
|
||||
}
|
||||
void TCPImagePuller::Disconnect() {
|
||||
if (disconnect.exchange(true))
|
||||
return;
|
||||
|
||||
@@ -29,6 +29,7 @@ class TCPImagePuller : public ImagePuller {
|
||||
Logger logger{"TCPImagePuller"};
|
||||
|
||||
bool ReadExact(void *buf, size_t size);
|
||||
bool SendAll(const void *buf, size_t len);
|
||||
bool EnsureConnected();
|
||||
void CloseSocket();
|
||||
void ReceiverThread();
|
||||
@@ -37,5 +38,7 @@ public:
|
||||
explicit TCPImagePuller(const std::string &tcp_addr, std::optional<int32_t> rcv_buffer_size = {});
|
||||
|
||||
~TCPImagePuller() override;
|
||||
bool SupportsAck() const override { return true; }
|
||||
bool SendAck(const PullerAckMessage &ack) override;
|
||||
void Disconnect() override;
|
||||
};
|
||||
@@ -26,7 +26,6 @@ public:
|
||||
void SendImage(ZeroCopyReturnValue &z) override;
|
||||
bool SendCalibration(const CompressedImage &message) override;
|
||||
|
||||
|
||||
std::string PrintSetup() const override;
|
||||
};
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ TCPStreamPusher::TCPStreamPusher(const std::vector<std::string> &addr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void TCPStreamPusher::StartDataCollection(StartMessage &message) {
|
||||
if (message.images_per_file < 1)
|
||||
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
||||
@@ -35,6 +34,25 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) {
|
||||
"TCP accept timeout/failure on socket " + socket[i]->GetEndpointName());
|
||||
}
|
||||
|
||||
for (auto &s : socket)
|
||||
s->StartWriterThread();
|
||||
|
||||
std::vector<bool> started(socket.size(), false);
|
||||
|
||||
auto rollback_cancel = [&]() {
|
||||
for (size_t i = 0; i < socket.size(); i++) {
|
||||
if (!started[i] || socket[i]->IsBroken())
|
||||
continue;
|
||||
|
||||
(void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL);
|
||||
std::string cancel_ack_err;
|
||||
(void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err);
|
||||
}
|
||||
|
||||
for (auto &s : socket)
|
||||
s->StopWriterThread();
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < socket.size(); i++) {
|
||||
message.socket_number = static_cast<int64_t>(i);
|
||||
if (i > 0)
|
||||
@@ -44,17 +62,20 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) {
|
||||
socket[i]->SetRunNumber(run_number);
|
||||
|
||||
if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) {
|
||||
// one-shot recovery: reconnect and retry START once
|
||||
if (!socket[i]->AcceptConnection(std::chrono::seconds(5)) ||
|
||||
!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) {
|
||||
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
||||
"Timeout/failure sending START");
|
||||
}
|
||||
rollback_cancel();
|
||||
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
||||
"Timeout/failure sending START on " + socket[i]->GetEndpointName());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &s : socket)
|
||||
s->StartWriterThread();
|
||||
std::string ack_err;
|
||||
if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) {
|
||||
rollback_cancel();
|
||||
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
||||
"START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err);
|
||||
}
|
||||
|
||||
started[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) {
|
||||
@@ -88,12 +109,25 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage &message) {
|
||||
|
||||
bool ret = true;
|
||||
for (auto &s : socket) {
|
||||
s->StopWriterThread();
|
||||
if (s->IsBroken())
|
||||
if (s->IsBroken()) {
|
||||
ret = false;
|
||||
else if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END))
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) {
|
||||
ret = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string ack_err;
|
||||
if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) {
|
||||
ret = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &s : socket)
|
||||
s->StopWriterThread();
|
||||
|
||||
transmission_error = !ret;
|
||||
return ret;
|
||||
}
|
||||
@@ -102,32 +136,16 @@ std::string TCPStreamPusher::Finalize() {
|
||||
std::string ret;
|
||||
if (transmission_error)
|
||||
ret += "Timeout sending images (e.g., writer disabled during data collection);";
|
||||
if (writer_notification_socket) {
|
||||
for (size_t i = 0; i < socket.size(); i++) {
|
||||
auto n = writer_notification_socket->Receive(run_number, run_name);
|
||||
if (!n)
|
||||
ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end";
|
||||
else if (static_cast<size_t>(n->socket_number) >= socket.size())
|
||||
ret += "Writer " + std::to_string(i) + ": mismatch in socket number";
|
||||
else if (!n->ok)
|
||||
ret += "Writer " + std::to_string(i) + ": " + n->error;
|
||||
|
||||
for (size_t i = 0; i < socket.size(); i++) {
|
||||
if (socket[i]->IsBroken()) {
|
||||
const auto reason = socket[i]->GetLastAckError();
|
||||
ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string TCPStreamPusher::GetWriterNotificationSocketAddress() const {
|
||||
if (writer_notification_socket)
|
||||
return writer_notification_socket->GetEndpointName();
|
||||
else
|
||||
return "";
|
||||
}
|
||||
|
||||
TCPStreamPusher &TCPStreamPusher::WriterNotificationSocket(const std::string &addr) {
|
||||
writer_notification_socket = std::make_unique<ZMQWriterNotificationPuller>(addr, std::chrono::minutes(1));
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string TCPStreamPusher::PrintSetup() const {
|
||||
std::string output = "TCPStream2Pusher: Sending images to sockets: ";
|
||||
for (const auto &s : socket)
|
||||
|
||||
@@ -11,8 +11,6 @@ class TCPStreamPusher : public ImagePusher {
|
||||
CBORStream2Serializer serializer;
|
||||
std::vector<std::unique_ptr<TCPStreamPusherSocket>> socket;
|
||||
|
||||
std::unique_ptr<ZMQWriterNotificationPuller> writer_notification_socket;
|
||||
|
||||
int64_t images_per_file = 1;
|
||||
uint64_t run_number = 0;
|
||||
std::string run_name;
|
||||
@@ -23,9 +21,6 @@ public:
|
||||
std::optional<size_t> zerocopy_threshold = {},
|
||||
size_t send_queue_size = 4096);
|
||||
|
||||
TCPStreamPusher& WriterNotificationSocket(const std::string& addr);
|
||||
std::string GetWriterNotificationSocketAddress() const override;
|
||||
|
||||
void StartDataCollection(StartMessage& message) override;
|
||||
bool EndDataCollection(const EndMessage& message) override;
|
||||
bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override;
|
||||
|
||||
@@ -197,6 +197,50 @@ bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) {
|
||||
auto *p = static_cast<uint8_t *>(buf);
|
||||
size_t got = 0;
|
||||
|
||||
while (got < len) {
|
||||
if (!active)
|
||||
return false;
|
||||
|
||||
int local_fd = fd.load();
|
||||
if (local_fd < 0)
|
||||
return false;
|
||||
|
||||
pollfd pfd{};
|
||||
pfd.fd = local_fd;
|
||||
pfd.events = POLLIN;
|
||||
|
||||
const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window
|
||||
if (prc == 0)
|
||||
continue;
|
||||
if (prc < 0) {
|
||||
if (errno == EINTR)
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
|
||||
return false;
|
||||
if ((pfd.revents & POLLIN) == 0)
|
||||
continue;
|
||||
|
||||
ssize_t rc = ::recv(local_fd, p + got, len - got, 0);
|
||||
if (rc == 0)
|
||||
return false;
|
||||
if (rc < 0) {
|
||||
if (errno == EINTR)
|
||||
continue;
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK)
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
got += static_cast<size_t>(rc);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) {
|
||||
#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY)
|
||||
int local_fd = fd.load();
|
||||
@@ -358,10 +402,89 @@ void TCPStreamPusherSocket::CompletionThread() {
|
||||
#endif
|
||||
}
|
||||
|
||||
void TCPStreamPusherSocket::AckThread() {
|
||||
while (active) {
|
||||
TcpFrameHeader h{};
|
||||
if (!ReadExact(&h, sizeof(h))) {
|
||||
if (active) {
|
||||
broken = true;
|
||||
logger.Error("TCP ACK reader disconnected on " + endpoint);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast<TCPFrameType>(h.type) != TCPFrameType::ACK) {
|
||||
broken = true;
|
||||
logger.Error("Invalid ACK frame on " + endpoint);
|
||||
break;
|
||||
}
|
||||
|
||||
std::string error_text;
|
||||
if (h.payload_size > 0) {
|
||||
error_text.resize(h.payload_size);
|
||||
if (!ReadExact(error_text.data(), error_text.size())) {
|
||||
broken = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const auto ack_for = static_cast<TCPFrameType>(h.ack_for);
|
||||
const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0;
|
||||
const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0;
|
||||
const auto code = static_cast<TCPAckCode>(h.ack_code);
|
||||
|
||||
{
|
||||
std::unique_lock ul(ack_state_mutex);
|
||||
last_ack_code = code;
|
||||
if (!error_text.empty())
|
||||
last_ack_error = error_text;
|
||||
|
||||
if (ack_for == TCPFrameType::START) {
|
||||
start_ack_received = true;
|
||||
start_ack_ok = ok;
|
||||
if (!ok && error_text.empty())
|
||||
last_ack_error = "START rejected";
|
||||
} else if (ack_for == TCPFrameType::END) {
|
||||
end_ack_received = true;
|
||||
end_ack_ok = ok;
|
||||
if (!ok && error_text.empty())
|
||||
last_ack_error = "END rejected";
|
||||
} else if (ack_for == TCPFrameType::CANCEL) {
|
||||
cancel_ack_received = true;
|
||||
cancel_ack_ok = ok;
|
||||
if (!ok && error_text.empty())
|
||||
last_ack_error = "CANCEL rejected";
|
||||
} else if (ack_for == TCPFrameType::DATA && (!ok || fatal)) {
|
||||
broken = true;
|
||||
if (error_text.empty())
|
||||
last_ack_error = "DATA fatal ACK";
|
||||
logger.Error("Received fatal DATA ACK on " + endpoint + ": " + last_ack_error);
|
||||
}
|
||||
}
|
||||
ack_cv.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void TCPStreamPusherSocket::StartWriterThread() {
|
||||
if (active)
|
||||
return;
|
||||
|
||||
{
|
||||
std::unique_lock ul(ack_state_mutex);
|
||||
start_ack_received = false;
|
||||
start_ack_ok = false;
|
||||
end_ack_received = false;
|
||||
end_ack_ok = false;
|
||||
cancel_ack_received = false;
|
||||
cancel_ack_ok = false;
|
||||
last_ack_error.clear();
|
||||
last_ack_code = TCPAckCode::None;
|
||||
}
|
||||
|
||||
active = true;
|
||||
send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this);
|
||||
completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this);
|
||||
ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this);
|
||||
}
|
||||
|
||||
void TCPStreamPusherSocket::StopWriterThread() {
|
||||
@@ -369,11 +492,14 @@ void TCPStreamPusherSocket::StopWriterThread() {
|
||||
return;
|
||||
active = false;
|
||||
queue.PutBlocking({.end = true});
|
||||
ack_cv.notify_all();
|
||||
|
||||
if (send_future.valid())
|
||||
send_future.get();
|
||||
if (completion_future.valid())
|
||||
completion_future.get();
|
||||
if (ack_future.valid())
|
||||
ack_future.get();
|
||||
|
||||
// Keep fd open: END frame may still be sent after writer thread stops.
|
||||
// Socket is closed in destructor / explicit close path.
|
||||
@@ -403,3 +529,46 @@ bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType
|
||||
|
||||
return SendFrame(data, size, type, image_number, nullptr);
|
||||
}
|
||||
|
||||
bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) {
|
||||
std::unique_lock ul(ack_state_mutex);
|
||||
const bool ok = ack_cv.wait_for(ul, timeout, [&] {
|
||||
if (ack_for == TCPFrameType::START)
|
||||
return start_ack_received || broken.load();
|
||||
if (ack_for == TCPFrameType::END)
|
||||
return end_ack_received || broken.load();
|
||||
if (ack_for == TCPFrameType::CANCEL)
|
||||
return cancel_ack_received || broken.load();
|
||||
return false;
|
||||
});
|
||||
|
||||
if (!ok) {
|
||||
if (error_text)
|
||||
*error_text = "ACK timeout";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (broken) {
|
||||
if (error_text)
|
||||
*error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ack_ok = false;
|
||||
if (ack_for == TCPFrameType::START)
|
||||
ack_ok = start_ack_ok;
|
||||
else if (ack_for == TCPFrameType::END)
|
||||
ack_ok = end_ack_ok;
|
||||
else if (ack_for == TCPFrameType::CANCEL)
|
||||
ack_ok = cancel_ack_ok;
|
||||
|
||||
if (!ack_ok && error_text)
|
||||
*error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error;
|
||||
|
||||
return ack_ok;
|
||||
}
|
||||
|
||||
std::string TCPStreamPusherSocket::GetLastAckError() const {
|
||||
std::unique_lock ul(ack_state_mutex);
|
||||
return last_ack_error;
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ class TCPStreamPusherSocket {
|
||||
std::atomic<bool> active = false;
|
||||
std::future<void> send_future;
|
||||
std::future<void> completion_future;
|
||||
std::future<void> ack_future;
|
||||
|
||||
ThreadSafeFIFO<ImagePusherQueueElement> queue;
|
||||
|
||||
@@ -40,6 +41,16 @@ class TCPStreamPusherSocket {
|
||||
constexpr static auto AcceptTimeout = std::chrono::seconds(5);
|
||||
|
||||
std::atomic<bool> broken{false};
|
||||
std::atomic<TCPAckCode> last_ack_code{TCPAckCode::None};
|
||||
std::string last_ack_error;
|
||||
mutable std::mutex ack_state_mutex;
|
||||
std::condition_variable ack_cv;
|
||||
bool start_ack_received = false;
|
||||
bool start_ack_ok = false;
|
||||
bool end_ack_received = false;
|
||||
bool end_ack_ok = false;
|
||||
bool cancel_ack_received = false;
|
||||
bool cancel_ack_ok = false;
|
||||
|
||||
std::atomic<uint64_t> next_tx_id{1};
|
||||
std::mutex inflight_mutex;
|
||||
@@ -49,10 +60,12 @@ class TCPStreamPusherSocket {
|
||||
|
||||
void WriterThread();
|
||||
void CompletionThread();
|
||||
void AckThread();
|
||||
|
||||
void CloseDataSocket();
|
||||
|
||||
bool SendAll(const void *buf, size_t len);
|
||||
bool ReadExact(void *buf, size_t len);
|
||||
bool SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z);
|
||||
bool SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z);
|
||||
public:
|
||||
@@ -74,9 +87,12 @@ public:
|
||||
void StartWriterThread();
|
||||
void StopWriterThread();
|
||||
|
||||
bool WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text = nullptr);
|
||||
|
||||
void SetRunNumber(uint64_t in_run_number);
|
||||
|
||||
void SendImage(ZeroCopyReturnValue &z);
|
||||
|
||||
bool IsBroken() const;
|
||||
std::string GetLastAckError() const;
|
||||
};
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "../image_pusher/TCPStreamPusher.h"
|
||||
#include "../image_puller/TCPImagePuller.h"
|
||||
|
||||
TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
|
||||
TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") {
|
||||
const size_t nframes = 128;
|
||||
const int64_t npullers = 2;
|
||||
const int64_t images_per_file = 16;
|
||||
@@ -23,24 +23,24 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
|
||||
for (auto &i : image1) i = dist(g1);
|
||||
|
||||
std::vector<std::string> addr{
|
||||
"tcp://127.0.0.1:19001",
|
||||
"tcp://127.0.0.1:19002"
|
||||
};
|
||||
"tcp://127.0.0.1:19001",
|
||||
"tcp://127.0.0.1:19002"
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<TCPImagePuller>> puller;
|
||||
for (int i = 0; i < npullers; i++) {
|
||||
puller.push_back(std::make_unique<TCPImagePuller>(
|
||||
addr[i], 64 * 1024 * 1024)); // decoded cbor ring
|
||||
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
|
||||
}
|
||||
|
||||
TCPStreamPusher pusher(
|
||||
addr,
|
||||
64 * 1024 * 1024,
|
||||
128 * 1024, // zerocopy threshold
|
||||
8192 // sender queue
|
||||
128 * 1024,
|
||||
8192
|
||||
);
|
||||
|
||||
std::vector<size_t> received(npullers, 0);
|
||||
std::vector<size_t> processed(npullers, 0);
|
||||
|
||||
std::thread sender([&] {
|
||||
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
||||
@@ -70,29 +70,204 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") {
|
||||
REQUIRE(pusher.EndDataCollection(end));
|
||||
});
|
||||
|
||||
std::vector<std::thread> receivers;
|
||||
receivers.reserve(npullers);
|
||||
|
||||
for (int w = 0; w < npullers; w++) {
|
||||
bool seen_end = false;
|
||||
while (!seen_end) {
|
||||
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
||||
REQUIRE(out.has_value());
|
||||
REQUIRE(out->cbor != nullptr);
|
||||
if (out->cbor->end_message) {
|
||||
seen_end = true;
|
||||
continue;
|
||||
receivers.emplace_back([&, w] {
|
||||
bool seen_start = false;
|
||||
bool seen_end = false;
|
||||
|
||||
while (!seen_end) {
|
||||
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
||||
REQUIRE(out.has_value());
|
||||
REQUIRE(out->cbor != nullptr);
|
||||
REQUIRE(out->tcp_msg != nullptr);
|
||||
|
||||
const auto &h = out->tcp_msg->header;
|
||||
|
||||
if (out->cbor->start_message) {
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = TCPFrameType::START;
|
||||
ack.ok = true;
|
||||
ack.fatal = false;
|
||||
ack.run_number = h.run_number;
|
||||
ack.socket_number = h.socket_number;
|
||||
ack.image_number = 0;
|
||||
ack.processed_images = 0;
|
||||
ack.error_code = TCPAckCode::None;
|
||||
REQUIRE(puller[w]->SendAck(ack));
|
||||
seen_start = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (out->cbor->data_message) {
|
||||
REQUIRE(seen_start);
|
||||
auto n = out->cbor->data_message->number;
|
||||
REQUIRE(((n / images_per_file) % npullers) == w);
|
||||
received[w]++;
|
||||
processed[w]++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (out->cbor->end_message) {
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = TCPFrameType::END;
|
||||
ack.ok = true;
|
||||
ack.fatal = false;
|
||||
ack.run_number = h.run_number;
|
||||
ack.socket_number = h.socket_number;
|
||||
ack.image_number = 0;
|
||||
ack.processed_images = processed[w];
|
||||
ack.error_code = TCPAckCode::None;
|
||||
REQUIRE(puller[w]->SendAck(ack));
|
||||
seen_end = true;
|
||||
}
|
||||
}
|
||||
if (out->cbor->data_message) {
|
||||
auto n = out->cbor->data_message->number;
|
||||
REQUIRE(((n / images_per_file) % npullers) == w);
|
||||
received[w]++;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
sender.join();
|
||||
for (auto &t : receivers) t.join();
|
||||
|
||||
REQUIRE(received[0] == nframes / 2);
|
||||
REQUIRE(received[1] == nframes / 2);
|
||||
|
||||
for (auto &p : puller)
|
||||
p->Disconnect();
|
||||
}
|
||||
|
||||
TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") {
|
||||
const size_t nframes = 64;
|
||||
const int64_t npullers = 2;
|
||||
const int64_t images_per_file = 8;
|
||||
|
||||
DiffractionExperiment x(DetJF(1));
|
||||
x.Raw();
|
||||
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
||||
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
||||
|
||||
std::mt19937 g1(42);
|
||||
std::uniform_int_distribution<uint16_t> dist;
|
||||
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
||||
for (auto &i : image1) i = dist(g1);
|
||||
|
||||
std::vector<std::string> addr{
|
||||
"tcp://127.0.0.1:19011",
|
||||
"tcp://127.0.0.1:19012"
|
||||
};
|
||||
|
||||
std::vector<std::unique_ptr<TCPImagePuller>> puller;
|
||||
for (int i = 0; i < npullers; i++) {
|
||||
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
|
||||
}
|
||||
|
||||
TCPStreamPusher pusher(
|
||||
addr,
|
||||
64 * 1024 * 1024,
|
||||
128 * 1024,
|
||||
8192
|
||||
);
|
||||
|
||||
std::atomic<bool> sent_fatal{false};
|
||||
|
||||
std::thread sender([&] {
|
||||
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
||||
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
||||
|
||||
StartMessage start{
|
||||
.images_per_file = images_per_file,
|
||||
.write_master_file = true
|
||||
};
|
||||
EndMessage end{};
|
||||
|
||||
pusher.StartDataCollection(start);
|
||||
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
||||
DataMessage data_message;
|
||||
data_message.number = i;
|
||||
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
||||
x.GetPixelsNum() * sizeof(uint16_t),
|
||||
x.GetXPixelsNum(),
|
||||
x.GetYPixelsNum(),
|
||||
x.GetImageMode(),
|
||||
x.GetCompressionAlgorithm());
|
||||
serializer.SerializeImage(data_message);
|
||||
(void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
|
||||
}
|
||||
|
||||
REQUIRE_FALSE(pusher.EndDataCollection(end));
|
||||
const auto final_msg = pusher.Finalize();
|
||||
REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota"));
|
||||
});
|
||||
|
||||
std::vector<std::thread> receivers;
|
||||
receivers.reserve(npullers);
|
||||
|
||||
for (int w = 0; w < npullers; w++) {
|
||||
receivers.emplace_back([&, w] {
|
||||
bool seen_end = false;
|
||||
bool local_fatal_sent = false;
|
||||
|
||||
while (!seen_end) {
|
||||
auto out = puller[w]->PollImage(std::chrono::seconds(2));
|
||||
if (!out.has_value()) {
|
||||
// Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream.
|
||||
if (local_fatal_sent)
|
||||
break;
|
||||
REQUIRE(out.has_value());
|
||||
}
|
||||
|
||||
REQUIRE(out->cbor != nullptr);
|
||||
REQUIRE(out->tcp_msg != nullptr);
|
||||
|
||||
const auto &h = out->tcp_msg->header;
|
||||
|
||||
if (out->cbor->start_message) {
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = TCPFrameType::START;
|
||||
ack.ok = true;
|
||||
ack.run_number = h.run_number;
|
||||
ack.socket_number = h.socket_number;
|
||||
ack.error_code = TCPAckCode::None;
|
||||
REQUIRE(puller[w]->SendAck(ack));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (out->cbor->data_message) {
|
||||
if (w == 0 && !sent_fatal.exchange(true)) {
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = TCPFrameType::DATA;
|
||||
ack.ok = false;
|
||||
ack.fatal = true;
|
||||
ack.run_number = h.run_number;
|
||||
ack.socket_number = h.socket_number;
|
||||
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
|
||||
ack.error_code = TCPAckCode::DiskQuotaExceeded;
|
||||
ack.error_text = "quota exceeded";
|
||||
REQUIRE(puller[w]->SendAck(ack));
|
||||
local_fatal_sent = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (out->cbor->end_message) {
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = TCPFrameType::END;
|
||||
ack.ok = true;
|
||||
ack.run_number = h.run_number;
|
||||
ack.socket_number = h.socket_number;
|
||||
ack.error_code = TCPAckCode::None;
|
||||
REQUIRE(puller[w]->SendAck(ack));
|
||||
seen_end = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
sender.join();
|
||||
for (auto &t : receivers) t.join();
|
||||
|
||||
for (auto &p : puller)
|
||||
p->Disconnect();
|
||||
}
|
||||
@@ -20,6 +20,27 @@ StreamWriter::StreamWriter(Logger &in_logger,
|
||||
max_image_number(0) {
|
||||
}
|
||||
|
||||
void StreamWriter::NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text) {
|
||||
if (!image_puller.SupportsAck())
|
||||
return;
|
||||
|
||||
PullerAckMessage ack;
|
||||
ack.ack_for = ack_for;
|
||||
ack.ok = ok;
|
||||
ack.fatal = fatal;
|
||||
ack.error_code = code;
|
||||
ack.error_text = error_text;
|
||||
ack.run_number = run_number;
|
||||
ack.socket_number = static_cast<uint32_t>(socket_number);
|
||||
ack.processed_images = processed_images.load();
|
||||
|
||||
if (image_puller_output.cbor && image_puller_output.cbor->data_message)
|
||||
ack.image_number = image_puller_output.cbor->data_message->number;
|
||||
|
||||
if (!image_puller.SendAck(ack))
|
||||
logger.Warning("Failed to send TCP ACK");
|
||||
}
|
||||
|
||||
void StreamWriter::ProcessStartMessage() {
|
||||
if (state == StreamWriterState::Finalized)
|
||||
return; // Should not happen (?)
|
||||
@@ -28,6 +49,7 @@ void StreamWriter::ProcessStartMessage() {
|
||||
FinalizeDataCollection();
|
||||
|
||||
err = "";
|
||||
tcp_data_fatal_sent = false;
|
||||
|
||||
max_image_number = 0;
|
||||
|
||||
@@ -51,11 +73,13 @@ void StreamWriter::ProcessStartMessage() {
|
||||
image_puller_output.cbor->start_message->file_prefix,
|
||||
image_puller_output.cbor->start_message->number_of_images);
|
||||
state = StreamWriterState::Started;
|
||||
NotifyTcpAck(TCPFrameType::START, true, false, TCPAckCode::None);
|
||||
} catch (const JFJochException &e) {
|
||||
logger.ErrorException(e);
|
||||
logger.Error("Error writing start message - switching to error state");
|
||||
state = StreamWriterState::Error;
|
||||
err = e.what();
|
||||
NotifyTcpAck(TCPFrameType::START, false, true, TCPAckCode::StartFailed, err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +132,10 @@ void StreamWriter::ProcessDataImage() {
|
||||
logger.Warning("Error writing image - switching to error state");
|
||||
state = StreamWriterState::Error;
|
||||
err = e.what();
|
||||
if (!tcp_data_fatal_sent) {
|
||||
tcp_data_fatal_sent = true;
|
||||
NotifyTcpAck(TCPFrameType::DATA, false, true, TCPAckCode::DataWriteFailed, err);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case StreamWriterState::Error:
|
||||
@@ -156,6 +184,11 @@ void StreamWriter::FinalizeDataCollection() {
|
||||
}
|
||||
file_writer.reset();
|
||||
NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr);
|
||||
NotifyTcpAck(TCPFrameType::END,
|
||||
state != StreamWriterState::Error,
|
||||
state == StreamWriterState::Error,
|
||||
state == StreamWriterState::Error ? TCPAckCode::EndFailed : TCPAckCode::None,
|
||||
state == StreamWriterState::Error ? err : "");
|
||||
logger.Info("Data writing finished");
|
||||
state = StreamWriterState::Finalized;
|
||||
}
|
||||
@@ -168,6 +201,21 @@ void StreamWriter::CollectImages() {
|
||||
while (run && state != StreamWriterState::Finalized) {
|
||||
run = WaitForImage();
|
||||
|
||||
if (image_puller_output.tcp_msg &&
|
||||
static_cast<TCPFrameType>(image_puller_output.tcp_msg->header.type) == TCPFrameType::CANCEL) {
|
||||
logger.Warning("Received TCP CANCEL, finalizing data collection");
|
||||
if (state != StreamWriterState::Idle && state != StreamWriterState::Finalized)
|
||||
FinalizeDataCollection();
|
||||
NotifyTcpAck(TCPFrameType::CANCEL, true, false, TCPAckCode::None);
|
||||
state = StreamWriterState::Finalized;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!image_puller_output.cbor) {
|
||||
logger.Warning("Missing CBOR payload for non-CANCEL TCP frame");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (image_puller_output.cbor->start_message)
|
||||
ProcessStartMessage();
|
||||
else if (image_puller_output.cbor->calibration)
|
||||
|
||||
@@ -55,12 +55,14 @@ class StreamWriter {
|
||||
std::vector<HDF5DataFileStatistics> hdf5_data_file_statistics;
|
||||
|
||||
bool debug_skip_write_notification = false;
|
||||
bool tcp_data_fatal_sent = false;
|
||||
|
||||
ImagePuller &image_puller;
|
||||
Logger &logger;
|
||||
void CollectImages();
|
||||
bool WaitForImage();
|
||||
void NotifyReceiverOnFinalizedWrite(const std::string &detector_update_zmq_addr);
|
||||
void NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text = "");
|
||||
void ProcessStartMessage();
|
||||
void ProcessEndMessage();
|
||||
void ProcessDataImage();
|
||||
|
||||
Reference in New Issue
Block a user