802 lines
26 KiB
C++
802 lines
26 KiB
C++
#include "TCPStreamPusher.h"
|
|
|
|
#include <limits>
|
|
#include <poll.h>
|
|
#include <cerrno>
|
|
#include <sys/socket.h>
|
|
#include <arpa/inet.h>
|
|
#include <netinet/tcp.h>
|
|
#include <unistd.h>
|
|
#if defined(MSG_ZEROCOPY)
|
|
#include <linux/errqueue.h>
|
|
#endif
|
|
|
|
std::pair<std::string, uint16_t> TCPStreamPusher::ParseTcpAddress(const std::string& addr) {
|
|
const std::string prefix = "tcp://";
|
|
if (addr.rfind(prefix, 0) != 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr);
|
|
|
|
auto hp = addr.substr(prefix.size());
|
|
auto p = hp.find_last_of(':');
|
|
if (p == std::string::npos)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr);
|
|
|
|
const auto host = hp.substr(0, p);
|
|
const auto port_str = hp.substr(p + 1);
|
|
|
|
int port_i = 0;
|
|
try {
|
|
size_t parsed = 0;
|
|
port_i = std::stoi(port_str, &parsed);
|
|
if (parsed != port_str.size())
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr);
|
|
} catch (...) {
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr);
|
|
}
|
|
|
|
if (port_i < 1 || port_i > static_cast<int>(std::numeric_limits<uint16_t>::max()))
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr);
|
|
|
|
return {host, static_cast<uint16_t>(port_i)};
|
|
}
|
|
|
|
int TCPStreamPusher::OpenListenSocket(const std::string& addr) {
|
|
auto [host, port] = ParseTcpAddress(addr);
|
|
|
|
int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0);
|
|
if (listen_fd < 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed");
|
|
|
|
int one = 1;
|
|
setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
|
|
|
|
sockaddr_in sin{};
|
|
sin.sin_family = AF_INET;
|
|
sin.sin_port = htons(port);
|
|
if (host == "*" || host == "0.0.0.0")
|
|
sin.sin_addr.s_addr = htonl(INADDR_ANY);
|
|
else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host);
|
|
|
|
if (bind(listen_fd, reinterpret_cast<sockaddr*>(&sin), sizeof(sin)) != 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr);
|
|
|
|
if (listen(listen_fd, 64) != 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr);
|
|
|
|
return listen_fd;
|
|
}
|
|
|
|
int TCPStreamPusher::AcceptOne(int listen_fd, std::chrono::milliseconds timeout) {
|
|
pollfd pfd{};
|
|
pfd.fd = listen_fd;
|
|
pfd.events = POLLIN;
|
|
|
|
const int prc = poll(&pfd, 1, static_cast<int>(timeout.count()));
|
|
if (prc <= 0)
|
|
return -1;
|
|
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
|
|
return -1;
|
|
|
|
return accept(listen_fd, nullptr, nullptr);
|
|
}
|
|
|
|
void TCPStreamPusher::CloseFd(std::atomic<int>& fd) {
|
|
int old_fd = fd.exchange(-1);
|
|
if (old_fd >= 0) {
|
|
shutdown(old_fd, SHUT_RDWR);
|
|
close(old_fd);
|
|
}
|
|
}
|
|
|
|
TCPStreamPusher::TCPStreamPusher(const std::string& addr,
|
|
size_t in_expected_connections,
|
|
std::optional<int32_t> in_send_buffer_size)
|
|
: serialization_buffer(256 * 1024 * 1024),
|
|
serializer(serialization_buffer.data(), serialization_buffer.size()),
|
|
endpoint(addr),
|
|
expected_connections(in_expected_connections),
|
|
send_buffer_size(in_send_buffer_size) {
|
|
if (endpoint.empty())
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided");
|
|
if (expected_connections == 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Expected TCP connections cannot be zero");
|
|
}
|
|
|
|
TCPStreamPusher::~TCPStreamPusher() {
|
|
for (auto& c : connections) {
|
|
StopConnectionThreads(*c);
|
|
CloseFd(c->fd);
|
|
}
|
|
}
|
|
|
|
bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const {
|
|
if (c.broken)
|
|
return false;
|
|
|
|
const int local_fd = c.fd.load();
|
|
if (local_fd < 0)
|
|
return false;
|
|
|
|
pollfd pfd{};
|
|
pfd.fd = local_fd;
|
|
pfd.events = POLLOUT;
|
|
if (poll(&pfd, 1, 0) < 0)
|
|
return false;
|
|
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
|
|
return false;
|
|
|
|
int so_error = 0;
|
|
socklen_t len = sizeof(so_error);
|
|
if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0)
|
|
return false;
|
|
|
|
return so_error == 0;
|
|
}
|
|
|
|
bool TCPStreamPusher::SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy,
|
|
bool* zc_used_out, uint32_t* zc_first_out, uint32_t* zc_last_out) {
|
|
const auto* p = static_cast<const uint8_t*>(buf);
|
|
size_t sent = 0;
|
|
bool zc_used = false;
|
|
uint32_t zc_first = 0;
|
|
uint32_t zc_last = 0;
|
|
|
|
bool try_zerocopy = false;
|
|
#if defined(MSG_ZEROCOPY)
|
|
try_zerocopy = allow_zerocopy && c.zerocopy_enabled.load(std::memory_order_relaxed);
|
|
#endif
|
|
|
|
while (sent < len) {
|
|
const int local_fd = c.fd.load();
|
|
if (local_fd < 0 || c.broken) {
|
|
if (zc_used_out) *zc_used_out = zc_used;
|
|
if (zc_first_out) *zc_first_out = zc_first;
|
|
if (zc_last_out) *zc_last_out = zc_last;
|
|
return false;
|
|
}
|
|
|
|
int flags = MSG_NOSIGNAL;
|
|
#if defined(MSG_ZEROCOPY)
|
|
if (try_zerocopy)
|
|
flags |= MSG_ZEROCOPY;
|
|
#endif
|
|
|
|
ssize_t rc = ::send(local_fd, p + sent, len - sent, flags);
|
|
if (rc < 0) {
|
|
if (errno == EINTR)
|
|
continue;
|
|
|
|
#if defined(MSG_ZEROCOPY)
|
|
if (try_zerocopy && (errno == EOPNOTSUPP || errno == EINVAL || errno == ENOBUFS)) {
|
|
try_zerocopy = false;
|
|
c.zerocopy_enabled.store(false, std::memory_order_relaxed);
|
|
continue;
|
|
}
|
|
#endif
|
|
|
|
if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) {
|
|
c.broken = true;
|
|
CloseFd(c.fd);
|
|
}
|
|
|
|
if (zc_used_out) *zc_used_out = zc_used;
|
|
if (zc_first_out) *zc_first_out = zc_first;
|
|
if (zc_last_out) *zc_last_out = zc_last;
|
|
return false;
|
|
}
|
|
|
|
#if defined(MSG_ZEROCOPY)
|
|
if (try_zerocopy && rc > 0) {
|
|
const uint32_t this_id = c.zc_next_id++;
|
|
if (!zc_used) {
|
|
zc_used = true;
|
|
zc_first = this_id;
|
|
}
|
|
zc_last = this_id;
|
|
}
|
|
#endif
|
|
|
|
sent += static_cast<size_t>(rc);
|
|
}
|
|
|
|
if (zc_used_out) *zc_used_out = zc_used;
|
|
if (zc_first_out) *zc_first_out = zc_first;
|
|
if (zc_last_out) *zc_last_out = zc_last;
|
|
return true;
|
|
}
|
|
|
|
bool TCPStreamPusher::SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z) {
|
|
TcpFrameHeader h{};
|
|
h.type = static_cast<uint16_t>(type);
|
|
h.payload_size = size;
|
|
h.image_number = image_number >= 0 ? static_cast<uint64_t>(image_number) : 0;
|
|
h.socket_number = c.socket_number;
|
|
h.run_number = run_number;
|
|
|
|
if (!SendAll(c, &h, sizeof(h), false)) {
|
|
if (z) z->release();
|
|
return false;
|
|
}
|
|
|
|
bool zc_used = false;
|
|
uint32_t zc_first = 0;
|
|
uint32_t zc_last = 0;
|
|
|
|
if (size > 0) {
|
|
const bool allow_zerocopy = (type == TCPFrameType::DATA || type == TCPFrameType::CALIBRATION);
|
|
if (!SendAll(c, data, size, allow_zerocopy, &zc_used, &zc_first, &zc_last)) {
|
|
if (z) {
|
|
if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last);
|
|
else z->release();
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (z) {
|
|
if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last);
|
|
else z->release();
|
|
}
|
|
|
|
if (type == TCPFrameType::DATA)
|
|
c.data_sent.fetch_add(1, std::memory_order_relaxed);
|
|
|
|
return true;
|
|
}
|
|
|
|
void TCPStreamPusher::ReleaseCompletedZeroCopy(Connection& c) {
|
|
while (!c.zc_pending.empty()) {
|
|
const auto& front = c.zc_pending.front();
|
|
if (c.zc_completed_id == std::numeric_limits<uint32_t>::max() || front.last_id > c.zc_completed_id)
|
|
break;
|
|
if (front.z)
|
|
front.z->release();
|
|
c.zc_pending.pop_front();
|
|
}
|
|
}
|
|
|
|
void TCPStreamPusher::EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id) {
|
|
std::unique_lock ul(c.zc_mutex);
|
|
c.zc_pending.push_back(Connection::PendingZC{
|
|
.first_id = first_id,
|
|
.last_id = last_id,
|
|
.z = z
|
|
});
|
|
c.zc_cv.notify_all();
|
|
}
|
|
|
|
void TCPStreamPusher::ForceReleasePendingZeroCopy(Connection& c) {
|
|
std::unique_lock ul(c.zc_mutex);
|
|
while (!c.zc_pending.empty()) {
|
|
auto p = c.zc_pending.front();
|
|
c.zc_pending.pop_front();
|
|
if (p.z)
|
|
p.z->release();
|
|
}
|
|
c.zc_cv.notify_all();
|
|
}
|
|
|
|
bool TCPStreamPusher::WaitForZeroCopyDrain(Connection& c, std::chrono::milliseconds timeout) {
|
|
std::unique_lock ul(c.zc_mutex);
|
|
return c.zc_cv.wait_for(ul, timeout, [&] {
|
|
return c.zc_pending.empty() || c.broken.load();
|
|
});
|
|
}
|
|
|
|
void TCPStreamPusher::ZeroCopyCompletionThread(Connection* c) {
|
|
#if defined(MSG_ZEROCOPY)
|
|
while (c->active || !c->zc_pending.empty()) {
|
|
const int local_fd = c->fd.load();
|
|
if (local_fd < 0)
|
|
break;
|
|
|
|
pollfd pfd{};
|
|
pfd.fd = local_fd;
|
|
pfd.events = POLLERR;
|
|
|
|
const int prc = poll(&pfd, 1, 100);
|
|
if (prc < 0) {
|
|
if (errno == EINTR)
|
|
continue;
|
|
break;
|
|
}
|
|
if (prc == 0)
|
|
continue;
|
|
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) == 0)
|
|
continue;
|
|
|
|
uint8_t control[512];
|
|
uint8_t data[1];
|
|
iovec iov{};
|
|
iov.iov_base = data;
|
|
iov.iov_len = sizeof(data);
|
|
|
|
msghdr msg{};
|
|
msg.msg_iov = &iov;
|
|
msg.msg_iovlen = 1;
|
|
msg.msg_control = control;
|
|
msg.msg_controllen = sizeof(control);
|
|
|
|
while (true) {
|
|
const ssize_t rr = recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT);
|
|
if (rr < 0) {
|
|
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
|
|
break;
|
|
c->broken = true;
|
|
break;
|
|
}
|
|
for (cmsghdr* cm = CMSG_FIRSTHDR(&msg); cm != nullptr; cm = CMSG_NXTHDR(&msg, cm)) {
|
|
if (cm->cmsg_level != SOL_IP || cm->cmsg_type != IP_RECVERR)
|
|
continue;
|
|
|
|
auto* se = reinterpret_cast<sock_extended_err*>(CMSG_DATA(cm));
|
|
if (!se || se->ee_origin != SO_EE_ORIGIN_ZEROCOPY)
|
|
continue;
|
|
|
|
const uint32_t end_id = se->ee_data;
|
|
std::unique_lock ul(c->zc_mutex);
|
|
if (c->zc_completed_id == std::numeric_limits<uint32_t>::max() || end_id > c->zc_completed_id)
|
|
c->zc_completed_id = end_id;
|
|
ReleaseCompletedZeroCopy(*c);
|
|
c->zc_cv.notify_all();
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
(void)c;
|
|
#endif
|
|
}
|
|
|
|
bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) {
|
|
auto* p = static_cast<uint8_t*>(buf);
|
|
size_t got = 0;
|
|
|
|
while (got < len) {
|
|
if (!c.active)
|
|
return false;
|
|
|
|
const int local_fd = c.fd.load();
|
|
if (local_fd < 0)
|
|
return false;
|
|
|
|
pollfd pfd{};
|
|
pfd.fd = local_fd;
|
|
pfd.events = POLLIN;
|
|
|
|
const int prc = poll(&pfd, 1, 100);
|
|
if (prc == 0)
|
|
continue;
|
|
if (prc < 0) {
|
|
if (errno == EINTR)
|
|
continue;
|
|
return false;
|
|
}
|
|
if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0)
|
|
return false;
|
|
if ((pfd.revents & POLLIN) == 0)
|
|
continue;
|
|
|
|
ssize_t rc = ::recv(local_fd, p + got, len - got, 0);
|
|
if (rc == 0)
|
|
return false;
|
|
if (rc < 0) {
|
|
if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK)
|
|
continue;
|
|
return false;
|
|
}
|
|
|
|
got += static_cast<size_t>(rc);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void TCPStreamPusher::WriterThread(Connection* c) {
|
|
while (c->active) {
|
|
auto e = c->queue.GetBlocking();
|
|
if (e.end)
|
|
break;
|
|
if (!e.z)
|
|
continue;
|
|
|
|
if (c->broken) {
|
|
e.z->release();
|
|
continue;
|
|
}
|
|
|
|
std::unique_lock ul(c->send_mutex);
|
|
if (!SendFrame(*c,
|
|
static_cast<const uint8_t*>(e.z->GetImage()),
|
|
e.z->GetImageSize(),
|
|
TCPFrameType::DATA,
|
|
e.z->GetImageNumber(),
|
|
e.z)) {
|
|
c->broken = true;
|
|
logger.Error("TCP send failed on socket " + std::to_string(c->socket_number));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TCPStreamPusher::AckThread(Connection* c) {
|
|
while (c->active) {
|
|
TcpFrameHeader h{};
|
|
if (!ReadExact(*c, &h, sizeof(h))) {
|
|
if (c->active) {
|
|
c->broken = true;
|
|
logger.Error("TCP ACK reader disconnected on socket " + std::to_string(c->socket_number));
|
|
}
|
|
break;
|
|
}
|
|
|
|
if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast<TCPFrameType>(h.type) != TCPFrameType::ACK) {
|
|
c->broken = true;
|
|
logger.Error("Invalid ACK frame on socket " + std::to_string(c->socket_number));
|
|
break;
|
|
}
|
|
|
|
std::string error_text;
|
|
if (h.payload_size > 0) {
|
|
error_text.resize(h.payload_size);
|
|
if (!ReadExact(*c, error_text.data(), error_text.size())) {
|
|
c->broken = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
const auto ack_for = static_cast<TCPFrameType>(h.ack_for);
|
|
const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0;
|
|
const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0;
|
|
const auto code = static_cast<TCPAckCode>(h.ack_code);
|
|
|
|
{
|
|
std::unique_lock ul(c->ack_mutex);
|
|
c->last_ack_code = code;
|
|
if (!error_text.empty())
|
|
c->last_ack_error = error_text;
|
|
|
|
if (ack_for == TCPFrameType::START) {
|
|
c->start_ack_received = true;
|
|
c->start_ack_ok = ok;
|
|
if (!ok && error_text.empty())
|
|
c->last_ack_error = "START rejected";
|
|
} else if (ack_for == TCPFrameType::END) {
|
|
c->end_ack_received = true;
|
|
c->end_ack_ok = ok;
|
|
if (!ok && error_text.empty())
|
|
c->last_ack_error = "END rejected";
|
|
} else if (ack_for == TCPFrameType::CANCEL) {
|
|
c->cancel_ack_received = true;
|
|
c->cancel_ack_ok = ok;
|
|
if (!ok && error_text.empty())
|
|
c->last_ack_error = "CANCEL rejected";
|
|
} else if (ack_for == TCPFrameType::DATA) {
|
|
c->data_acked_total.fetch_add(1, std::memory_order_relaxed);
|
|
total_data_acked_total.fetch_add(1, std::memory_order_relaxed);
|
|
if (ok && !fatal) {
|
|
c->data_acked_ok.fetch_add(1, std::memory_order_relaxed);
|
|
total_data_acked_ok.fetch_add(1, std::memory_order_relaxed);
|
|
} else {
|
|
c->data_acked_bad.fetch_add(1, std::memory_order_relaxed);
|
|
total_data_acked_bad.fetch_add(1, std::memory_order_relaxed);
|
|
|
|
// Soft failure: remember it for Finalize(), do NOT mark socket broken.
|
|
c->data_ack_error_reported = true;
|
|
if (!error_text.empty()) {
|
|
c->data_ack_error_text = error_text;
|
|
} else if (c->data_ack_error_text.empty()) {
|
|
c->data_ack_error_text = "DATA ACK failed";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
c->ack_cv.notify_all();
|
|
}
|
|
}
|
|
|
|
void TCPStreamPusher::StartConnectionThreads(Connection& c) {
|
|
{
|
|
std::unique_lock ul(c.ack_mutex);
|
|
c.start_ack_received = false;
|
|
c.start_ack_ok = false;
|
|
c.end_ack_received = false;
|
|
c.end_ack_ok = false;
|
|
c.cancel_ack_received = false;
|
|
c.cancel_ack_ok = false;
|
|
c.last_ack_error.clear();
|
|
c.last_ack_code = TCPAckCode::None;
|
|
c.data_ack_error_reported = false;
|
|
c.data_ack_error_text.clear();
|
|
}
|
|
|
|
c.data_sent.store(0, std::memory_order_relaxed);
|
|
c.data_acked_ok.store(0, std::memory_order_relaxed);
|
|
c.data_acked_bad.store(0, std::memory_order_relaxed);
|
|
c.data_acked_total.store(0, std::memory_order_relaxed);
|
|
c.zc_next_id = 0;
|
|
c.zc_completed_id = std::numeric_limits<uint32_t>::max();
|
|
{
|
|
std::unique_lock ul(c.zc_mutex);
|
|
c.zc_pending.clear();
|
|
}
|
|
|
|
c.active = true;
|
|
c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c);
|
|
c.ack_future = std::async(std::launch::async, &TCPStreamPusher::AckThread, this, &c);
|
|
c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c);
|
|
}
|
|
|
|
void TCPStreamPusher::StopConnectionThreads(Connection& c) {
|
|
if (!c.active)
|
|
return;
|
|
|
|
c.active = false;
|
|
c.queue.PutBlocking({.end = true});
|
|
c.ack_cv.notify_all();
|
|
c.zc_cv.notify_all();
|
|
|
|
if (c.writer_future.valid())
|
|
c.writer_future.get();
|
|
|
|
constexpr auto zc_drain_timeout = std::chrono::seconds(2);
|
|
if (!WaitForZeroCopyDrain(c, zc_drain_timeout)) {
|
|
logger.Warning("TCP zerocopy completion drain timeout on socket " + std::to_string(c.socket_number) +
|
|
"; forcing socket close");
|
|
c.broken = true;
|
|
CloseFd(c.fd);
|
|
c.zc_cv.notify_all();
|
|
c.ack_cv.notify_all();
|
|
}
|
|
|
|
if (c.ack_future.valid())
|
|
c.ack_future.get();
|
|
if (c.zc_future.valid())
|
|
c.zc_future.get();
|
|
|
|
ForceReleasePendingZeroCopy(c);
|
|
}
|
|
|
|
bool TCPStreamPusher::WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text) {
|
|
std::unique_lock ul(c.ack_mutex);
|
|
const bool ok = c.ack_cv.wait_for(ul, timeout, [&] {
|
|
if (ack_for == TCPFrameType::START)
|
|
return c.start_ack_received || c.broken.load();
|
|
if (ack_for == TCPFrameType::END)
|
|
return c.end_ack_received || c.broken.load();
|
|
if (ack_for == TCPFrameType::CANCEL)
|
|
return c.cancel_ack_received || c.broken.load();
|
|
return false;
|
|
});
|
|
|
|
if (!ok) {
|
|
if (error_text) *error_text = "ACK timeout";
|
|
return false;
|
|
}
|
|
|
|
if (c.broken) {
|
|
if (error_text) *error_text = c.last_ack_error.empty() ? "Socket broken" : c.last_ack_error;
|
|
return false;
|
|
}
|
|
|
|
bool ack_ok = false;
|
|
if (ack_for == TCPFrameType::START) ack_ok = c.start_ack_ok;
|
|
if (ack_for == TCPFrameType::END) ack_ok = c.end_ack_ok;
|
|
if (ack_for == TCPFrameType::CANCEL) ack_ok = c.cancel_ack_ok;
|
|
|
|
if (!ack_ok && error_text)
|
|
*error_text = c.last_ack_error.empty() ? "ACK rejected" : c.last_ack_error;
|
|
|
|
return ack_ok;
|
|
}
|
|
|
|
void TCPStreamPusher::StartDataCollection(StartMessage& message) {
|
|
if (message.images_per_file < 1)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative");
|
|
|
|
images_per_file = message.images_per_file;
|
|
run_number = message.run_number;
|
|
run_name = message.run_name;
|
|
transmission_error = false;
|
|
|
|
total_data_acked_ok.store(0, std::memory_order_relaxed);
|
|
total_data_acked_bad.store(0, std::memory_order_relaxed);
|
|
total_data_acked_total.store(0, std::memory_order_relaxed);
|
|
|
|
for (auto& c : connections) {
|
|
StopConnectionThreads(*c);
|
|
CloseFd(c->fd);
|
|
}
|
|
connections.clear();
|
|
connections.reserve(expected_connections);
|
|
|
|
int listen_fd = OpenListenSocket(endpoint);
|
|
try {
|
|
for (size_t i = 0; i < expected_connections; i++) {
|
|
int new_fd = AcceptOne(listen_fd, std::chrono::seconds(5));
|
|
if (new_fd < 0)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on " + endpoint);
|
|
|
|
auto c = std::make_unique<Connection>(send_queue_size);
|
|
c->socket_number = static_cast<uint32_t>(i);
|
|
c->fd.store(new_fd);
|
|
|
|
int one = 1;
|
|
setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
|
|
if (send_buffer_size)
|
|
setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t));
|
|
|
|
#if defined(SO_ZEROCOPY)
|
|
int zc_one = 1;
|
|
if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) {
|
|
c->zerocopy_enabled.store(true, std::memory_order_relaxed);
|
|
} else {
|
|
c->zerocopy_enabled.store(false, std::memory_order_relaxed);
|
|
}
|
|
#endif
|
|
connections.emplace_back(std::move(c));
|
|
}
|
|
} catch (...) {
|
|
close(listen_fd);
|
|
throw;
|
|
}
|
|
close(listen_fd);
|
|
|
|
for (auto& c : connections)
|
|
StartConnectionThreads(*c);
|
|
|
|
std::vector<bool> started(connections.size(), false);
|
|
|
|
auto rollback_cancel = [&]() {
|
|
for (size_t i = 0; i < connections.size(); i++) {
|
|
auto& c = *connections[i];
|
|
if (!started[i] || c.broken)
|
|
continue;
|
|
|
|
std::unique_lock ul(c.send_mutex);
|
|
(void)SendFrame(c, nullptr, 0, TCPFrameType::CANCEL, -1, nullptr);
|
|
|
|
std::string cancel_ack_err;
|
|
(void)WaitForAck(c, TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err);
|
|
}
|
|
|
|
for (auto& c : connections)
|
|
StopConnectionThreads(*c);
|
|
};
|
|
|
|
for (size_t i = 0; i < connections.size(); i++) {
|
|
auto& c = *connections[i];
|
|
|
|
message.socket_number = static_cast<int64_t>(i);
|
|
message.write_master_file = (i == 0);
|
|
|
|
serializer.SerializeSequenceStart(message);
|
|
|
|
{
|
|
std::unique_lock ul(c.send_mutex);
|
|
if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) {
|
|
rollback_cancel();
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"Timeout/failure sending START on socket " + std::to_string(i));
|
|
}
|
|
}
|
|
|
|
std::string ack_err;
|
|
if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) {
|
|
rollback_cancel();
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"START ACK failed on socket " + std::to_string(i) + ": " + ack_err);
|
|
}
|
|
|
|
started[i] = true;
|
|
}
|
|
}
|
|
|
|
bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) {
|
|
if (connections.empty())
|
|
return false;
|
|
|
|
auto idx = static_cast<size_t>((image_number / images_per_file) % static_cast<int64_t>(connections.size()));
|
|
auto& c = *connections[idx];
|
|
|
|
if (c.broken || !IsConnectionAlive(c))
|
|
return false;
|
|
|
|
std::unique_lock ul(c.send_mutex);
|
|
return SendFrame(c, image_data, image_size, TCPFrameType::DATA, image_number, nullptr);
|
|
}
|
|
|
|
void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) {
|
|
if (connections.empty()) {
|
|
z.release();
|
|
return;
|
|
}
|
|
|
|
auto idx = static_cast<size_t>((z.GetImageNumber() / images_per_file) % static_cast<int64_t>(connections.size()));
|
|
auto& c = *connections[idx];
|
|
|
|
if (c.broken) {
|
|
z.release();
|
|
return;
|
|
}
|
|
|
|
c.queue.PutBlocking(ImagePusherQueueElement{
|
|
.image_data = static_cast<uint8_t *>(z.GetImage()),
|
|
.z = &z,
|
|
.end = false
|
|
});
|
|
}
|
|
|
|
bool TCPStreamPusher::EndDataCollection(const EndMessage& message) {
|
|
serializer.SerializeSequenceEnd(message);
|
|
|
|
bool ret = true;
|
|
for (auto& cptr : connections) {
|
|
auto& c = *cptr;
|
|
if (c.broken) {
|
|
ret = false;
|
|
continue;
|
|
}
|
|
|
|
{
|
|
std::unique_lock ul(c.send_mutex);
|
|
if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END, -1, nullptr)) {
|
|
ret = false;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
std::string ack_err;
|
|
if (!WaitForAck(c, TCPFrameType::END, std::chrono::seconds(10), &ack_err))
|
|
ret = false;
|
|
}
|
|
|
|
for (auto& c : connections)
|
|
StopConnectionThreads(*c);
|
|
|
|
transmission_error = !ret;
|
|
return ret;
|
|
}
|
|
|
|
bool TCPStreamPusher::SendCalibration(const CompressedImage& message) {
|
|
if (connections.empty())
|
|
return false;
|
|
|
|
serializer.SerializeCalibration(message);
|
|
|
|
auto& c = *connections[0];
|
|
if (c.broken)
|
|
return false;
|
|
|
|
std::unique_lock ul(c.send_mutex);
|
|
return SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION, -1, nullptr);
|
|
}
|
|
|
|
std::string TCPStreamPusher::Finalize() {
|
|
std::string ret;
|
|
if (transmission_error)
|
|
ret += "Timeout sending images (e.g., writer disabled during data collection);";
|
|
|
|
for (size_t i = 0; i < connections.size(); i++) {
|
|
auto& c = *connections[i];
|
|
{
|
|
std::unique_lock ul(c.ack_mutex);
|
|
|
|
if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) {
|
|
ret += "Writer " + std::to_string(i) + ": " + c.data_ack_error_text + ";";
|
|
} else if (!c.last_ack_error.empty()) {
|
|
ret += "Writer " + std::to_string(i) + ": " + c.last_ack_error + ";";
|
|
}
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
std::string TCPStreamPusher::PrintSetup() const {
|
|
return "TCPStreamPusher: endpoint=" + endpoint + " expected_connections=" + std::to_string(expected_connections);
|
|
}
|
|
|
|
std::optional<uint64_t> TCPStreamPusher::GetImagesWritten() const {
|
|
return total_data_acked_ok.load(std::memory_order_relaxed);
|
|
} |