// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "TCPStreamPusher.h" #include #include #include #include #include #include #include #include #if defined(MSG_ZEROCOPY) #include #endif std::pair> TCPStreamPusher::ParseTcpAddress(const std::string& addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); auto hp = addr.substr(prefix.size()); auto p = hp.find_last_of(':'); if (p == std::string::npos) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); const auto host = hp.substr(0, p); const auto port_str = hp.substr(p + 1); if (port_str == "*") return {host, std::nullopt}; int port_i = 0; try { size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } std::pair TCPStreamPusher::OpenListenSocket(const std::string& addr) { auto [host, port_opt] = ParseTcpAddress(addr); int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); if (listen_fd < 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed"); int one = 1; setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); sockaddr_in sin{}; sin.sin_family = AF_INET; sin.sin_port = htons(port_opt.has_value() ? port_opt.value() : 0); if (host == "*" || host == "0.0.0.0") sin.sin_addr.s_addr = htonl(INADDR_ANY); else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host); if (bind(listen_fd, reinterpret_cast(&sin), sizeof(sin)) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr); if (listen(listen_fd, 64) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr); sockaddr_in actual{}; socklen_t actual_len = sizeof(actual); if (getsockname(listen_fd, reinterpret_cast(&actual), &actual_len) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "getsockname() failed on " + addr); const uint16_t bound_port = ntohs(actual.sin_port); const std::string normalized_host = (host == "*") ? "0.0.0.0" : host; const std::string bound_endpoint = "tcp://" + normalized_host + ":" + std::to_string(bound_port); return {listen_fd, bound_endpoint}; } int TCPStreamPusher::AcceptOne(int listen_fd, std::chrono::milliseconds timeout) { pollfd pfd{}; pfd.fd = listen_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, static_cast(timeout.count())); if (prc <= 0) return -1; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) return -1; return accept(listen_fd, nullptr, nullptr); } void TCPStreamPusher::CloseFd(std::atomic& fd) { int old_fd = fd.exchange(-1); if (old_fd >= 0) { shutdown(old_fd, SHUT_RDWR); close(old_fd); } } TCPStreamPusher::TCPStreamPusher(const std::string& addr, size_t in_max_connections, std::optional in_send_buffer_size) : serialization_buffer(256 * 1024 * 1024), serializer(serialization_buffer.data(), serialization_buffer.size()), endpoint(addr), max_connections(in_max_connections), send_buffer_size(in_send_buffer_size) { if (endpoint.empty()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); if (max_connections == 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Max TCP connections cannot be zero"); auto [lfd, bound_endpoint] = OpenListenSocket(endpoint); listen_fd.store(lfd); endpoint = bound_endpoint; acceptor_running = true; acceptor_future = std::async(std::launch::async, &TCPStreamPusher::AcceptorThread, this); keepalive_future = std::async(std::launch::async, &TCPStreamPusher::KeepaliveThread, this); logger.Info("TCPStreamPusher listening on " + endpoint + " (max " + std::to_string(max_connections) + " connections)"); } TCPStreamPusher::~TCPStreamPusher() { // 1. Stop acceptor + keepalive threads (they take connections_mutex briefly) acceptor_running = false; int lfd = listen_fd.exchange(-1); if (lfd >= 0) { shutdown(lfd, SHUT_RDWR); close(lfd); } if (acceptor_future.valid()) acceptor_future.get(); if (keepalive_future.valid()) keepalive_future.get(); // 2. Now no background threads touch connections_mutex. Tear down connections. // We do NOT hold the mutex while joining futures, to avoid deadlock. std::vector> local_connections; { std::lock_guard lg(connections_mutex); local_connections = std::move(connections); connections.clear(); session_connections.clear(); } for (auto& c : local_connections) { if (c) TearDownConnection(*c); } } void TCPStreamPusher::TearDownConnection(Connection& c) { StopDataCollectionThreads(c); c.connected = false; c.broken = true; CloseFd(c.fd); if (c.persistent_ack_future.valid()) c.persistent_ack_future.get(); } bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const { if (c.broken) return false; const int local_fd = c.fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLOUT; if (poll(&pfd, 1, 0) < 0) return false; if ((pfd.revents & (POLLHUP | POLLNVAL)) != 0) return false; int so_error = 0; socklen_t len = sizeof(so_error); if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) return false; return so_error == 0; } bool TCPStreamPusher::SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy, bool* zc_used_out, uint32_t* zc_first_out, uint32_t* zc_last_out) { const auto* p = static_cast(buf); size_t sent = 0; bool zc_used = false; uint32_t zc_first = 0; uint32_t zc_last = 0; const auto deadline = std::chrono::steady_clock::now() + send_total_timeout; bool try_zerocopy = false; #if defined(MSG_ZEROCOPY) try_zerocopy = allow_zerocopy && c.zerocopy_enabled.load(std::memory_order_relaxed); #endif while (sent < len) { const int local_fd = c.fd.load(); if (local_fd < 0 || c.broken) { if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } if (std::chrono::steady_clock::now() >= deadline) { c.broken = true; CloseFd(c.fd); if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLOUT; const int prc = poll(&pfd, 1, static_cast(send_poll_timeout.count())); if (prc < 0) { if (errno == EINTR) continue; c.broken = true; CloseFd(c.fd); if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } if (prc == 0) continue; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0 && !(pfd.revents & POLLOUT)) { c.broken = true; CloseFd(c.fd); if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } int flags = MSG_NOSIGNAL; #if defined(MSG_ZEROCOPY) if (try_zerocopy) flags |= MSG_ZEROCOPY; #endif ssize_t rc = ::send(local_fd, p + sent, len - sent, flags); if (rc < 0) { if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) continue; #if defined(MSG_ZEROCOPY) if (try_zerocopy && (errno == EOPNOTSUPP || errno == EINVAL || errno == ENOBUFS)) { try_zerocopy = false; c.zerocopy_enabled.store(false, std::memory_order_relaxed); continue; } #endif if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { c.broken = true; CloseFd(c.fd); } if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } #if defined(MSG_ZEROCOPY) if (try_zerocopy && rc > 0) { const uint32_t this_id = c.zc_next_id++; if (!zc_used) { zc_used = true; zc_first = this_id; } zc_last = this_id; } #endif sent += static_cast(rc); } if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return true; } bool TCPStreamPusher::SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z) { TcpFrameHeader h{}; h.type = static_cast(type); h.payload_size = size; h.image_number = image_number >= 0 ? static_cast(image_number) : 0; h.socket_number = c.socket_number; h.run_number = run_number; if (!SendAll(c, &h, sizeof(h), false)) { if (z) z->release(); return false; } bool zc_used = false; uint32_t zc_first = 0; uint32_t zc_last = 0; if (size > 0) { const bool allow_zerocopy = (type == TCPFrameType::DATA || type == TCPFrameType::CALIBRATION); if (!SendAll(c, data, size, allow_zerocopy, &zc_used, &zc_first, &zc_last)) { if (z) { if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); else z->release(); } return false; } } if (z) { if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); else z->release(); } if (type == TCPFrameType::DATA) c.data_sent.fetch_add(1, std::memory_order_relaxed); return true; } // Caller must hold c.zc_mutex. void TCPStreamPusher::ReleaseCompletedZeroCopy(Connection& c) { while (!c.zc_pending.empty()) { const auto& front = c.zc_pending.front(); if (c.zc_completed_id == std::numeric_limits::max() || front.last_id > c.zc_completed_id) break; if (front.z) front.z->release(); c.zc_pending.pop_front(); } } void TCPStreamPusher::EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id) { std::unique_lock ul(c.zc_mutex); c.zc_pending.push_back(Connection::PendingZC{ .first_id = first_id, .last_id = last_id, .z = z }); c.zc_cv.notify_all(); } void TCPStreamPusher::ForceReleasePendingZeroCopy(Connection& c) { std::unique_lock ul(c.zc_mutex); while (!c.zc_pending.empty()) { auto p = c.zc_pending.front(); c.zc_pending.pop_front(); if (p.z) p.z->release(); } c.zc_cv.notify_all(); } bool TCPStreamPusher::WaitForZeroCopyDrain(Connection& c, std::chrono::milliseconds timeout) { std::unique_lock ul(c.zc_mutex); return c.zc_cv.wait_for(ul, timeout, [&] { return c.zc_pending.empty() || c.broken.load(); }); } void TCPStreamPusher::ZeroCopyCompletionThread(Connection* c) { #if defined(MSG_ZEROCOPY) while (c->active || !c->zc_pending.empty()) { const int local_fd = c->fd.load(); if (local_fd < 0) break; pollfd pfd{}; pfd.fd = local_fd; pfd.events = 0; // We only care about POLLERR for errqueue const int prc = poll(&pfd, 1, 100); if (prc < 0) { if (errno == EINTR) continue; break; } if (prc == 0) continue; // Only process the error queue; POLLERR is always reported regardless of events mask uint8_t control[512]; uint8_t data[1]; iovec iov{}; iov.iov_base = data; iov.iov_len = sizeof(data); msghdr msg{}; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = control; msg.msg_controllen = sizeof(control); while (true) { const ssize_t rr = recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT); if (rr < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) break; // A real error on the errqueue likely means the socket is dead c->broken = true; break; } for (cmsghdr* cm = CMSG_FIRSTHDR(&msg); cm != nullptr; cm = CMSG_NXTHDR(&msg, cm)) { if (cm->cmsg_level != SOL_IP || cm->cmsg_type != IP_RECVERR) continue; auto* se = reinterpret_cast(CMSG_DATA(cm)); if (!se || se->ee_origin != SO_EE_ORIGIN_ZEROCOPY) continue; const uint32_t end_id = se->ee_data; std::unique_lock ul(c->zc_mutex); if (c->zc_completed_id == std::numeric_limits::max() || end_id > c->zc_completed_id) c->zc_completed_id = end_id; ReleaseCompletedZeroCopy(*c); c->zc_cv.notify_all(); } } } #else (void)c; #endif } bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) { auto* p = static_cast(buf); size_t got = 0; while (got < len) { if (!c.active) return false; const int local_fd = c.fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, 100); if (prc == 0) continue; if (prc < 0) { if (errno == EINTR) continue; return false; } if ((pfd.revents & (POLLHUP | POLLNVAL)) != 0 && !(pfd.revents & POLLIN)) return false; if ((pfd.revents & POLLIN) == 0) continue; ssize_t rc = ::recv(local_fd, p + got, len - got, 0); if (rc == 0) return false; if (rc < 0) { if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) continue; return false; } got += static_cast(rc); } return true; } bool TCPStreamPusher::ReadExactPersistent(Connection& c, void* buf, size_t len) { auto* p = static_cast(buf); size_t got = 0; while (got < len) { if (!c.connected || c.broken) return false; const int local_fd = c.fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, 500); if (prc == 0) continue; if (prc < 0) { if (errno == EINTR) continue; return false; } // POLLERR can fire alongside POLLIN during zerocopy — only bail if we have // POLLHUP/POLLNVAL without any readable data. if ((pfd.revents & (POLLHUP | POLLNVAL)) != 0 && !(pfd.revents & POLLIN)) return false; if ((pfd.revents & POLLIN) == 0) continue; ssize_t rc = ::recv(local_fd, p + got, len - got, 0); if (rc == 0) return false; // Peer closed connection if (rc < 0) { if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) continue; return false; } got += static_cast(rc); } return true; } void TCPStreamPusher::WriterThread(Connection* c) { while (c->active) { auto e = c->queue.GetBlocking(); if (e.end) break; if (!e.z) continue; if (c->broken) { e.z->release(); continue; } std::unique_lock ul(c->send_mutex); if (!SendFrame(*c, static_cast(e.z->GetImage()), e.z->GetImageSize(), TCPFrameType::DATA, e.z->GetImageNumber(), e.z)) { c->broken = true; logger.Error("TCP send failed on socket " + std::to_string(c->socket_number)); } } } void TCPStreamPusher::PersistentAckThread(Connection* c) { while (c->connected && !c->broken) { TcpFrameHeader h{}; if (!ReadExactPersistent(*c, &h, sizeof(h))) { if (c->connected && !c->broken) { c->broken = true; logger.Info("Persistent connection lost on socket " + std::to_string(c->socket_number)); } break; } if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION) { c->broken = true; logger.Error("Invalid frame on persistent connection, socket " + std::to_string(c->socket_number)); break; } const auto frame_type = static_cast(h.type); // Keepalive pong from the writer if (frame_type == TCPFrameType::KEEPALIVE) { c->last_keepalive_recv = std::chrono::steady_clock::now(); if (h.payload_size > 0) { std::vector discard(h.payload_size); ReadExactPersistent(*c, discard.data(), discard.size()); } continue; } if (frame_type != TCPFrameType::ACK) { c->broken = true; logger.Error("Unexpected frame type " + std::to_string(h.type) + " on socket " + std::to_string(c->socket_number)); break; } // ACK frame — forward to data-collection ack logic std::string error_text; if (h.payload_size > 0) { error_text.resize(h.payload_size); if (!ReadExactPersistent(*c, error_text.data(), error_text.size())) { c->broken = true; break; } } const auto ack_for = static_cast(h.ack_for); const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; const auto code = static_cast(h.ack_code); { std::unique_lock ul(c->ack_mutex); c->last_ack_code = code; if (!error_text.empty()) c->last_ack_error = error_text; if (ack_for == TCPFrameType::START) { c->start_ack_received = true; c->start_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "START rejected"; } else if (ack_for == TCPFrameType::END) { c->end_ack_received = true; c->end_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "END rejected"; } else if (ack_for == TCPFrameType::CANCEL) { c->cancel_ack_received = true; c->cancel_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "CANCEL rejected"; } else if (ack_for == TCPFrameType::DATA) { c->data_acked_total.fetch_add(1, std::memory_order_relaxed); total_data_acked_total.fetch_add(1, std::memory_order_relaxed); if (ok && !fatal) { c->data_acked_ok.fetch_add(1, std::memory_order_relaxed); total_data_acked_ok.fetch_add(1, std::memory_order_relaxed); } else { c->data_acked_bad.fetch_add(1, std::memory_order_relaxed); total_data_acked_bad.fetch_add(1, std::memory_order_relaxed); c->data_ack_error_reported = true; if (!error_text.empty()) { c->data_ack_error_text = error_text; } else if (c->data_ack_error_text.empty()) { c->data_ack_error_text = "DATA ACK failed"; } } } } c->ack_cv.notify_all(); } // If the connection broke, also wake up anyone waiting for an ACK c->ack_cv.notify_all(); } void TCPStreamPusher::AcceptorThread() { uint32_t next_socket_number = 0; while (acceptor_running) { int lfd = listen_fd.load(); if (lfd < 0) break; int new_fd = AcceptOne(lfd, std::chrono::milliseconds(500)); if (new_fd < 0) continue; std::lock_guard lg(connections_mutex); RemoveDeadConnections(); if (connections.size() >= max_connections) { logger.Warning("Max connections (" + std::to_string(max_connections) + ") reached, rejecting new connection"); shutdown(new_fd, SHUT_RDWR); close(new_fd); continue; } SetupNewConnection(new_fd, next_socket_number++); logger.Info("Accepted writer connection (socket_number=" + std::to_string(next_socket_number - 1) + ", total=" + std::to_string(connections.size()) + ")"); } } void TCPStreamPusher::SetupNewConnection(int new_fd, uint32_t socket_number) { auto c = std::make_shared(send_queue_size); c->socket_number = socket_number; c->fd.store(new_fd); int one = 1; setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); // Enable OS-level TCP keep-alive setsockopt(new_fd, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); int idle = 30; // longer than app-level keepalive to avoid interference int intvl = 10; int cnt = 3; setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)); setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPINTVL, &intvl, sizeof(intvl)); setsockopt(new_fd, IPPROTO_TCP, TCP_KEEPCNT, &cnt, sizeof(cnt)); if (send_buffer_size) setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); #if defined(SO_ZEROCOPY) int zc_one = 1; if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) c->zerocopy_enabled.store(true, std::memory_order_relaxed); else c->zerocopy_enabled.store(false, std::memory_order_relaxed); #endif c->connected = true; c->broken = false; auto now = std::chrono::steady_clock::now(); c->last_keepalive_sent = now; c->last_keepalive_recv = now; auto* raw = c.get(); c->persistent_ack_future = std::async(std::launch::async, &TCPStreamPusher::PersistentAckThread, this, raw); connections.emplace_back(std::move(c)); } void TCPStreamPusher::RemoveDeadConnections() { // Must be called with connections_mutex held. // We move dead connections out, release the mutex implicitly (caller still holds it), // then join their futures. Actually — we can join right here since PersistentAckThread // doesn't take connections_mutex, so no deadlock. auto it = connections.begin(); while (it != connections.end()) { auto c = *it; if (c->broken || !c->connected || !IsConnectionAlive(*c)) { c->connected = false; c->broken = true; StopDataCollectionThreads(*c); CloseFd(c->fd); if (c->persistent_ack_future.valid()) c->persistent_ack_future.get(); logger.Info("Removed dead connection (socket_number=" + std::to_string(c->socket_number) + ")"); it = connections.erase(it); } else { ++it; } } } void TCPStreamPusher::KeepaliveThread() { while (acceptor_running) { // Sleep in small increments so we can exit promptly for (int i = 0; i < 50 && acceptor_running; ++i) std::this_thread::sleep_for(std::chrono::milliseconds(100)); if (!acceptor_running) break; // During data collection, the data flow itself serves as heartbeat if (data_collection_active) continue; std::lock_guard lg(connections_mutex); for (auto& cptr : connections) { auto& c = *cptr; if (c.broken || !c.connected) continue; std::unique_lock ul(c.send_mutex); if (!SendFrame(c, nullptr, 0, TCPFrameType::KEEPALIVE, -1, nullptr)) { logger.Warning("Keepalive send failed on socket " + std::to_string(c.socket_number)); c.broken = true; } else { c.last_keepalive_sent = std::chrono::steady_clock::now(); } } RemoveDeadConnections(); } } size_t TCPStreamPusher::GetConnectedWriters() const { std::lock_guard lg(connections_mutex); size_t count = 0; for (const auto& c : connections) { if (c->connected && !c->broken) ++count; } return count; } void TCPStreamPusher::StartDataCollectionThreads(Connection& c) { { std::unique_lock ul(c.ack_mutex); c.start_ack_received = false; c.start_ack_ok = false; c.end_ack_received = false; c.end_ack_ok = false; c.cancel_ack_received = false; c.cancel_ack_ok = false; c.last_ack_error.clear(); c.last_ack_code = TCPAckCode::None; c.data_ack_error_reported = false; c.data_ack_error_text.clear(); } c.data_sent.store(0, std::memory_order_relaxed); c.data_acked_ok.store(0, std::memory_order_relaxed); c.data_acked_bad.store(0, std::memory_order_relaxed); c.data_acked_total.store(0, std::memory_order_relaxed); c.zc_next_id = 0; c.zc_completed_id = std::numeric_limits::max(); { std::unique_lock ul(c.zc_mutex); c.zc_pending.clear(); } c.active = true; c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c); c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c); } void TCPStreamPusher::StopDataCollectionThreads(Connection& c) { if (!c.active) return; c.active = false; // Avoid potential shutdown deadlock if queue is full and writer is stalled. if (!c.queue.PutTimeout({.end = true}, std::chrono::milliseconds(200))) { c.broken = true; CloseFd(c.fd); c.queue.Clear(); (void)c.queue.Put({.end = true}); } c.ack_cv.notify_all(); c.zc_cv.notify_all(); if (c.writer_future.valid()) c.writer_future.get(); constexpr auto zc_drain_timeout = std::chrono::seconds(2); if (!WaitForZeroCopyDrain(c, zc_drain_timeout)) { logger.Warning("TCP zerocopy completion drain timeout on socket " + std::to_string(c.socket_number) + "; forcing socket close"); c.broken = true; CloseFd(c.fd); c.zc_cv.notify_all(); c.ack_cv.notify_all(); } if (c.zc_future.valid()) c.zc_future.get(); ForceReleasePendingZeroCopy(c); } bool TCPStreamPusher::WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text) { std::unique_lock ul(c.ack_mutex); const bool ok = c.ack_cv.wait_for(ul, timeout, [&] { if (ack_for == TCPFrameType::START) return c.start_ack_received || c.broken.load(); if (ack_for == TCPFrameType::END) return c.end_ack_received || c.broken.load(); if (ack_for == TCPFrameType::CANCEL) return c.cancel_ack_received || c.broken.load(); return false; }); if (!ok) { if (error_text) *error_text = "ACK timeout"; return false; } if (c.broken) { if (error_text) *error_text = c.last_ack_error.empty() ? "Socket broken" : c.last_ack_error; return false; } bool ack_ok = false; if (ack_for == TCPFrameType::START) ack_ok = c.start_ack_ok; if (ack_for == TCPFrameType::END) ack_ok = c.end_ack_ok; if (ack_for == TCPFrameType::CANCEL) ack_ok = c.cancel_ack_ok; if (!ack_ok && error_text) *error_text = c.last_ack_error.empty() ? "ACK rejected" : c.last_ack_error; return ack_ok; } void TCPStreamPusher::StartDataCollection(StartMessage& message) { if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative"); images_per_file = message.images_per_file; run_number = message.run_number; run_name = message.run_name; transmission_error = false; total_data_acked_ok.store(0, std::memory_order_relaxed); total_data_acked_bad.store(0, std::memory_order_relaxed); total_data_acked_total.store(0, std::memory_order_relaxed); std::vector> local_connections; { std::lock_guard lg(connections_mutex); for (auto& c : connections) StopDataCollectionThreads(*c); RemoveDeadConnections(); if (connections.empty()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No writers connected to " + endpoint); session_connections = connections; local_connections = session_connections; } logger.Info("Starting data collection with " + std::to_string(local_connections.size()) + " connected writers"); data_collection_active = true; for (auto& c : local_connections) StartDataCollectionThreads(*c); std::vector started(local_connections.size(), false); auto rollback_cancel = [&]() { for (size_t i = 0; i < local_connections.size(); i++) { auto& c = *local_connections[i]; if (!started[i] || c.broken) continue; std::unique_lock ul(c.send_mutex); (void)SendFrame(c, nullptr, 0, TCPFrameType::CANCEL, -1, nullptr); std::string cancel_ack_err; (void)WaitForAck(c, TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); } for (auto& c : local_connections) StopDataCollectionThreads(*c); { std::lock_guard lg(connections_mutex); session_connections.clear(); } data_collection_active = false; }; for (size_t i = 0; i < local_connections.size(); i++) { auto& c = *local_connections[i]; message.socket_number = static_cast(c.socket_number); message.write_master_file = (i == 0); serializer.SerializeSequenceStart(message); { std::unique_lock ul(c.send_mutex); if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Timeout/failure sending START on socket " + std::to_string(c.socket_number)); } } std::string ack_err; if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "START ACK failed on socket " + std::to_string(c.socket_number) + ": " + ack_err); } started[i] = true; } } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { std::shared_ptr target; size_t conn_count = 0; { std::lock_guard lg(connections_mutex); const auto& use = (!session_connections.empty() ? session_connections : connections); if (use.empty()) return false; conn_count = use.size(); auto idx = static_cast((image_number / images_per_file) % static_cast(conn_count)); target = use[idx]; } auto& c = *target; if (c.broken || !IsConnectionAlive(c)) return false; std::unique_lock ul(c.send_mutex); return SendFrame(c, image_data, image_size, TCPFrameType::DATA, image_number, nullptr); } void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) { // Look up the target connection while holding the mutex, but do NOT call // PutBlocking while holding it — that can block indefinitely and deadlock // against AcceptorThread/KeepaliveThread. std::shared_ptr target; { std::lock_guard lg(connections_mutex); const auto& use = (!session_connections.empty() ? session_connections : connections); if (use.empty()) { z.release(); return; } auto idx = static_cast((z.GetImageNumber() / images_per_file) % static_cast(use.size())); target = use[idx]; } if (!target || target->broken || !target->active) { z.release(); return; } const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(2); while (std::chrono::steady_clock::now() < deadline) { if (target->broken || !target->active) { z.release(); return; } if (target->queue.PutTimeout(ImagePusherQueueElement{ .image_data = static_cast(z.GetImage()), .z = &z, .end = false }, std::chrono::milliseconds(50))) { return; } } target->broken = true; z.release(); } bool TCPStreamPusher::EndDataCollection(const EndMessage& message) { serializer.SerializeSequenceEnd(message); bool ret = true; std::vector> local_connections; { std::lock_guard lg(connections_mutex); local_connections = (!session_connections.empty() ? session_connections : connections); } for (auto& cptr : local_connections) { auto& c = *cptr; if (c.broken) { ret = false; continue; } { std::unique_lock ul(c.send_mutex); if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END, -1, nullptr)) { ret = false; continue; } } std::string ack_err; if (!WaitForAck(c, TCPFrameType::END, std::chrono::seconds(10), &ack_err)) ret = false; } for (auto& c : local_connections) StopDataCollectionThreads(*c); { std::lock_guard lg(connections_mutex); session_connections.clear(); } data_collection_active = false; transmission_error = !ret; return ret; } bool TCPStreamPusher::SendCalibration(const CompressedImage& message) { std::shared_ptr target; { std::lock_guard lg(connections_mutex); if (connections.empty()) return false; target = connections[0]; } if (!target || target->broken) return false; serializer.SerializeCalibration(message); std::unique_lock ul(target->send_mutex); return SendFrame(*target, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION, -1, nullptr); } std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; std::lock_guard lg(connections_mutex); for (size_t i = 0; i < connections.size(); i++) { auto& c = *connections[i]; { std::unique_lock ul(c.ack_mutex); if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) { ret += "Writer " + std::to_string(c.socket_number) + ": " + c.data_ack_error_text + ";"; } else if (!c.last_ack_error.empty()) { ret += "Writer " + std::to_string(c.socket_number) + ": " + c.last_ack_error + ";"; } } } return ret; } std::string TCPStreamPusher::PrintSetup() const { return "TCPStreamPusher: endpoint=" + endpoint + " max_connections=" + std::to_string(max_connections) + " connected=" + std::to_string(GetConnectedWriters()); } std::optional TCPStreamPusher::GetImagesWritten() const { return total_data_acked_ok.load(std::memory_order_relaxed); } std::optional TCPStreamPusher::GetImagesWriteError() const { return total_data_acked_bad.load(std::memory_order_relaxed); }