TCPStreamPusher: Implement KEEPALIVE + writers stay connected

This commit is contained in:
2026-03-05 08:30:16 +01:00
parent cc33d5ff9c
commit 91591a3cc3
4 changed files with 327 additions and 76 deletions
+2 -1
View File
@@ -14,7 +14,8 @@ enum class TCPFrameType : uint16_t {
CALIBRATION = 3,
END = 4,
ACK = 5,
CANCEL = 6
CANCEL = 6,
KEEPALIVE = 7,
};
enum class TCPAckCode : uint16_t {
+25 -1
View File
@@ -250,8 +250,32 @@ void TCPImagePuller::ReceiverThread() {
continue;
}
const auto frame_type = static_cast<TCPFrameType>(frame.header.type);
// Respond to keepalive ping with a keepalive pong
if (frame_type == TCPFrameType::KEEPALIVE) {
if (frame.header.payload_size > 0) {
std::vector<uint8_t> discard(frame.header.payload_size);
if (!ReadExact(discard.data(), discard.size())) {
CloseSocket();
std::this_thread::sleep_for(std::chrono::milliseconds(20));
continue;
}
}
// Send keepalive pong back
TcpFrameHeader pong{};
pong.type = static_cast<uint16_t>(TCPFrameType::KEEPALIVE);
pong.payload_size = 0;
if (!SendAll(&pong, sizeof(pong))) {
logger.Info("Keepalive pong send failed, reconnecting to " + addr);
CloseSocket();
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
continue;
}
// Ignore ACK on puller side
if (static_cast<TCPFrameType>(frame.header.type) == TCPFrameType::ACK) {
if (frame_type == TCPFrameType::ACK) {
if (frame.header.payload_size > 0) {
std::vector<uint8_t> discard(frame.header.payload_size);
if (!ReadExact(discard.data(), discard.size())) {
+268 -64
View File
@@ -90,24 +90,48 @@ void TCPStreamPusher::CloseFd(std::atomic<int>& fd) {
}
TCPStreamPusher::TCPStreamPusher(const std::string& addr,
size_t in_expected_connections,
size_t in_max_connections,
std::optional<int32_t> in_send_buffer_size)
: serialization_buffer(256 * 1024 * 1024),
serializer(serialization_buffer.data(), serialization_buffer.size()),
endpoint(addr),
expected_connections(in_expected_connections),
max_connections(in_max_connections),
send_buffer_size(in_send_buffer_size) {
if (endpoint.empty())
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided");
if (expected_connections == 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Expected TCP connections cannot be zero");
if (max_connections == 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Max TCP connections cannot be zero");
listen_fd.store(OpenListenSocket(endpoint));
acceptor_running = true;
acceptor_future = std::async(std::launch::async, &TCPStreamPusher::AcceptorThread, this);
keepalive_future = std::async(std::launch::async, &TCPStreamPusher::KeepaliveThread, this);
logger.Info("TCPStreamPusher listening on " + endpoint + " (max " + std::to_string(max_connections) + " connections)");
}
TCPStreamPusher::~TCPStreamPusher() {
for (auto& c : connections) {
StopConnectionThreads(*c);
CloseFd(c->fd);
acceptor_running = false;
int lfd = listen_fd.exchange(-1);
if (lfd >= 0) {
shutdown(lfd, SHUT_RDWR);
close(lfd);
}
if (acceptor_future.valid())
acceptor_future.get();
if (keepalive_future.valid())
keepalive_future.get();
std::lock_guard lg(connections_mutex);
for (auto& c : connections) {
StopDataCollectionThreads(*c);
c->connected = false;
c->broken = true;
CloseFd(c->fd);
if (c->persistent_ack_future.valid())
c->persistent_ack_future.get();
}
connections.clear();
}
bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const {
@@ -392,6 +416,50 @@ bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) {
return true;
}
bool TCPStreamPusher::ReadExactPersistent(Connection& c, void* buf, size_t len) {
auto* p = static_cast<uint8_t*>(buf);
size_t got = 0;
while (got < len) {
if (!c.connected)
return false;
const int local_fd = c.fd.load();
if (local_fd < 0)
return false;
pollfd pfd{};
pfd.fd = local_fd;
pfd.events = POLLIN;
const int prc = poll(&pfd, 1, 500);
if (prc == 0)
continue;
if (prc < 0) {
if (errno == EINTR)
continue;
return false;
}
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
return false;
if ((pfd.revents & POLLIN) == 0)
continue;
ssize_t rc = ::recv(local_fd, p + got, len - got, 0);
if (rc == 0)
return false;
if (rc < 0) {
if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK)
continue;
return false;
}
got += static_cast<size_t>(rc);
}
return true;
}
void TCPStreamPusher::WriterThread(Connection* c) {
while (c->active) {
auto e = c->queue.GetBlocking();
@@ -418,27 +486,46 @@ void TCPStreamPusher::WriterThread(Connection* c) {
}
}
void TCPStreamPusher::AckThread(Connection* c) {
while (c->active) {
void TCPStreamPusher::PersistentAckThread(Connection* c) {
while (c->connected && !c->broken) {
TcpFrameHeader h{};
if (!ReadExact(*c, &h, sizeof(h))) {
if (c->active) {
if (!ReadExactPersistent(*c, &h, sizeof(h))) {
if (c->connected) {
c->broken = true;
logger.Error("TCP ACK reader disconnected on socket " + std::to_string(c->socket_number));
logger.Info("Persistent connection lost on socket " + std::to_string(c->socket_number));
}
break;
}
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast<TCPFrameType>(h.type) != TCPFrameType::ACK) {
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION) {
c->broken = true;
logger.Error("Invalid ACK frame on socket " + std::to_string(c->socket_number));
logger.Error("Invalid frame on persistent connection, socket " + std::to_string(c->socket_number));
break;
}
const auto frame_type = static_cast<TCPFrameType>(h.type);
// Keepalive pong from the writer
if (frame_type == TCPFrameType::KEEPALIVE) {
c->last_keepalive_recv = std::chrono::steady_clock::now();
if (h.payload_size > 0) {
std::vector<uint8_t> discard(h.payload_size);
ReadExactPersistent(*c, discard.data(), discard.size());
}
continue;
}
if (frame_type != TCPFrameType::ACK) {
c->broken = true;
logger.Error("Unexpected frame type " + std::to_string(h.type) + " on socket " + std::to_string(c->socket_number));
break;
}
// ACK frame — forward to data-collection ack logic
std::string error_text;
if (h.payload_size > 0) {
error_text.resize(h.payload_size);
if (!ReadExact(*c, error_text.data(), error_text.size())) {
if (!ReadExactPersistent(*c, error_text.data(), error_text.size())) {
c->broken = true;
break;
}
@@ -479,8 +566,6 @@ void TCPStreamPusher::AckThread(Connection* c) {
} else {
c->data_acked_bad.fetch_add(1, std::memory_order_relaxed);
total_data_acked_bad.fetch_add(1, std::memory_order_relaxed);
// Soft failure: remember it for Finalize(), do NOT mark socket broken.
c->data_ack_error_reported = true;
if (!error_text.empty()) {
c->data_ack_error_text = error_text;
@@ -495,7 +580,136 @@ void TCPStreamPusher::AckThread(Connection* c) {
}
}
void TCPStreamPusher::StartConnectionThreads(Connection& c) {
void TCPStreamPusher::AcceptorThread() {
uint32_t next_socket_number = 0;
while (acceptor_running) {
int lfd = listen_fd.load();
if (lfd < 0)
break;
int new_fd = AcceptOne(lfd, std::chrono::milliseconds(500));
if (new_fd < 0)
continue;
std::lock_guard lg(connections_mutex);
RemoveDeadConnections();
if (connections.size() >= max_connections) {
logger.Warning("Max connections (" + std::to_string(max_connections) +
") reached, rejecting new connection");
shutdown(new_fd, SHUT_RDWR);
close(new_fd);
continue;
}
SetupNewConnection(new_fd, next_socket_number++);
logger.Info("Accepted writer connection (socket_number=" + std::to_string(next_socket_number - 1) +
", total=" + std::to_string(connections.size()) + ")");
}
}
void TCPStreamPusher::SetupNewConnection(int new_fd, uint32_t socket_number) {
auto c = std::make_unique<Connection>(send_queue_size);
c->socket_number = socket_number;
c->fd.store(new_fd);
int one = 1;
setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
// Enable OS-level TCP keep-alive
setsockopt(new_fd, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one));
int idle = 10;
int intvl = 5;
int cnt = 3;
setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle));
setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPINTVL, &intvl, sizeof(intvl));
setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPCNT, &cnt, sizeof(cnt));
if (send_buffer_size)
setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t));
#if defined(SO_ZEROCOPY)
int zc_one = 1;
if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0)
c->zerocopy_enabled.store(true, std::memory_order_relaxed);
else
c->zerocopy_enabled.store(false, std::memory_order_relaxed);
#endif
c->connected = true;
c->broken = false;
auto now = std::chrono::steady_clock::now();
c->last_keepalive_sent = now;
c->last_keepalive_recv = now;
auto* raw = c.get();
c->persistent_ack_future = std::async(std::launch::async, &TCPStreamPusher::PersistentAckThread, this, raw);
connections.emplace_back(std::move(c));
}
void TCPStreamPusher::RemoveDeadConnections() {
// Must be called with connections_mutex held
auto it = connections.begin();
while (it != connections.end()) {
auto& c = **it;
if (c.broken || !c.connected || !IsConnectionAlive(c)) {
c.connected = false;
c.broken = true;
StopDataCollectionThreads(c);
CloseFd(c.fd);
if (c.persistent_ack_future.valid())
c.persistent_ack_future.get();
logger.Info("Removed dead connection (socket_number=" + std::to_string(c.socket_number) + ")");
it = connections.erase(it);
} else {
++it;
}
}
}
void TCPStreamPusher::KeepaliveThread() {
while (acceptor_running) {
std::this_thread::sleep_for(std::chrono::seconds(5));
if (!acceptor_running)
break;
// During data collection, the data flow itself serves as heartbeat
if (data_collection_active)
continue;
std::lock_guard lg(connections_mutex);
for (auto& cptr : connections) {
auto& c = *cptr;
if (c.broken || !c.connected)
continue;
std::unique_lock ul(c.send_mutex);
if (!SendFrame(c, nullptr, 0, TCPFrameType::KEEPALIVE, -1, nullptr)) {
logger.Warning("Keepalive send failed on socket " + std::to_string(c.socket_number));
c.broken = true;
} else {
c.last_keepalive_sent = std::chrono::steady_clock::now();
}
}
RemoveDeadConnections();
}
}
size_t TCPStreamPusher::GetConnectedWriters() const {
std::lock_guard lg(connections_mutex);
size_t count = 0;
for (const auto& c : connections) {
if (c->connected && !c->broken)
++count;
}
return count;
}
void TCPStreamPusher::StartDataCollectionThreads(Connection& c) {
{
std::unique_lock ul(c.ack_mutex);
c.start_ack_received = false;
@@ -523,11 +737,10 @@ void TCPStreamPusher::StartConnectionThreads(Connection& c) {
c.active = true;
c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c);
c.ack_future = std::async(std::launch::async, &TCPStreamPusher::AckThread, this, &c);
c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c);
}
void TCPStreamPusher::StopConnectionThreads(Connection& c) {
void TCPStreamPusher::StopDataCollectionThreads(Connection& c) {
if (!c.active)
return;
@@ -549,8 +762,6 @@ void TCPStreamPusher::StopConnectionThreads(Connection& c) {
c.ack_cv.notify_all();
}
if (c.ack_future.valid())
c.ack_future.get();
if (c.zc_future.valid())
c.zc_future.get();
@@ -603,47 +814,27 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) {
total_data_acked_bad.store(0, std::memory_order_relaxed);
total_data_acked_total.store(0, std::memory_order_relaxed);
for (auto& c : connections) {
StopConnectionThreads(*c);
CloseFd(c->fd);
// Stop any leftover data-collection threads and clean up dead connections
{
std::lock_guard lg(connections_mutex);
for (auto& c : connections)
StopDataCollectionThreads(*c);
RemoveDeadConnections();
}
connections.clear();
connections.reserve(expected_connections);
int listen_fd = OpenListenSocket(endpoint);
try {
for (size_t i = 0; i < expected_connections; i++) {
int new_fd = AcceptOne(listen_fd, std::chrono::seconds(5));
if (new_fd < 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on " + endpoint);
std::lock_guard lg(connections_mutex);
auto c = std::make_unique<Connection>(send_queue_size);
c->socket_number = static_cast<uint32_t>(i);
c->fd.store(new_fd);
if (connections.empty())
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"No writers connected to " + endpoint);
int one = 1;
setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
if (send_buffer_size)
setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t));
logger.Info("Starting data collection with " + std::to_string(connections.size()) + " connected writers");
#if defined(SO_ZEROCOPY)
int zc_one = 1;
if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) {
c->zerocopy_enabled.store(true, std::memory_order_relaxed);
} else {
c->zerocopy_enabled.store(false, std::memory_order_relaxed);
}
#endif
connections.emplace_back(std::move(c));
}
} catch (...) {
close(listen_fd);
throw;
}
close(listen_fd);
data_collection_active = true;
// Start writer + zerocopy threads for each connection
for (auto& c : connections)
StartConnectionThreads(*c);
StartDataCollectionThreads(*c);
std::vector<bool> started(connections.size(), false);
@@ -661,7 +852,9 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) {
}
for (auto& c : connections)
StopConnectionThreads(*c);
StopDataCollectionThreads(*c);
data_collection_active = false;
};
for (size_t i = 0; i < connections.size(); i++) {
@@ -677,7 +870,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) {
if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) {
rollback_cancel();
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Timeout/failure sending START on socket " + std::to_string(i));
"Timeout/failure sending START on socket " + std::to_string(c.socket_number));
}
}
@@ -685,7 +878,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) {
if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) {
rollback_cancel();
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"START ACK failed on socket " + std::to_string(i) + ": " + ack_err);
"START ACK failed on socket " + std::to_string(c.socket_number) + ": " + ack_err);
}
started[i] = true;
@@ -693,6 +886,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) {
}
bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) {
std::lock_guard lg(connections_mutex);
if (connections.empty())
return false;
@@ -707,6 +901,7 @@ bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, in
}
void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) {
std::lock_guard lg(connections_mutex);
if (connections.empty()) {
z.release();
return;
@@ -731,6 +926,9 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage& message) {
serializer.SerializeSequenceEnd(message);
bool ret = true;
std::lock_guard lg(connections_mutex);
for (auto& cptr : connections) {
auto& c = *cptr;
if (c.broken) {
@@ -751,14 +949,17 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage& message) {
ret = false;
}
// Stop only data-collection threads, keep connections alive
for (auto& c : connections)
StopConnectionThreads(*c);
StopDataCollectionThreads(*c);
data_collection_active = false;
transmission_error = !ret;
return ret;
}
bool TCPStreamPusher::SendCalibration(const CompressedImage& message) {
std::lock_guard lg(connections_mutex);
if (connections.empty())
return false;
@@ -777,15 +978,16 @@ std::string TCPStreamPusher::Finalize() {
if (transmission_error)
ret += "Timeout sending images (e.g., writer disabled during data collection);";
std::lock_guard lg(connections_mutex);
for (size_t i = 0; i < connections.size(); i++) {
auto& c = *connections[i];
{
std::unique_lock ul(c.ack_mutex);
if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) {
ret += "Writer " + std::to_string(i) + ": " + c.data_ack_error_text + ";";
ret += "Writer " + std::to_string(c.socket_number) + ": " + c.data_ack_error_text + ";";
} else if (!c.last_ack_error.empty()) {
ret += "Writer " + std::to_string(i) + ": " + c.last_ack_error + ";";
ret += "Writer " + std::to_string(c.socket_number) + ": " + c.last_ack_error + ";";
}
}
}
@@ -794,9 +996,11 @@ std::string TCPStreamPusher::Finalize() {
}
std::string TCPStreamPusher::PrintSetup() const {
return "TCPStreamPusher: endpoint=" + endpoint + " expected_connections=" + std::to_string(expected_connections);
return "TCPStreamPusher: endpoint=" + endpoint +
" max_connections=" + std::to_string(max_connections) +
" connected=" + std::to_string(GetConnectedWriters());
}
std::optional<uint64_t> TCPStreamPusher::GetImagesWritten() const {
return total_data_acked_ok.load(std::memory_order_relaxed);
}
}
+32 -10
View File
@@ -3,8 +3,6 @@
#pragma once
#pragma once
#include <atomic>
#include <future>
#include <mutex>
@@ -26,15 +24,18 @@ class TCPStreamPusher : public ImagePusher {
std::atomic<int> fd{-1};
uint32_t socket_number = 0;
std::atomic<bool> active{false};
std::atomic<bool> active{false}; // data-collection threads running
std::atomic<bool> broken{false};
std::atomic<bool> connected{false}; // persistent connection is alive
std::atomic<bool> zerocopy_enabled{false};
ThreadSafeFIFO<ImagePusherQueueElement> queue;
std::future<void> writer_future;
std::future<void> ack_future;
std::future<void> zc_future;
// Persistent ack/keepalive reader (runs as long as the connection is alive)
std::future<void> persistent_ack_future;
std::mutex send_mutex;
std::mutex ack_mutex;
std::condition_variable ack_cv;
@@ -69,22 +70,34 @@ class TCPStreamPusher : public ImagePusher {
std::atomic<uint64_t> data_acked_ok{0};
std::atomic<uint64_t> data_acked_bad{0};
std::atomic<uint64_t> data_acked_total{0};
std::chrono::steady_clock::time_point last_keepalive_sent{};
std::chrono::steady_clock::time_point last_keepalive_recv{};
};
std::vector<uint8_t> serialization_buffer;
CBORStream2Serializer serializer;
std::string endpoint;
size_t expected_connections = 0;
size_t max_connections;
std::optional<int32_t> send_buffer_size;
size_t send_queue_size = 128;
// Persistent connection pool, guarded by connections_mutex
mutable std::mutex connections_mutex;
std::vector<std::unique_ptr<Connection>> connections;
// Acceptor thread state
std::atomic<int> listen_fd{-1};
std::atomic<bool> acceptor_running{false};
std::future<void> acceptor_future;
std::future<void> keepalive_future;
int64_t images_per_file = 1;
uint64_t run_number = 0;
std::string run_name;
std::atomic<bool> transmission_error = false;
std::atomic<bool> data_collection_active{false};
std::atomic<uint64_t> total_data_acked_ok{0};
std::atomic<uint64_t> total_data_acked_bad{0};
@@ -101,14 +114,20 @@ class TCPStreamPusher : public ImagePusher {
bool SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy,
bool* zc_used = nullptr, uint32_t* zc_first = nullptr, uint32_t* zc_last = nullptr);
bool ReadExact(Connection& c, void* buf, size_t len);
bool ReadExactPersistent(Connection& c, void* buf, size_t len);
bool SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z);
void WriterThread(Connection* c);
void AckThread(Connection* c);
void PersistentAckThread(Connection* c);
void ZeroCopyCompletionThread(Connection* c);
void AcceptorThread();
void KeepaliveThread();
void StartConnectionThreads(Connection& c);
void StopConnectionThreads(Connection& c);
void SetupNewConnection(int new_fd, uint32_t socket_number);
void RemoveDeadConnections();
void StartDataCollectionThreads(Connection& c);
void StopDataCollectionThreads(Connection& c);
void EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id);
void ReleaseCompletedZeroCopy(Connection& c);
@@ -118,11 +137,14 @@ class TCPStreamPusher : public ImagePusher {
bool WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text);
public:
explicit TCPStreamPusher(const std::string& addr,
size_t in_expected_connections,
size_t in_max_connections,
std::optional<int32_t> in_send_buffer_size = {});
~TCPStreamPusher() override;
/// Returns the number of currently connected writers (can be called at any time)
size_t GetConnectedWriters() const;
void StartDataCollection(StartMessage& message) override;
bool EndDataCollection(const EndMessage& message) override;
bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override;
@@ -133,4 +155,4 @@ public:
std::string PrintSetup() const override;
std::optional<uint64_t> GetImagesWritten() const override;
};
};