Files
Jungfraujoch/image_pusher/TCPStreamPusherSocket.cpp

598 lines
18 KiB
C++

// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
// SPDX-License-Identifier: GPL-3.0-only
#include "TCPStreamPusherSocket.h"
#include <cstring>
#include <thread>
#include <chrono>
#include <limits>
#include <poll.h>
#include <cerrno>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <unistd.h>
#include <linux/errqueue.h>
#include <netdb.h>
static std::pair<std::string, uint16_t> ParseTcpAddress(const std::string& addr) {
const std::string prefix = "tcp://";
if (addr.rfind(prefix, 0) != 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr);
auto hp = addr.substr(prefix.size());
auto p = hp.find_last_of(':');
if (p == std::string::npos)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr);
const auto host = hp.substr(0, p);
const auto port_str = hp.substr(p + 1);
int port_i = 0;
try {
size_t parsed = 0;
port_i = std::stoi(port_str, &parsed);
if (parsed != port_str.size())
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr);
} catch (...) {
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr);
}
if (port_i < 1 || port_i > static_cast<int>(std::numeric_limits<uint16_t>::max()))
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr);
return {host, static_cast<uint16_t>(port_i)};
}
TCPStreamPusherSocket::TCPStreamPusherSocket(const std::string &addr,
uint32_t in_socket_number,
std::optional<int32_t> in_send_buffer_size,
std::optional<size_t> in_zerocopy_threshold,
size_t send_queue_size)
: queue(send_queue_size),
endpoint(addr),
socket_number(in_socket_number),
zerocopy_threshold(in_zerocopy_threshold),
send_buffer_size(in_send_buffer_size) {
auto [host, port] = ParseTcpAddress(addr);
listen_fd = ::socket(AF_INET, SOCK_STREAM, 0);
if (listen_fd < 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed");
int one = 1;
setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
sockaddr_in sin{};
sin.sin_family = AF_INET;
sin.sin_port = htons(port);
if (host == "*" || host == "0.0.0.0")
sin.sin_addr.s_addr = htonl(INADDR_ANY);
else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host);
if (bind(listen_fd, reinterpret_cast<sockaddr *>(&sin), sizeof(sin)) != 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr);
if (listen(listen_fd, 16) != 0)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr);
}
TCPStreamPusherSocket::~TCPStreamPusherSocket() {
try {
StopWriterThread();
} catch (...) {}
CloseDataSocket();
if (listen_fd >= 0)
close(listen_fd);
}
void TCPStreamPusherSocket::CloseDataSocket() {
int old_fd = fd.exchange(-1);
if (old_fd >= 0) {
shutdown(old_fd, SHUT_RDWR);
close(old_fd);
}
}
bool TCPStreamPusherSocket::AcceptConnection(std::chrono::milliseconds timeout) {
std::unique_lock ul(send_mutex);
// Reuse existing healthy connection instead of always forcing reconnect.
if (fd.load() >= 0 && IsConnectionAlive()) {
broken = false;
logger.Info("TCP peer already connected on " + endpoint + ", reusing existing connection");
return true;
}
CloseDataSocket();
broken = false;
pollfd pfd{};
pfd.fd = listen_fd;
pfd.events = POLLIN;
const int prc = poll(&pfd, 1, static_cast<int>(timeout.count()));
if (prc == 0) {
logger.Error("TCP accept timeout (" + std::to_string(timeout.count()) + " ms) on " + endpoint);
return false;
}
if (prc < 0) {
if (errno == EINTR)
return false;
logger.Error("TCP poll() failed on " + endpoint);
return false;
}
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) {
logger.Error("TCP listen socket error on " + endpoint);
return false;
}
int new_fd = accept(listen_fd, nullptr, nullptr);
if (new_fd < 0)
return false;
int one = 1;
setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
if (send_buffer_size)
setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t));
#ifdef SO_ZEROCOPY
setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &one, sizeof(one));
#endif
fd.store(new_fd);
logger.Info("TCP peer connected on " + endpoint);
return true;
}
bool TCPStreamPusherSocket::IsConnectionAlive() const {
if (broken)
return false;
int local_fd = fd.load();
if (local_fd < 0)
return false;
pollfd pfd{};
pfd.fd = local_fd;
pfd.events = POLLOUT;
if (poll(&pfd, 1, 0) < 0)
return false;
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
return false;
int so_error = 0;
socklen_t len = sizeof(so_error);
if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0)
return false;
return so_error == 0;
}
bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) {
const uint8_t *p = static_cast<const uint8_t *>(buf);
size_t sent = 0;
while (sent < len) {
int local_fd = fd.load();
if (local_fd < 0 || broken)
return false;
ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL);
if (rc < 0) {
if (errno == EINTR)
continue;
if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) {
CloseDataSocket();
broken = true;
logger.Error("TCP peer disconnected on " + endpoint + ", stopping this stream");
return false;
}
return false;
}
sent += static_cast<size_t>(rc);
}
return true;
}
bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) {
auto *p = static_cast<uint8_t *>(buf);
size_t got = 0;
while (got < len) {
if (!active)
return false;
int local_fd = fd.load();
if (local_fd < 0)
return false;
pollfd pfd{};
pfd.fd = local_fd;
pfd.events = POLLIN;
const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window
if (prc == 0)
continue;
if (prc < 0) {
if (errno == EINTR)
continue;
return false;
}
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
return false;
if ((pfd.revents & POLLIN) == 0)
continue;
ssize_t rc = ::recv(local_fd, p + got, len - got, 0);
if (rc == 0)
return false;
if (rc < 0) {
if (errno == EINTR)
continue;
if (errno == EAGAIN || errno == EWOULDBLOCK)
continue;
return false;
}
got += static_cast<size_t>(rc);
}
return true;
}
bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) {
#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY)
int local_fd = fd.load();
if (local_fd < 0)
return false;
msghdr msg{};
iovec iov{};
iov.iov_base = const_cast<uint8_t *>(data);
iov.iov_len = size;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
while (true) {
ssize_t rc = ::sendmsg(local_fd, &msg, MSG_ZEROCOPY | MSG_NOSIGNAL);
if (rc < 0) {
if (errno == EINTR)
continue;
if (errno == EAGAIN)
return SendAll(data, size);
return false;
}
if (static_cast<size_t>(rc) != size)
return false;
break;
}
std::unique_lock ul(inflight_mutex);
inflight.push_back(InflightZC{.z = z, .tx_id = next_tx_id.fetch_add(1)});
return true;
#else
(void) z;
return SendAll(data, size);
#endif
}
bool TCPStreamPusherSocket::SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z) {
TcpFrameHeader h{};
h.type = static_cast<uint16_t>(type);
h.payload_size = size;
h.image_number = image_number >= 0 ? static_cast<uint64_t>(image_number) : 0;
h.socket_number = socket_number;
h.run_number = run_number;
if (!SendAll(&h, sizeof(h))) {
if (z)
z->release();
return false;
}
if (size == 0) {
if (z)
z->release();
return true;
}
bool ok;
if (z && zerocopy_threshold && size >= zerocopy_threshold.value()) {
ok = SendPayloadZC(data, size, z);
if (!ok)
z->release();
} else {
ok = SendAll(data, size);
if (z)
z->release();
}
if (ok && type == TCPFrameType::DATA)
data_sent.fetch_add(1, std::memory_order_relaxed);
return ok;
}
void TCPStreamPusherSocket::WriterThread() {
while (active) {
const auto e = queue.GetBlocking();
if (e.end)
break;
if (!e.z)
continue;
if (broken) {
e.z->release();
continue;
}
bool ok = false;
{
std::unique_lock ul(send_mutex);
ok = SendFrame(static_cast<const uint8_t *>(e.z->GetImage()),
e.z->GetImageSize(),
TCPFrameType::DATA,
e.z->GetImageNumber(),
e.z);
}
if (!ok) {
broken = true;
logger.Error("TCP send failed on " + endpoint + ", stopping this stream");
}
}
}
bool TCPStreamPusherSocket::IsBroken() const {
return broken;
}
void TCPStreamPusherSocket::CompletionThread() {
#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY)
while (active) {
int local_fd = fd.load();
if (local_fd < 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
continue;
}
char cmsgbuf[512];
char dummy[1];
iovec iov{.iov_base = dummy, .iov_len = sizeof(dummy)};
msghdr msg{};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsgbuf;
msg.msg_controllen = sizeof(cmsgbuf);
ssize_t rc = ::recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT);
if (rc < 0) {
if (errno == EAGAIN || errno == EINTR) {
std::this_thread::sleep_for(std::chrono::microseconds(100));
continue;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
continue;
}
for (cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_IP || cmsg->cmsg_type != IP_RECVERR)
continue;
auto *serr = reinterpret_cast<sock_extended_err *>(CMSG_DATA(cmsg));
if (!serr || serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY)
continue;
uint64_t first = serr->ee_info;
uint64_t last = serr->ee_data;
std::unique_lock ul(inflight_mutex);
while (!inflight.empty() && inflight.front().tx_id <= last) {
auto item = inflight.front();
inflight.pop_front();
if (item.tx_id >= first && item.z)
item.z->release();
}
}
}
std::unique_lock ul(inflight_mutex);
while (!inflight.empty()) {
if (inflight.front().z)
inflight.front().z->release();
inflight.pop_front();
}
#endif
}
void TCPStreamPusherSocket::AckThread() {
while (active) {
TcpFrameHeader h{};
if (!ReadExact(&h, sizeof(h))) {
if (active) {
broken = true;
logger.Error("TCP ACK reader disconnected on " + endpoint);
}
break;
}
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast<TCPFrameType>(h.type) != TCPFrameType::ACK) {
broken = true;
logger.Error("Invalid ACK frame on " + endpoint);
break;
}
std::string error_text;
if (h.payload_size > 0) {
error_text.resize(h.payload_size);
if (!ReadExact(error_text.data(), error_text.size())) {
broken = true;
break;
}
}
const auto ack_for = static_cast<TCPFrameType>(h.ack_for);
const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0;
const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0;
const auto code = static_cast<TCPAckCode>(h.ack_code);
{
std::unique_lock ul(ack_state_mutex);
last_ack_code = code;
if (!error_text.empty())
last_ack_error = error_text;
if (ack_for == TCPFrameType::START) {
start_ack_received = true;
start_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "START rejected";
} else if (ack_for == TCPFrameType::END) {
end_ack_received = true;
end_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "END rejected";
} else if (ack_for == TCPFrameType::CANCEL) {
cancel_ack_received = true;
cancel_ack_ok = ok;
if (!ok && error_text.empty())
last_ack_error = "CANCEL rejected";
} else if (ack_for == TCPFrameType::DATA) {
data_acked_total.fetch_add(1, std::memory_order_relaxed);
if (ok && !fatal) {
data_acked_ok.fetch_add(1, std::memory_order_relaxed);
} else {
data_acked_bad.fetch_add(1, std::memory_order_relaxed);
if (error_text.empty())
last_ack_error = "DATA ACK failed";
logger.Error("Received failing DATA ACK on " + endpoint + ": " + last_ack_error);
}
}
}
ack_cv.notify_all();
}
}
void TCPStreamPusherSocket::StartWriterThread() {
if (active)
return;
{
std::unique_lock ul(ack_state_mutex);
start_ack_received = false;
start_ack_ok = false;
end_ack_received = false;
end_ack_ok = false;
cancel_ack_received = false;
cancel_ack_ok = false;
last_ack_error.clear();
last_ack_code = TCPAckCode::None;
}
data_sent.store(0, std::memory_order_relaxed);
data_acked_ok.store(0, std::memory_order_relaxed);
data_acked_bad.store(0, std::memory_order_relaxed);
data_acked_total.store(0, std::memory_order_relaxed);
active = true;
send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this);
completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this);
ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this);
}
void TCPStreamPusherSocket::StopWriterThread() {
if (!active)
return;
active = false;
queue.PutBlocking({.end = true});
ack_cv.notify_all();
if (send_future.valid())
send_future.get();
if (completion_future.valid())
completion_future.get();
if (ack_future.valid())
ack_future.get();
// Keep fd open: END frame may still be sent after writer thread stops.
// Socket is closed in destructor / explicit close path.
}
void TCPStreamPusherSocket::SendImage(ZeroCopyReturnValue &z) {
queue.PutBlocking(ImagePusherQueueElement{
.image_data = static_cast<uint8_t *>(z.GetImage()),
.z = &z,
.end = false
});
}
bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number) {
if (broken)
return false;
std::unique_lock ul(send_mutex);
if (fd.load() < 0)
return false;
if (!IsConnectionAlive()) {
broken = true;
return false;
}
return SendFrame(data, size, type, image_number, nullptr);
}
bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) {
std::unique_lock ul(ack_state_mutex);
const bool ok = ack_cv.wait_for(ul, timeout, [&] {
if (ack_for == TCPFrameType::START)
return start_ack_received || broken.load();
if (ack_for == TCPFrameType::END)
return end_ack_received || broken.load();
if (ack_for == TCPFrameType::CANCEL)
return cancel_ack_received || broken.load();
return false;
});
if (!ok) {
if (error_text)
*error_text = "ACK timeout";
return false;
}
if (broken) {
if (error_text)
*error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error;
return false;
}
bool ack_ok = false;
if (ack_for == TCPFrameType::START)
ack_ok = start_ack_ok;
else if (ack_for == TCPFrameType::END)
ack_ok = end_ack_ok;
else if (ack_for == TCPFrameType::CANCEL)
ack_ok = cancel_ack_ok;
if (!ack_ok && error_text)
*error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error;
return ack_ok;
}
std::string TCPStreamPusherSocket::GetLastAckError() const {
std::unique_lock ul(ack_state_mutex);
return last_ack_error;
}
ImagePusherAckProgress TCPStreamPusherSocket::GetDataAckProgress() const {
ImagePusherAckProgress p;
p.data_sent = data_sent.load(std::memory_order_relaxed);
p.data_acked_ok = data_acked_ok.load(std::memory_order_relaxed);
p.data_acked_bad = data_acked_bad.load(std::memory_order_relaxed);
p.data_acked_total = data_acked_total.load(std::memory_order_relaxed);
p.data_ack_pending = (p.data_sent >= p.data_acked_total) ? (p.data_sent - p.data_acked_total) : 0;
return p;
}