#include "TCPStreamPusher.h" #include #include #include #include #include #include #include #if defined(MSG_ZEROCOPY) #include #endif std::pair TCPStreamPusher::ParseTcpAddress(const std::string& addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); auto hp = addr.substr(prefix.size()); auto p = hp.find_last_of(':'); if (p == std::string::npos) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); const auto host = hp.substr(0, p); const auto port_str = hp.substr(p + 1); int port_i = 0; try { size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } int TCPStreamPusher::OpenListenSocket(const std::string& addr) { auto [host, port] = ParseTcpAddress(addr); int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); if (listen_fd < 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed"); int one = 1; setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); sockaddr_in sin{}; sin.sin_family = AF_INET; sin.sin_port = htons(port); if (host == "*" || host == "0.0.0.0") sin.sin_addr.s_addr = htonl(INADDR_ANY); else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host); if (bind(listen_fd, reinterpret_cast(&sin), sizeof(sin)) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr); if (listen(listen_fd, 64) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr); return listen_fd; } int TCPStreamPusher::AcceptOne(int listen_fd, std::chrono::milliseconds timeout) { pollfd pfd{}; pfd.fd = listen_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, static_cast(timeout.count())); if (prc <= 0) return -1; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) return -1; return accept(listen_fd, nullptr, nullptr); } void TCPStreamPusher::CloseFd(std::atomic& fd) { int old_fd = fd.exchange(-1); if (old_fd >= 0) { shutdown(old_fd, SHUT_RDWR); close(old_fd); } } TCPStreamPusher::TCPStreamPusher(const std::string& addr, size_t in_expected_connections, std::optional in_send_buffer_size) : serialization_buffer(256 * 1024 * 1024), serializer(serialization_buffer.data(), serialization_buffer.size()), endpoint(addr), expected_connections(in_expected_connections), send_buffer_size(in_send_buffer_size) { if (endpoint.empty()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); if (expected_connections == 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Expected TCP connections cannot be zero"); } TCPStreamPusher::~TCPStreamPusher() { for (auto& c : connections) { StopConnectionThreads(*c); CloseFd(c->fd); } } bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const { if (c.broken) return false; const int local_fd = c.fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLOUT; if (poll(&pfd, 1, 0) < 0) return false; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) return false; int so_error = 0; socklen_t len = sizeof(so_error); if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) return false; return so_error == 0; } bool TCPStreamPusher::SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy, bool* zc_used_out, uint32_t* zc_first_out, uint32_t* zc_last_out) { const auto* p = static_cast(buf); size_t sent = 0; bool zc_used = false; uint32_t zc_first = 0; uint32_t zc_last = 0; bool try_zerocopy = false; #if defined(MSG_ZEROCOPY) try_zerocopy = allow_zerocopy && c.zerocopy_enabled.load(std::memory_order_relaxed); #endif while (sent < len) { const int local_fd = c.fd.load(); if (local_fd < 0 || c.broken) { if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } int flags = MSG_NOSIGNAL; #if defined(MSG_ZEROCOPY) if (try_zerocopy) flags |= MSG_ZEROCOPY; #endif ssize_t rc = ::send(local_fd, p + sent, len - sent, flags); if (rc < 0) { if (errno == EINTR) continue; #if defined(MSG_ZEROCOPY) if (try_zerocopy && (errno == EOPNOTSUPP || errno == EINVAL || errno == ENOBUFS)) { try_zerocopy = false; c.zerocopy_enabled.store(false, std::memory_order_relaxed); continue; } #endif if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { c.broken = true; CloseFd(c.fd); } if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return false; } #if defined(MSG_ZEROCOPY) if (try_zerocopy && rc > 0) { const uint32_t this_id = c.zc_next_id++; if (!zc_used) { zc_used = true; zc_first = this_id; } zc_last = this_id; } #endif sent += static_cast(rc); } if (zc_used_out) *zc_used_out = zc_used; if (zc_first_out) *zc_first_out = zc_first; if (zc_last_out) *zc_last_out = zc_last; return true; } bool TCPStreamPusher::SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z) { TcpFrameHeader h{}; h.type = static_cast(type); h.payload_size = size; h.image_number = image_number >= 0 ? static_cast(image_number) : 0; h.socket_number = c.socket_number; h.run_number = run_number; if (!SendAll(c, &h, sizeof(h), false)) { if (z) z->release(); return false; } bool zc_used = false; uint32_t zc_first = 0; uint32_t zc_last = 0; if (size > 0) { const bool allow_zerocopy = (type == TCPFrameType::DATA || type == TCPFrameType::CALIBRATION); if (!SendAll(c, data, size, allow_zerocopy, &zc_used, &zc_first, &zc_last)) { if (z) { if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); else z->release(); } return false; } } if (z) { if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); else z->release(); } if (type == TCPFrameType::DATA) c.data_sent.fetch_add(1, std::memory_order_relaxed); return true; } void TCPStreamPusher::ReleaseCompletedZeroCopy(Connection& c) { while (!c.zc_pending.empty()) { const auto& front = c.zc_pending.front(); if (c.zc_completed_id == std::numeric_limits::max() || front.last_id > c.zc_completed_id) break; if (front.z) front.z->release(); c.zc_pending.pop_front(); } } void TCPStreamPusher::EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id) { std::unique_lock ul(c.zc_mutex); c.zc_pending.push_back(Connection::PendingZC{ .first_id = first_id, .last_id = last_id, .z = z }); c.zc_cv.notify_all(); } void TCPStreamPusher::ForceReleasePendingZeroCopy(Connection& c) { std::unique_lock ul(c.zc_mutex); while (!c.zc_pending.empty()) { auto p = c.zc_pending.front(); c.zc_pending.pop_front(); if (p.z) p.z->release(); } c.zc_cv.notify_all(); } bool TCPStreamPusher::WaitForZeroCopyDrain(Connection& c, std::chrono::milliseconds timeout) { std::unique_lock ul(c.zc_mutex); return c.zc_cv.wait_for(ul, timeout, [&] { return c.zc_pending.empty() || c.broken.load(); }); } void TCPStreamPusher::ZeroCopyCompletionThread(Connection* c) { #if defined(MSG_ZEROCOPY) while (c->active || !c->zc_pending.empty()) { const int local_fd = c->fd.load(); if (local_fd < 0) break; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLERR; const int prc = poll(&pfd, 1, 100); if (prc < 0) { if (errno == EINTR) continue; break; } if (prc == 0) continue; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) == 0) continue; uint8_t control[512]; uint8_t data[1]; iovec iov{}; iov.iov_base = data; iov.iov_len = sizeof(data); msghdr msg{}; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = control; msg.msg_controllen = sizeof(control); while (true) { const ssize_t rr = recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT); if (rr < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) break; c->broken = true; break; } for (cmsghdr* cm = CMSG_FIRSTHDR(&msg); cm != nullptr; cm = CMSG_NXTHDR(&msg, cm)) { if (cm->cmsg_level != SOL_IP || cm->cmsg_type != IP_RECVERR) continue; auto* se = reinterpret_cast(CMSG_DATA(cm)); if (!se || se->ee_origin != SO_EE_ORIGIN_ZEROCOPY) continue; const uint32_t end_id = se->ee_data; std::unique_lock ul(c->zc_mutex); if (c->zc_completed_id == std::numeric_limits::max() || end_id > c->zc_completed_id) c->zc_completed_id = end_id; ReleaseCompletedZeroCopy(*c); c->zc_cv.notify_all(); } } } #else (void)c; #endif } bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) { auto* p = static_cast(buf); size_t got = 0; while (got < len) { if (!c.active) return false; const int local_fd = c.fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, 100); if (prc == 0) continue; if (prc < 0) { if (errno == EINTR) continue; return false; } if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) return false; if ((pfd.revents & POLLIN) == 0) continue; ssize_t rc = ::recv(local_fd, p + got, len - got, 0); if (rc == 0) return false; if (rc < 0) { if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) continue; return false; } got += static_cast(rc); } return true; } void TCPStreamPusher::WriterThread(Connection* c) { while (c->active) { auto e = c->queue.GetBlocking(); if (e.end) break; if (!e.z) continue; if (c->broken) { e.z->release(); continue; } std::unique_lock ul(c->send_mutex); if (!SendFrame(*c, static_cast(e.z->GetImage()), e.z->GetImageSize(), TCPFrameType::DATA, e.z->GetImageNumber(), e.z)) { c->broken = true; logger.Error("TCP send failed on socket " + std::to_string(c->socket_number)); } } } void TCPStreamPusher::AckThread(Connection* c) { while (c->active) { TcpFrameHeader h{}; if (!ReadExact(*c, &h, sizeof(h))) { if (c->active) { c->broken = true; logger.Error("TCP ACK reader disconnected on socket " + std::to_string(c->socket_number)); } break; } if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { c->broken = true; logger.Error("Invalid ACK frame on socket " + std::to_string(c->socket_number)); break; } std::string error_text; if (h.payload_size > 0) { error_text.resize(h.payload_size); if (!ReadExact(*c, error_text.data(), error_text.size())) { c->broken = true; break; } } const auto ack_for = static_cast(h.ack_for); const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; const auto code = static_cast(h.ack_code); { std::unique_lock ul(c->ack_mutex); c->last_ack_code = code; if (!error_text.empty()) c->last_ack_error = error_text; if (ack_for == TCPFrameType::START) { c->start_ack_received = true; c->start_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "START rejected"; } else if (ack_for == TCPFrameType::END) { c->end_ack_received = true; c->end_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "END rejected"; } else if (ack_for == TCPFrameType::CANCEL) { c->cancel_ack_received = true; c->cancel_ack_ok = ok; if (!ok && error_text.empty()) c->last_ack_error = "CANCEL rejected"; } else if (ack_for == TCPFrameType::DATA) { c->data_acked_total.fetch_add(1, std::memory_order_relaxed); total_data_acked_total.fetch_add(1, std::memory_order_relaxed); if (ok && !fatal) { c->data_acked_ok.fetch_add(1, std::memory_order_relaxed); total_data_acked_ok.fetch_add(1, std::memory_order_relaxed); } else { c->data_acked_bad.fetch_add(1, std::memory_order_relaxed); total_data_acked_bad.fetch_add(1, std::memory_order_relaxed); // Soft failure: remember it for Finalize(), do NOT mark socket broken. c->data_ack_error_reported = true; if (!error_text.empty()) { c->data_ack_error_text = error_text; } else if (c->data_ack_error_text.empty()) { c->data_ack_error_text = "DATA ACK failed"; } } } } c->ack_cv.notify_all(); } } void TCPStreamPusher::StartConnectionThreads(Connection& c) { { std::unique_lock ul(c.ack_mutex); c.start_ack_received = false; c.start_ack_ok = false; c.end_ack_received = false; c.end_ack_ok = false; c.cancel_ack_received = false; c.cancel_ack_ok = false; c.last_ack_error.clear(); c.last_ack_code = TCPAckCode::None; c.data_ack_error_reported = false; c.data_ack_error_text.clear(); } c.data_sent.store(0, std::memory_order_relaxed); c.data_acked_ok.store(0, std::memory_order_relaxed); c.data_acked_bad.store(0, std::memory_order_relaxed); c.data_acked_total.store(0, std::memory_order_relaxed); c.zc_next_id = 0; c.zc_completed_id = std::numeric_limits::max(); { std::unique_lock ul(c.zc_mutex); c.zc_pending.clear(); } c.active = true; c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c); c.ack_future = std::async(std::launch::async, &TCPStreamPusher::AckThread, this, &c); c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c); } void TCPStreamPusher::StopConnectionThreads(Connection& c) { if (!c.active) return; c.active = false; c.queue.PutBlocking({.end = true}); c.ack_cv.notify_all(); c.zc_cv.notify_all(); if (c.writer_future.valid()) c.writer_future.get(); constexpr auto zc_drain_timeout = std::chrono::seconds(2); if (!WaitForZeroCopyDrain(c, zc_drain_timeout)) { logger.Warning("TCP zerocopy completion drain timeout on socket " + std::to_string(c.socket_number) + "; forcing socket close"); c.broken = true; CloseFd(c.fd); c.zc_cv.notify_all(); c.ack_cv.notify_all(); } if (c.ack_future.valid()) c.ack_future.get(); if (c.zc_future.valid()) c.zc_future.get(); ForceReleasePendingZeroCopy(c); } bool TCPStreamPusher::WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text) { std::unique_lock ul(c.ack_mutex); const bool ok = c.ack_cv.wait_for(ul, timeout, [&] { if (ack_for == TCPFrameType::START) return c.start_ack_received || c.broken.load(); if (ack_for == TCPFrameType::END) return c.end_ack_received || c.broken.load(); if (ack_for == TCPFrameType::CANCEL) return c.cancel_ack_received || c.broken.load(); return false; }); if (!ok) { if (error_text) *error_text = "ACK timeout"; return false; } if (c.broken) { if (error_text) *error_text = c.last_ack_error.empty() ? "Socket broken" : c.last_ack_error; return false; } bool ack_ok = false; if (ack_for == TCPFrameType::START) ack_ok = c.start_ack_ok; if (ack_for == TCPFrameType::END) ack_ok = c.end_ack_ok; if (ack_for == TCPFrameType::CANCEL) ack_ok = c.cancel_ack_ok; if (!ack_ok && error_text) *error_text = c.last_ack_error.empty() ? "ACK rejected" : c.last_ack_error; return ack_ok; } void TCPStreamPusher::StartDataCollection(StartMessage& message) { if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative"); images_per_file = message.images_per_file; run_number = message.run_number; run_name = message.run_name; transmission_error = false; total_data_acked_ok.store(0, std::memory_order_relaxed); total_data_acked_bad.store(0, std::memory_order_relaxed); total_data_acked_total.store(0, std::memory_order_relaxed); for (auto& c : connections) { StopConnectionThreads(*c); CloseFd(c->fd); } connections.clear(); connections.reserve(expected_connections); int listen_fd = OpenListenSocket(endpoint); try { for (size_t i = 0; i < expected_connections; i++) { int new_fd = AcceptOne(listen_fd, std::chrono::seconds(5)); if (new_fd < 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on " + endpoint); auto c = std::make_unique(send_queue_size); c->socket_number = static_cast(i); c->fd.store(new_fd); int one = 1; setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); if (send_buffer_size) setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); #if defined(SO_ZEROCOPY) int zc_one = 1; if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) { c->zerocopy_enabled.store(true, std::memory_order_relaxed); } else { c->zerocopy_enabled.store(false, std::memory_order_relaxed); } #endif connections.emplace_back(std::move(c)); } } catch (...) { close(listen_fd); throw; } close(listen_fd); for (auto& c : connections) StartConnectionThreads(*c); std::vector started(connections.size(), false); auto rollback_cancel = [&]() { for (size_t i = 0; i < connections.size(); i++) { auto& c = *connections[i]; if (!started[i] || c.broken) continue; std::unique_lock ul(c.send_mutex); (void)SendFrame(c, nullptr, 0, TCPFrameType::CANCEL, -1, nullptr); std::string cancel_ack_err; (void)WaitForAck(c, TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); } for (auto& c : connections) StopConnectionThreads(*c); }; for (size_t i = 0; i < connections.size(); i++) { auto& c = *connections[i]; message.socket_number = static_cast(i); message.write_master_file = (i == 0); serializer.SerializeSequenceStart(message); { std::unique_lock ul(c.send_mutex); if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Timeout/failure sending START on socket " + std::to_string(i)); } } std::string ack_err; if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "START ACK failed on socket " + std::to_string(i) + ": " + ack_err); } started[i] = true; } } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { if (connections.empty()) return false; auto idx = static_cast((image_number / images_per_file) % static_cast(connections.size())); auto& c = *connections[idx]; if (c.broken || !IsConnectionAlive(c)) return false; std::unique_lock ul(c.send_mutex); return SendFrame(c, image_data, image_size, TCPFrameType::DATA, image_number, nullptr); } void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) { if (connections.empty()) { z.release(); return; } auto idx = static_cast((z.GetImageNumber() / images_per_file) % static_cast(connections.size())); auto& c = *connections[idx]; if (c.broken) { z.release(); return; } c.queue.PutBlocking(ImagePusherQueueElement{ .image_data = static_cast(z.GetImage()), .z = &z, .end = false }); } bool TCPStreamPusher::EndDataCollection(const EndMessage& message) { serializer.SerializeSequenceEnd(message); bool ret = true; for (auto& cptr : connections) { auto& c = *cptr; if (c.broken) { ret = false; continue; } { std::unique_lock ul(c.send_mutex); if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END, -1, nullptr)) { ret = false; continue; } } std::string ack_err; if (!WaitForAck(c, TCPFrameType::END, std::chrono::seconds(10), &ack_err)) ret = false; } for (auto& c : connections) StopConnectionThreads(*c); transmission_error = !ret; return ret; } bool TCPStreamPusher::SendCalibration(const CompressedImage& message) { if (connections.empty()) return false; serializer.SerializeCalibration(message); auto& c = *connections[0]; if (c.broken) return false; std::unique_lock ul(c.send_mutex); return SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION, -1, nullptr); } std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; for (size_t i = 0; i < connections.size(); i++) { auto& c = *connections[i]; { std::unique_lock ul(c.ack_mutex); if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) { ret += "Writer " + std::to_string(i) + ": " + c.data_ack_error_text + ";"; } else if (!c.last_ack_error.empty()) { ret += "Writer " + std::to_string(i) + ": " + c.last_ack_error + ";"; } } } return ret; } std::string TCPStreamPusher::PrintSetup() const { return "TCPStreamPusher: endpoint=" + endpoint + " expected_connections=" + std::to_string(expected_connections); } std::optional TCPStreamPusher::GetImagesWritten() const { return total_data_acked_ok.load(std::memory_order_relaxed); }