// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "TCPStreamPusherSocket.h" #include #include #include #include #include #include #include #include #include #include #include #include #include static std::pair ParseTcpAddress(const std::string& addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); auto hp = addr.substr(prefix.size()); auto p = hp.find_last_of(':'); if (p == std::string::npos) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); const auto host = hp.substr(0, p); const auto port_str = hp.substr(p + 1); int port_i = 0; try { size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } TCPStreamPusherSocket::TCPStreamPusherSocket(const std::string &addr, uint32_t in_socket_number, std::optional in_send_buffer_size, std::optional in_zerocopy_threshold, size_t send_queue_size) : queue(send_queue_size), endpoint(addr), socket_number(in_socket_number), zerocopy_threshold(in_zerocopy_threshold), send_buffer_size(in_send_buffer_size) { auto [host, port] = ParseTcpAddress(addr); listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); if (listen_fd < 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed"); int one = 1; setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); sockaddr_in sin{}; sin.sin_family = AF_INET; sin.sin_port = htons(port); if (host == "*" || host == "0.0.0.0") sin.sin_addr.s_addr = htonl(INADDR_ANY); else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host); if (bind(listen_fd, reinterpret_cast(&sin), sizeof(sin)) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr); if (listen(listen_fd, 16) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr); } TCPStreamPusherSocket::~TCPStreamPusherSocket() { try { StopWriterThread(); } catch (...) {} CloseDataSocket(); if (listen_fd >= 0) close(listen_fd); } void TCPStreamPusherSocket::CloseDataSocket() { int old_fd = fd.exchange(-1); if (old_fd >= 0) { shutdown(old_fd, SHUT_RDWR); close(old_fd); } } bool TCPStreamPusherSocket::AcceptConnection(std::chrono::milliseconds timeout) { std::unique_lock ul(send_mutex); // Reuse existing healthy connection instead of always forcing reconnect. if (fd.load() >= 0 && IsConnectionAlive()) { broken = false; logger.Info("TCP peer already connected on " + endpoint + ", reusing existing connection"); return true; } CloseDataSocket(); broken = false; pollfd pfd{}; pfd.fd = listen_fd; pfd.events = POLLIN; const int prc = poll(&pfd, 1, static_cast(timeout.count())); if (prc == 0) { logger.Error("TCP accept timeout (" + std::to_string(timeout.count()) + " ms) on " + endpoint); return false; } if (prc < 0) { if (errno == EINTR) return false; logger.Error("TCP poll() failed on " + endpoint); return false; } if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) { logger.Error("TCP listen socket error on " + endpoint); return false; } int new_fd = accept(listen_fd, nullptr, nullptr); if (new_fd < 0) return false; int one = 1; setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); if (send_buffer_size) setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); #ifdef SO_ZEROCOPY setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &one, sizeof(one)); #endif fd.store(new_fd); logger.Info("TCP peer connected on " + endpoint); return true; } bool TCPStreamPusherSocket::IsConnectionAlive() const { if (broken) return false; int local_fd = fd.load(); if (local_fd < 0) return false; pollfd pfd{}; pfd.fd = local_fd; pfd.events = POLLOUT; if (poll(&pfd, 1, 0) < 0) return false; if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) return false; int so_error = 0; socklen_t len = sizeof(so_error); if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) return false; return so_error == 0; } bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) { const uint8_t *p = static_cast(buf); size_t sent = 0; while (sent < len) { int local_fd = fd.load(); if (local_fd < 0 || broken) return false; ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL); if (rc < 0) { if (errno == EINTR) continue; if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { CloseDataSocket(); broken = true; logger.Error("TCP peer disconnected on " + endpoint + ", stopping this stream"); return false; } return false; } sent += static_cast(rc); } return true; } bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) { #if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) int local_fd = fd.load(); if (local_fd < 0) return false; msghdr msg{}; iovec iov{}; iov.iov_base = const_cast(data); iov.iov_len = size; msg.msg_iov = &iov; msg.msg_iovlen = 1; while (true) { ssize_t rc = ::sendmsg(local_fd, &msg, MSG_ZEROCOPY | MSG_NOSIGNAL); if (rc < 0) { if (errno == EINTR) continue; if (errno == EAGAIN) return SendAll(data, size); return false; } if (static_cast(rc) != size) return false; break; } std::unique_lock ul(inflight_mutex); inflight.push_back(InflightZC{.z = z, .tx_id = next_tx_id.fetch_add(1)}); return true; #else (void) z; return SendAll(data, size); #endif } bool TCPStreamPusherSocket::SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z) { TcpFrameHeader h{}; h.type = static_cast(type); h.payload_size = size; h.image_number = image_number >= 0 ? static_cast(image_number) : 0; h.socket_number = socket_number; h.run_number = run_number; if (!SendAll(&h, sizeof(h))) { if (z) z->release(); return false; } if (size == 0) { if (z) z->release(); return true; } if (z && zerocopy_threshold && size >= zerocopy_threshold.value()) { bool ok = SendPayloadZC(data, size, z); if (!ok) z->release(); return ok; } bool ok = SendAll(data, size); if (z) z->release(); return ok; } void TCPStreamPusherSocket::WriterThread() { while (active) { const auto e = queue.GetBlocking(); if (e.end) break; if (!e.z) continue; if (broken) { e.z->release(); continue; } bool ok = false; { std::unique_lock ul(send_mutex); ok = SendFrame(static_cast(e.z->GetImage()), e.z->GetImageSize(), TCPFrameType::DATA, e.z->GetImageNumber(), e.z); } if (!ok) { broken = true; logger.Error("TCP send failed on " + endpoint + ", stopping this stream"); } } } bool TCPStreamPusherSocket::IsBroken() const { return broken; } void TCPStreamPusherSocket::CompletionThread() { #if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) while (active) { int local_fd = fd.load(); if (local_fd < 0) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); continue; } char cmsgbuf[512]; char dummy[1]; iovec iov{.iov_base = dummy, .iov_len = sizeof(dummy)}; msghdr msg{}; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = cmsgbuf; msg.msg_controllen = sizeof(cmsgbuf); ssize_t rc = ::recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT); if (rc < 0) { if (errno == EAGAIN || errno == EINTR) { std::this_thread::sleep_for(std::chrono::microseconds(100)); continue; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); continue; } for (cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) { if (cmsg->cmsg_level != SOL_IP || cmsg->cmsg_type != IP_RECVERR) continue; auto *serr = reinterpret_cast(CMSG_DATA(cmsg)); if (!serr || serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY) continue; uint64_t first = serr->ee_info; uint64_t last = serr->ee_data; std::unique_lock ul(inflight_mutex); while (!inflight.empty() && inflight.front().tx_id <= last) { auto item = inflight.front(); inflight.pop_front(); if (item.tx_id >= first && item.z) item.z->release(); } } } std::unique_lock ul(inflight_mutex); while (!inflight.empty()) { if (inflight.front().z) inflight.front().z->release(); inflight.pop_front(); } #endif } void TCPStreamPusherSocket::StartWriterThread() { active = true; send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this); completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this); } void TCPStreamPusherSocket::StopWriterThread() { if (!active) return; active = false; queue.PutBlocking({.end = true}); if (send_future.valid()) send_future.get(); if (completion_future.valid()) completion_future.get(); // Keep fd open: END frame may still be sent after writer thread stops. // Socket is closed in destructor / explicit close path. } void TCPStreamPusherSocket::SendImage(ZeroCopyReturnValue &z) { queue.PutBlocking(ImagePusherQueueElement{ .image_data = static_cast(z.GetImage()), .z = &z, .end = false }); } bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number) { if (broken) return false; std::unique_lock ul(send_mutex); if (fd.load() < 0) return false; if (!IsConnectionAlive()) { broken = true; return false; } return SendFrame(data, size, type, image_number, nullptr); }