From 91591a3cc3bb7aed96b48b422361875f2f4f02de Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Thu, 5 Mar 2026 08:30:16 +0100 Subject: [PATCH] TCPStreamPusher: Implement KEEPALIVE + writers stay connected --- common/JfjochTCP.h | 3 +- image_puller/TCPImagePuller.cpp | 26 ++- image_pusher/TCPStreamPusher.cpp | 332 +++++++++++++++++++++++++------ image_pusher/TCPStreamPusher.h | 42 +++- 4 files changed, 327 insertions(+), 76 deletions(-) diff --git a/common/JfjochTCP.h b/common/JfjochTCP.h index 51c6d2a4..3ee50967 100644 --- a/common/JfjochTCP.h +++ b/common/JfjochTCP.h @@ -14,7 +14,8 @@ enum class TCPFrameType : uint16_t { CALIBRATION = 3, END = 4, ACK = 5, - CANCEL = 6 + CANCEL = 6, + KEEPALIVE = 7, }; enum class TCPAckCode : uint16_t { diff --git a/image_puller/TCPImagePuller.cpp b/image_puller/TCPImagePuller.cpp index b002e662..bbab529d 100644 --- a/image_puller/TCPImagePuller.cpp +++ b/image_puller/TCPImagePuller.cpp @@ -250,8 +250,32 @@ void TCPImagePuller::ReceiverThread() { continue; } + const auto frame_type = static_cast(frame.header.type); + + // Respond to keepalive ping with a keepalive pong + if (frame_type == TCPFrameType::KEEPALIVE) { + if (frame.header.payload_size > 0) { + std::vector discard(frame.header.payload_size); + if (!ReadExact(discard.data(), discard.size())) { + CloseSocket(); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + continue; + } + } + // Send keepalive pong back + TcpFrameHeader pong{}; + pong.type = static_cast(TCPFrameType::KEEPALIVE); + pong.payload_size = 0; + if (!SendAll(&pong, sizeof(pong))) { + logger.Info("Keepalive pong send failed, reconnecting to " + addr); + CloseSocket(); + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + continue; + } + // Ignore ACK on puller side - if (static_cast(frame.header.type) == TCPFrameType::ACK) { + if (frame_type == TCPFrameType::ACK) { if (frame.header.payload_size > 0) { std::vector discard(frame.header.payload_size); if (!ReadExact(discard.data(), discard.size())) { diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index 22b48383..6c56ef2a 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -90,24 +90,48 @@ void TCPStreamPusher::CloseFd(std::atomic& fd) { } TCPStreamPusher::TCPStreamPusher(const std::string& addr, - size_t in_expected_connections, + size_t in_max_connections, std::optional in_send_buffer_size) : serialization_buffer(256 * 1024 * 1024), serializer(serialization_buffer.data(), serialization_buffer.size()), endpoint(addr), - expected_connections(in_expected_connections), + max_connections(in_max_connections), send_buffer_size(in_send_buffer_size) { if (endpoint.empty()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); - if (expected_connections == 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Expected TCP connections cannot be zero"); + if (max_connections == 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Max TCP connections cannot be zero"); + + listen_fd.store(OpenListenSocket(endpoint)); + acceptor_running = true; + acceptor_future = std::async(std::launch::async, &TCPStreamPusher::AcceptorThread, this); + keepalive_future = std::async(std::launch::async, &TCPStreamPusher::KeepaliveThread, this); + + logger.Info("TCPStreamPusher listening on " + endpoint + " (max " + std::to_string(max_connections) + " connections)"); } TCPStreamPusher::~TCPStreamPusher() { - for (auto& c : connections) { - StopConnectionThreads(*c); - CloseFd(c->fd); + acceptor_running = false; + int lfd = listen_fd.exchange(-1); + if (lfd >= 0) { + shutdown(lfd, SHUT_RDWR); + close(lfd); } + if (acceptor_future.valid()) + acceptor_future.get(); + if (keepalive_future.valid()) + keepalive_future.get(); + + std::lock_guard lg(connections_mutex); + for (auto& c : connections) { + StopDataCollectionThreads(*c); + c->connected = false; + c->broken = true; + CloseFd(c->fd); + if (c->persistent_ack_future.valid()) + c->persistent_ack_future.get(); + } + connections.clear(); } bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const { @@ -392,6 +416,50 @@ bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) { return true; } +bool TCPStreamPusher::ReadExactPersistent(Connection& c, void* buf, size_t len) { + auto* p = static_cast(buf); + size_t got = 0; + + while (got < len) { + if (!c.connected) + return false; + + const int local_fd = c.fd.load(); + if (local_fd < 0) + return false; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLIN; + + const int prc = poll(&pfd, 1, 500); + if (prc == 0) + continue; + if (prc < 0) { + if (errno == EINTR) + continue; + return false; + } + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return false; + if ((pfd.revents & POLLIN) == 0) + continue; + + ssize_t rc = ::recv(local_fd, p + got, len - got, 0); + if (rc == 0) + return false; + if (rc < 0) { + if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) + continue; + return false; + } + + got += static_cast(rc); + } + + return true; +} + void TCPStreamPusher::WriterThread(Connection* c) { while (c->active) { auto e = c->queue.GetBlocking(); @@ -418,27 +486,46 @@ void TCPStreamPusher::WriterThread(Connection* c) { } } -void TCPStreamPusher::AckThread(Connection* c) { - while (c->active) { +void TCPStreamPusher::PersistentAckThread(Connection* c) { + while (c->connected && !c->broken) { TcpFrameHeader h{}; - if (!ReadExact(*c, &h, sizeof(h))) { - if (c->active) { + if (!ReadExactPersistent(*c, &h, sizeof(h))) { + if (c->connected) { c->broken = true; - logger.Error("TCP ACK reader disconnected on socket " + std::to_string(c->socket_number)); + logger.Info("Persistent connection lost on socket " + std::to_string(c->socket_number)); } break; } - if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { + if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION) { c->broken = true; - logger.Error("Invalid ACK frame on socket " + std::to_string(c->socket_number)); + logger.Error("Invalid frame on persistent connection, socket " + std::to_string(c->socket_number)); break; } + const auto frame_type = static_cast(h.type); + + // Keepalive pong from the writer + if (frame_type == TCPFrameType::KEEPALIVE) { + c->last_keepalive_recv = std::chrono::steady_clock::now(); + if (h.payload_size > 0) { + std::vector discard(h.payload_size); + ReadExactPersistent(*c, discard.data(), discard.size()); + } + continue; + } + + if (frame_type != TCPFrameType::ACK) { + c->broken = true; + logger.Error("Unexpected frame type " + std::to_string(h.type) + " on socket " + std::to_string(c->socket_number)); + break; + } + + // ACK frame — forward to data-collection ack logic std::string error_text; if (h.payload_size > 0) { error_text.resize(h.payload_size); - if (!ReadExact(*c, error_text.data(), error_text.size())) { + if (!ReadExactPersistent(*c, error_text.data(), error_text.size())) { c->broken = true; break; } @@ -479,8 +566,6 @@ void TCPStreamPusher::AckThread(Connection* c) { } else { c->data_acked_bad.fetch_add(1, std::memory_order_relaxed); total_data_acked_bad.fetch_add(1, std::memory_order_relaxed); - - // Soft failure: remember it for Finalize(), do NOT mark socket broken. c->data_ack_error_reported = true; if (!error_text.empty()) { c->data_ack_error_text = error_text; @@ -495,7 +580,136 @@ void TCPStreamPusher::AckThread(Connection* c) { } } -void TCPStreamPusher::StartConnectionThreads(Connection& c) { +void TCPStreamPusher::AcceptorThread() { + uint32_t next_socket_number = 0; + + while (acceptor_running) { + int lfd = listen_fd.load(); + if (lfd < 0) + break; + + int new_fd = AcceptOne(lfd, std::chrono::milliseconds(500)); + if (new_fd < 0) + continue; + + std::lock_guard lg(connections_mutex); + + RemoveDeadConnections(); + + if (connections.size() >= max_connections) { + logger.Warning("Max connections (" + std::to_string(max_connections) + + ") reached, rejecting new connection"); + shutdown(new_fd, SHUT_RDWR); + close(new_fd); + continue; + } + + SetupNewConnection(new_fd, next_socket_number++); + logger.Info("Accepted writer connection (socket_number=" + std::to_string(next_socket_number - 1) + + ", total=" + std::to_string(connections.size()) + ")"); + } +} + +void TCPStreamPusher::SetupNewConnection(int new_fd, uint32_t socket_number) { + auto c = std::make_unique(send_queue_size); + c->socket_number = socket_number; + c->fd.store(new_fd); + + int one = 1; + setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + + // Enable OS-level TCP keep-alive + setsockopt(new_fd, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); + int idle = 10; + int intvl = 5; + int cnt = 3; + setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)); + setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPINTVL, &intvl, sizeof(intvl)); + setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPCNT, &cnt, sizeof(cnt)); + + if (send_buffer_size) + setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); + +#if defined(SO_ZEROCOPY) + int zc_one = 1; + if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) + c->zerocopy_enabled.store(true, std::memory_order_relaxed); + else + c->zerocopy_enabled.store(false, std::memory_order_relaxed); +#endif + + c->connected = true; + c->broken = false; + auto now = std::chrono::steady_clock::now(); + c->last_keepalive_sent = now; + c->last_keepalive_recv = now; + + auto* raw = c.get(); + c->persistent_ack_future = std::async(std::launch::async, &TCPStreamPusher::PersistentAckThread, this, raw); + + connections.emplace_back(std::move(c)); +} + +void TCPStreamPusher::RemoveDeadConnections() { + // Must be called with connections_mutex held + auto it = connections.begin(); + while (it != connections.end()) { + auto& c = **it; + if (c.broken || !c.connected || !IsConnectionAlive(c)) { + c.connected = false; + c.broken = true; + StopDataCollectionThreads(c); + CloseFd(c.fd); + if (c.persistent_ack_future.valid()) + c.persistent_ack_future.get(); + logger.Info("Removed dead connection (socket_number=" + std::to_string(c.socket_number) + ")"); + it = connections.erase(it); + } else { + ++it; + } + } +} + +void TCPStreamPusher::KeepaliveThread() { + while (acceptor_running) { + std::this_thread::sleep_for(std::chrono::seconds(5)); + if (!acceptor_running) + break; + + // During data collection, the data flow itself serves as heartbeat + if (data_collection_active) + continue; + + std::lock_guard lg(connections_mutex); + for (auto& cptr : connections) { + auto& c = *cptr; + if (c.broken || !c.connected) + continue; + + std::unique_lock ul(c.send_mutex); + if (!SendFrame(c, nullptr, 0, TCPFrameType::KEEPALIVE, -1, nullptr)) { + logger.Warning("Keepalive send failed on socket " + std::to_string(c.socket_number)); + c.broken = true; + } else { + c.last_keepalive_sent = std::chrono::steady_clock::now(); + } + } + + RemoveDeadConnections(); + } +} + +size_t TCPStreamPusher::GetConnectedWriters() const { + std::lock_guard lg(connections_mutex); + size_t count = 0; + for (const auto& c : connections) { + if (c->connected && !c->broken) + ++count; + } + return count; +} + +void TCPStreamPusher::StartDataCollectionThreads(Connection& c) { { std::unique_lock ul(c.ack_mutex); c.start_ack_received = false; @@ -523,11 +737,10 @@ void TCPStreamPusher::StartConnectionThreads(Connection& c) { c.active = true; c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c); - c.ack_future = std::async(std::launch::async, &TCPStreamPusher::AckThread, this, &c); c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c); } -void TCPStreamPusher::StopConnectionThreads(Connection& c) { +void TCPStreamPusher::StopDataCollectionThreads(Connection& c) { if (!c.active) return; @@ -549,8 +762,6 @@ void TCPStreamPusher::StopConnectionThreads(Connection& c) { c.ack_cv.notify_all(); } - if (c.ack_future.valid()) - c.ack_future.get(); if (c.zc_future.valid()) c.zc_future.get(); @@ -603,47 +814,27 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) { total_data_acked_bad.store(0, std::memory_order_relaxed); total_data_acked_total.store(0, std::memory_order_relaxed); - for (auto& c : connections) { - StopConnectionThreads(*c); - CloseFd(c->fd); + // Stop any leftover data-collection threads and clean up dead connections + { + std::lock_guard lg(connections_mutex); + for (auto& c : connections) + StopDataCollectionThreads(*c); + RemoveDeadConnections(); } - connections.clear(); - connections.reserve(expected_connections); - int listen_fd = OpenListenSocket(endpoint); - try { - for (size_t i = 0; i < expected_connections; i++) { - int new_fd = AcceptOne(listen_fd, std::chrono::seconds(5)); - if (new_fd < 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on " + endpoint); + std::lock_guard lg(connections_mutex); - auto c = std::make_unique(send_queue_size); - c->socket_number = static_cast(i); - c->fd.store(new_fd); + if (connections.empty()) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "No writers connected to " + endpoint); - int one = 1; - setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); - if (send_buffer_size) - setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); + logger.Info("Starting data collection with " + std::to_string(connections.size()) + " connected writers"); -#if defined(SO_ZEROCOPY) - int zc_one = 1; - if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) { - c->zerocopy_enabled.store(true, std::memory_order_relaxed); - } else { - c->zerocopy_enabled.store(false, std::memory_order_relaxed); - } -#endif - connections.emplace_back(std::move(c)); - } - } catch (...) { - close(listen_fd); - throw; - } - close(listen_fd); + data_collection_active = true; + // Start writer + zerocopy threads for each connection for (auto& c : connections) - StartConnectionThreads(*c); + StartDataCollectionThreads(*c); std::vector started(connections.size(), false); @@ -661,7 +852,9 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) { } for (auto& c : connections) - StopConnectionThreads(*c); + StopDataCollectionThreads(*c); + + data_collection_active = false; }; for (size_t i = 0; i < connections.size(); i++) { @@ -677,7 +870,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) { if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "Timeout/failure sending START on socket " + std::to_string(i)); + "Timeout/failure sending START on socket " + std::to_string(c.socket_number)); } } @@ -685,7 +878,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) { if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "START ACK failed on socket " + std::to_string(i) + ": " + ack_err); + "START ACK failed on socket " + std::to_string(c.socket_number) + ": " + ack_err); } started[i] = true; @@ -693,6 +886,7 @@ void TCPStreamPusher::StartDataCollection(StartMessage& message) { } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { + std::lock_guard lg(connections_mutex); if (connections.empty()) return false; @@ -707,6 +901,7 @@ bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, in } void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) { + std::lock_guard lg(connections_mutex); if (connections.empty()) { z.release(); return; @@ -731,6 +926,9 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage& message) { serializer.SerializeSequenceEnd(message); bool ret = true; + + std::lock_guard lg(connections_mutex); + for (auto& cptr : connections) { auto& c = *cptr; if (c.broken) { @@ -751,14 +949,17 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage& message) { ret = false; } + // Stop only data-collection threads, keep connections alive for (auto& c : connections) - StopConnectionThreads(*c); + StopDataCollectionThreads(*c); + data_collection_active = false; transmission_error = !ret; return ret; } bool TCPStreamPusher::SendCalibration(const CompressedImage& message) { + std::lock_guard lg(connections_mutex); if (connections.empty()) return false; @@ -777,15 +978,16 @@ std::string TCPStreamPusher::Finalize() { if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; + std::lock_guard lg(connections_mutex); for (size_t i = 0; i < connections.size(); i++) { auto& c = *connections[i]; { std::unique_lock ul(c.ack_mutex); if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) { - ret += "Writer " + std::to_string(i) + ": " + c.data_ack_error_text + ";"; + ret += "Writer " + std::to_string(c.socket_number) + ": " + c.data_ack_error_text + ";"; } else if (!c.last_ack_error.empty()) { - ret += "Writer " + std::to_string(i) + ": " + c.last_ack_error + ";"; + ret += "Writer " + std::to_string(c.socket_number) + ": " + c.last_ack_error + ";"; } } } @@ -794,9 +996,11 @@ std::string TCPStreamPusher::Finalize() { } std::string TCPStreamPusher::PrintSetup() const { - return "TCPStreamPusher: endpoint=" + endpoint + " expected_connections=" + std::to_string(expected_connections); + return "TCPStreamPusher: endpoint=" + endpoint + + " max_connections=" + std::to_string(max_connections) + + " connected=" + std::to_string(GetConnectedWriters()); } std::optional TCPStreamPusher::GetImagesWritten() const { return total_data_acked_ok.load(std::memory_order_relaxed); -} \ No newline at end of file +} diff --git a/image_pusher/TCPStreamPusher.h b/image_pusher/TCPStreamPusher.h index f006c60d..8f4701f9 100644 --- a/image_pusher/TCPStreamPusher.h +++ b/image_pusher/TCPStreamPusher.h @@ -3,8 +3,6 @@ #pragma once -#pragma once - #include #include #include @@ -26,15 +24,18 @@ class TCPStreamPusher : public ImagePusher { std::atomic fd{-1}; uint32_t socket_number = 0; - std::atomic active{false}; + std::atomic active{false}; // data-collection threads running std::atomic broken{false}; + std::atomic connected{false}; // persistent connection is alive std::atomic zerocopy_enabled{false}; ThreadSafeFIFO queue; std::future writer_future; - std::future ack_future; std::future zc_future; + // Persistent ack/keepalive reader (runs as long as the connection is alive) + std::future persistent_ack_future; + std::mutex send_mutex; std::mutex ack_mutex; std::condition_variable ack_cv; @@ -69,22 +70,34 @@ class TCPStreamPusher : public ImagePusher { std::atomic data_acked_ok{0}; std::atomic data_acked_bad{0}; std::atomic data_acked_total{0}; + + std::chrono::steady_clock::time_point last_keepalive_sent{}; + std::chrono::steady_clock::time_point last_keepalive_recv{}; }; std::vector serialization_buffer; CBORStream2Serializer serializer; std::string endpoint; - size_t expected_connections = 0; + size_t max_connections; std::optional send_buffer_size; size_t send_queue_size = 128; + // Persistent connection pool, guarded by connections_mutex + mutable std::mutex connections_mutex; std::vector> connections; + // Acceptor thread state + std::atomic listen_fd{-1}; + std::atomic acceptor_running{false}; + std::future acceptor_future; + std::future keepalive_future; + int64_t images_per_file = 1; uint64_t run_number = 0; std::string run_name; std::atomic transmission_error = false; + std::atomic data_collection_active{false}; std::atomic total_data_acked_ok{0}; std::atomic total_data_acked_bad{0}; @@ -101,14 +114,20 @@ class TCPStreamPusher : public ImagePusher { bool SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy, bool* zc_used = nullptr, uint32_t* zc_first = nullptr, uint32_t* zc_last = nullptr); bool ReadExact(Connection& c, void* buf, size_t len); + bool ReadExactPersistent(Connection& c, void* buf, size_t len); bool SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z); void WriterThread(Connection* c); - void AckThread(Connection* c); + void PersistentAckThread(Connection* c); void ZeroCopyCompletionThread(Connection* c); + void AcceptorThread(); + void KeepaliveThread(); - void StartConnectionThreads(Connection& c); - void StopConnectionThreads(Connection& c); + void SetupNewConnection(int new_fd, uint32_t socket_number); + void RemoveDeadConnections(); + + void StartDataCollectionThreads(Connection& c); + void StopDataCollectionThreads(Connection& c); void EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id); void ReleaseCompletedZeroCopy(Connection& c); @@ -118,11 +137,14 @@ class TCPStreamPusher : public ImagePusher { bool WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text); public: explicit TCPStreamPusher(const std::string& addr, - size_t in_expected_connections, + size_t in_max_connections, std::optional in_send_buffer_size = {}); ~TCPStreamPusher() override; + /// Returns the number of currently connected writers (can be called at any time) + size_t GetConnectedWriters() const; + void StartDataCollection(StartMessage& message) override; bool EndDataCollection(const EndMessage& message) override; bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override; @@ -133,4 +155,4 @@ public: std::string PrintSetup() const override; std::optional GetImagesWritten() const override; -}; \ No newline at end of file +};