diff --git a/image_pusher/CMakeLists.txt b/image_pusher/CMakeLists.txt index 7da43fb3..a2e14370 100644 --- a/image_pusher/CMakeLists.txt +++ b/image_pusher/CMakeLists.txt @@ -9,8 +9,6 @@ ADD_LIBRARY(ImagePusher STATIC NonePusher.h ZMQStream2PusherSocket.cpp ZMQStream2PusherSocket.h - TCPStreamPusherSocket.cpp - TCPStreamPusherSocket.h TCPStreamPusher.cpp TCPStreamPusher.h ) diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index 75066630..22b48383 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -1,77 +1,691 @@ -// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute -// SPDX-License-Identifier: GPL-3.0-only - #include "TCPStreamPusher.h" +#include +#include +#include +#include +#include +#include +#include +#if defined(MSG_ZEROCOPY) +#include +#endif -TCPStreamPusher::TCPStreamPusher(const std::vector &addr, - std::optional send_buffer_size, - std::optional zerocopy_threshold, - size_t send_queue_size) - : serialization_buffer(256 * 1024 * 1024), - serializer(serialization_buffer.data(), serialization_buffer.size()) { - if (addr.empty()) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); +std::pair TCPStreamPusher::ParseTcpAddress(const std::string& addr) { + const std::string prefix = "tcp://"; + if (addr.rfind(prefix, 0) != 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); - for (size_t i = 0; i < addr.size(); i++) { - socket.emplace_back(std::make_unique( - addr[i], static_cast(i), send_buffer_size, zerocopy_threshold, send_queue_size)); + auto hp = addr.substr(prefix.size()); + auto p = hp.find_last_of(':'); + if (p == std::string::npos) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); + + const auto host = hp.substr(0, p); + const auto port_str = hp.substr(p + 1); + + int port_i = 0; + try { + size_t parsed = 0; + port_i = std::stoi(port_str, &parsed); + if (parsed != port_str.size()) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); + } catch (...) { + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); + } + + if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); + + return {host, static_cast(port_i)}; +} + +int TCPStreamPusher::OpenListenSocket(const std::string& addr) { + auto [host, port] = ParseTcpAddress(addr); + + int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed"); + + int one = 1; + setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + sockaddr_in sin{}; + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + if (host == "*" || host == "0.0.0.0") + sin.sin_addr.s_addr = htonl(INADDR_ANY); + else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host); + + if (bind(listen_fd, reinterpret_cast(&sin), sizeof(sin)) != 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr); + + if (listen(listen_fd, 64) != 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr); + + return listen_fd; +} + +int TCPStreamPusher::AcceptOne(int listen_fd, std::chrono::milliseconds timeout) { + pollfd pfd{}; + pfd.fd = listen_fd; + pfd.events = POLLIN; + + const int prc = poll(&pfd, 1, static_cast(timeout.count())); + if (prc <= 0) + return -1; + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return -1; + + return accept(listen_fd, nullptr, nullptr); +} + +void TCPStreamPusher::CloseFd(std::atomic& fd) { + int old_fd = fd.exchange(-1); + if (old_fd >= 0) { + shutdown(old_fd, SHUT_RDWR); + close(old_fd); } } -void TCPStreamPusher::StartDataCollection(StartMessage &message) { +TCPStreamPusher::TCPStreamPusher(const std::string& addr, + size_t in_expected_connections, + std::optional in_send_buffer_size) + : serialization_buffer(256 * 1024 * 1024), + serializer(serialization_buffer.data(), serialization_buffer.size()), + endpoint(addr), + expected_connections(in_expected_connections), + send_buffer_size(in_send_buffer_size) { + if (endpoint.empty()) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); + if (expected_connections == 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Expected TCP connections cannot be zero"); +} + +TCPStreamPusher::~TCPStreamPusher() { + for (auto& c : connections) { + StopConnectionThreads(*c); + CloseFd(c->fd); + } +} + +bool TCPStreamPusher::IsConnectionAlive(const Connection& c) const { + if (c.broken) + return false; + + const int local_fd = c.fd.load(); + if (local_fd < 0) + return false; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLOUT; + if (poll(&pfd, 1, 0) < 0) + return false; + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return false; + + int so_error = 0; + socklen_t len = sizeof(so_error); + if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) + return false; + + return so_error == 0; +} + +bool TCPStreamPusher::SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy, + bool* zc_used_out, uint32_t* zc_first_out, uint32_t* zc_last_out) { + const auto* p = static_cast(buf); + size_t sent = 0; + bool zc_used = false; + uint32_t zc_first = 0; + uint32_t zc_last = 0; + + bool try_zerocopy = false; +#if defined(MSG_ZEROCOPY) + try_zerocopy = allow_zerocopy && c.zerocopy_enabled.load(std::memory_order_relaxed); +#endif + + while (sent < len) { + const int local_fd = c.fd.load(); + if (local_fd < 0 || c.broken) { + if (zc_used_out) *zc_used_out = zc_used; + if (zc_first_out) *zc_first_out = zc_first; + if (zc_last_out) *zc_last_out = zc_last; + return false; + } + + int flags = MSG_NOSIGNAL; +#if defined(MSG_ZEROCOPY) + if (try_zerocopy) + flags |= MSG_ZEROCOPY; +#endif + + ssize_t rc = ::send(local_fd, p + sent, len - sent, flags); + if (rc < 0) { + if (errno == EINTR) + continue; + +#if defined(MSG_ZEROCOPY) + if (try_zerocopy && (errno == EOPNOTSUPP || errno == EINVAL || errno == ENOBUFS)) { + try_zerocopy = false; + c.zerocopy_enabled.store(false, std::memory_order_relaxed); + continue; + } +#endif + + if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { + c.broken = true; + CloseFd(c.fd); + } + + if (zc_used_out) *zc_used_out = zc_used; + if (zc_first_out) *zc_first_out = zc_first; + if (zc_last_out) *zc_last_out = zc_last; + return false; + } + +#if defined(MSG_ZEROCOPY) + if (try_zerocopy && rc > 0) { + const uint32_t this_id = c.zc_next_id++; + if (!zc_used) { + zc_used = true; + zc_first = this_id; + } + zc_last = this_id; + } +#endif + + sent += static_cast(rc); + } + + if (zc_used_out) *zc_used_out = zc_used; + if (zc_first_out) *zc_first_out = zc_first; + if (zc_last_out) *zc_last_out = zc_last; + return true; +} + +bool TCPStreamPusher::SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z) { + TcpFrameHeader h{}; + h.type = static_cast(type); + h.payload_size = size; + h.image_number = image_number >= 0 ? static_cast(image_number) : 0; + h.socket_number = c.socket_number; + h.run_number = run_number; + + if (!SendAll(c, &h, sizeof(h), false)) { + if (z) z->release(); + return false; + } + + bool zc_used = false; + uint32_t zc_first = 0; + uint32_t zc_last = 0; + + if (size > 0) { + const bool allow_zerocopy = (type == TCPFrameType::DATA || type == TCPFrameType::CALIBRATION); + if (!SendAll(c, data, size, allow_zerocopy, &zc_used, &zc_first, &zc_last)) { + if (z) { + if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); + else z->release(); + } + return false; + } + } + + if (z) { + if (zc_used) EnqueueZeroCopyPending(c, z, zc_first, zc_last); + else z->release(); + } + + if (type == TCPFrameType::DATA) + c.data_sent.fetch_add(1, std::memory_order_relaxed); + + return true; +} + +void TCPStreamPusher::ReleaseCompletedZeroCopy(Connection& c) { + while (!c.zc_pending.empty()) { + const auto& front = c.zc_pending.front(); + if (c.zc_completed_id == std::numeric_limits::max() || front.last_id > c.zc_completed_id) + break; + if (front.z) + front.z->release(); + c.zc_pending.pop_front(); + } +} + +void TCPStreamPusher::EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id) { + std::unique_lock ul(c.zc_mutex); + c.zc_pending.push_back(Connection::PendingZC{ + .first_id = first_id, + .last_id = last_id, + .z = z + }); + c.zc_cv.notify_all(); +} + +void TCPStreamPusher::ForceReleasePendingZeroCopy(Connection& c) { + std::unique_lock ul(c.zc_mutex); + while (!c.zc_pending.empty()) { + auto p = c.zc_pending.front(); + c.zc_pending.pop_front(); + if (p.z) + p.z->release(); + } + c.zc_cv.notify_all(); +} + +bool TCPStreamPusher::WaitForZeroCopyDrain(Connection& c, std::chrono::milliseconds timeout) { + std::unique_lock ul(c.zc_mutex); + return c.zc_cv.wait_for(ul, timeout, [&] { + return c.zc_pending.empty() || c.broken.load(); + }); +} + +void TCPStreamPusher::ZeroCopyCompletionThread(Connection* c) { +#if defined(MSG_ZEROCOPY) + while (c->active || !c->zc_pending.empty()) { + const int local_fd = c->fd.load(); + if (local_fd < 0) + break; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLERR; + + const int prc = poll(&pfd, 1, 100); + if (prc < 0) { + if (errno == EINTR) + continue; + break; + } + if (prc == 0) + continue; + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) == 0) + continue; + + uint8_t control[512]; + uint8_t data[1]; + iovec iov{}; + iov.iov_base = data; + iov.iov_len = sizeof(data); + + msghdr msg{}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + while (true) { + const ssize_t rr = recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT); + if (rr < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) + break; + c->broken = true; + break; + } + for (cmsghdr* cm = CMSG_FIRSTHDR(&msg); cm != nullptr; cm = CMSG_NXTHDR(&msg, cm)) { + if (cm->cmsg_level != SOL_IP || cm->cmsg_type != IP_RECVERR) + continue; + + auto* se = reinterpret_cast(CMSG_DATA(cm)); + if (!se || se->ee_origin != SO_EE_ORIGIN_ZEROCOPY) + continue; + + const uint32_t end_id = se->ee_data; + std::unique_lock ul(c->zc_mutex); + if (c->zc_completed_id == std::numeric_limits::max() || end_id > c->zc_completed_id) + c->zc_completed_id = end_id; + ReleaseCompletedZeroCopy(*c); + c->zc_cv.notify_all(); + } + } + } +#else + (void)c; +#endif +} + +bool TCPStreamPusher::ReadExact(Connection& c, void* buf, size_t len) { + auto* p = static_cast(buf); + size_t got = 0; + + while (got < len) { + if (!c.active) + return false; + + const int local_fd = c.fd.load(); + if (local_fd < 0) + return false; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLIN; + + const int prc = poll(&pfd, 1, 100); + if (prc == 0) + continue; + if (prc < 0) { + if (errno == EINTR) + continue; + return false; + } + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return false; + if ((pfd.revents & POLLIN) == 0) + continue; + + ssize_t rc = ::recv(local_fd, p + got, len - got, 0); + if (rc == 0) + return false; + if (rc < 0) { + if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) + continue; + return false; + } + + got += static_cast(rc); + } + + return true; +} + +void TCPStreamPusher::WriterThread(Connection* c) { + while (c->active) { + auto e = c->queue.GetBlocking(); + if (e.end) + break; + if (!e.z) + continue; + + if (c->broken) { + e.z->release(); + continue; + } + + std::unique_lock ul(c->send_mutex); + if (!SendFrame(*c, + static_cast(e.z->GetImage()), + e.z->GetImageSize(), + TCPFrameType::DATA, + e.z->GetImageNumber(), + e.z)) { + c->broken = true; + logger.Error("TCP send failed on socket " + std::to_string(c->socket_number)); + } + } +} + +void TCPStreamPusher::AckThread(Connection* c) { + while (c->active) { + TcpFrameHeader h{}; + if (!ReadExact(*c, &h, sizeof(h))) { + if (c->active) { + c->broken = true; + logger.Error("TCP ACK reader disconnected on socket " + std::to_string(c->socket_number)); + } + break; + } + + if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { + c->broken = true; + logger.Error("Invalid ACK frame on socket " + std::to_string(c->socket_number)); + break; + } + + std::string error_text; + if (h.payload_size > 0) { + error_text.resize(h.payload_size); + if (!ReadExact(*c, error_text.data(), error_text.size())) { + c->broken = true; + break; + } + } + + const auto ack_for = static_cast(h.ack_for); + const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; + const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; + const auto code = static_cast(h.ack_code); + + { + std::unique_lock ul(c->ack_mutex); + c->last_ack_code = code; + if (!error_text.empty()) + c->last_ack_error = error_text; + + if (ack_for == TCPFrameType::START) { + c->start_ack_received = true; + c->start_ack_ok = ok; + if (!ok && error_text.empty()) + c->last_ack_error = "START rejected"; + } else if (ack_for == TCPFrameType::END) { + c->end_ack_received = true; + c->end_ack_ok = ok; + if (!ok && error_text.empty()) + c->last_ack_error = "END rejected"; + } else if (ack_for == TCPFrameType::CANCEL) { + c->cancel_ack_received = true; + c->cancel_ack_ok = ok; + if (!ok && error_text.empty()) + c->last_ack_error = "CANCEL rejected"; + } else if (ack_for == TCPFrameType::DATA) { + c->data_acked_total.fetch_add(1, std::memory_order_relaxed); + total_data_acked_total.fetch_add(1, std::memory_order_relaxed); + if (ok && !fatal) { + c->data_acked_ok.fetch_add(1, std::memory_order_relaxed); + total_data_acked_ok.fetch_add(1, std::memory_order_relaxed); + } else { + c->data_acked_bad.fetch_add(1, std::memory_order_relaxed); + total_data_acked_bad.fetch_add(1, std::memory_order_relaxed); + + // Soft failure: remember it for Finalize(), do NOT mark socket broken. + c->data_ack_error_reported = true; + if (!error_text.empty()) { + c->data_ack_error_text = error_text; + } else if (c->data_ack_error_text.empty()) { + c->data_ack_error_text = "DATA ACK failed"; + } + } + } + } + + c->ack_cv.notify_all(); + } +} + +void TCPStreamPusher::StartConnectionThreads(Connection& c) { + { + std::unique_lock ul(c.ack_mutex); + c.start_ack_received = false; + c.start_ack_ok = false; + c.end_ack_received = false; + c.end_ack_ok = false; + c.cancel_ack_received = false; + c.cancel_ack_ok = false; + c.last_ack_error.clear(); + c.last_ack_code = TCPAckCode::None; + c.data_ack_error_reported = false; + c.data_ack_error_text.clear(); + } + + c.data_sent.store(0, std::memory_order_relaxed); + c.data_acked_ok.store(0, std::memory_order_relaxed); + c.data_acked_bad.store(0, std::memory_order_relaxed); + c.data_acked_total.store(0, std::memory_order_relaxed); + c.zc_next_id = 0; + c.zc_completed_id = std::numeric_limits::max(); + { + std::unique_lock ul(c.zc_mutex); + c.zc_pending.clear(); + } + + c.active = true; + c.writer_future = std::async(std::launch::async, &TCPStreamPusher::WriterThread, this, &c); + c.ack_future = std::async(std::launch::async, &TCPStreamPusher::AckThread, this, &c); + c.zc_future = std::async(std::launch::async, &TCPStreamPusher::ZeroCopyCompletionThread, this, &c); +} + +void TCPStreamPusher::StopConnectionThreads(Connection& c) { + if (!c.active) + return; + + c.active = false; + c.queue.PutBlocking({.end = true}); + c.ack_cv.notify_all(); + c.zc_cv.notify_all(); + + if (c.writer_future.valid()) + c.writer_future.get(); + + constexpr auto zc_drain_timeout = std::chrono::seconds(2); + if (!WaitForZeroCopyDrain(c, zc_drain_timeout)) { + logger.Warning("TCP zerocopy completion drain timeout on socket " + std::to_string(c.socket_number) + + "; forcing socket close"); + c.broken = true; + CloseFd(c.fd); + c.zc_cv.notify_all(); + c.ack_cv.notify_all(); + } + + if (c.ack_future.valid()) + c.ack_future.get(); + if (c.zc_future.valid()) + c.zc_future.get(); + + ForceReleasePendingZeroCopy(c); +} + +bool TCPStreamPusher::WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text) { + std::unique_lock ul(c.ack_mutex); + const bool ok = c.ack_cv.wait_for(ul, timeout, [&] { + if (ack_for == TCPFrameType::START) + return c.start_ack_received || c.broken.load(); + if (ack_for == TCPFrameType::END) + return c.end_ack_received || c.broken.load(); + if (ack_for == TCPFrameType::CANCEL) + return c.cancel_ack_received || c.broken.load(); + return false; + }); + + if (!ok) { + if (error_text) *error_text = "ACK timeout"; + return false; + } + + if (c.broken) { + if (error_text) *error_text = c.last_ack_error.empty() ? "Socket broken" : c.last_ack_error; + return false; + } + + bool ack_ok = false; + if (ack_for == TCPFrameType::START) ack_ok = c.start_ack_ok; + if (ack_for == TCPFrameType::END) ack_ok = c.end_ack_ok; + if (ack_for == TCPFrameType::CANCEL) ack_ok = c.cancel_ack_ok; + + if (!ack_ok && error_text) + *error_text = c.last_ack_error.empty() ? "ACK rejected" : c.last_ack_error; + + return ack_ok; +} + +void TCPStreamPusher::StartDataCollection(StartMessage& message) { if (message.images_per_file < 1) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "Images per file cannot be zero or negative"); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative"); + images_per_file = message.images_per_file; run_number = message.run_number; run_name = message.run_name; transmission_error = false; - for (size_t i = 0; i < socket.size(); i++) { - if (!socket[i]->AcceptConnection(std::chrono::seconds(5))) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "TCP accept timeout/failure on socket " + socket[i]->GetEndpointName()); + total_data_acked_ok.store(0, std::memory_order_relaxed); + total_data_acked_bad.store(0, std::memory_order_relaxed); + total_data_acked_total.store(0, std::memory_order_relaxed); + + for (auto& c : connections) { + StopConnectionThreads(*c); + CloseFd(c->fd); } + connections.clear(); + connections.reserve(expected_connections); - for (auto &s : socket) - s->StartWriterThread(); + int listen_fd = OpenListenSocket(endpoint); + try { + for (size_t i = 0; i < expected_connections; i++) { + int new_fd = AcceptOne(listen_fd, std::chrono::seconds(5)); + if (new_fd < 0) + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on " + endpoint); - std::vector started(socket.size(), false); + auto c = std::make_unique(send_queue_size); + c->socket_number = static_cast(i); + c->fd.store(new_fd); + + int one = 1; + setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + if (send_buffer_size) + setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); + +#if defined(SO_ZEROCOPY) + int zc_one = 1; + if (setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &zc_one, sizeof(zc_one)) == 0) { + c->zerocopy_enabled.store(true, std::memory_order_relaxed); + } else { + c->zerocopy_enabled.store(false, std::memory_order_relaxed); + } +#endif + connections.emplace_back(std::move(c)); + } + } catch (...) { + close(listen_fd); + throw; + } + close(listen_fd); + + for (auto& c : connections) + StartConnectionThreads(*c); + + std::vector started(connections.size(), false); auto rollback_cancel = [&]() { - for (size_t i = 0; i < socket.size(); i++) { - if (!started[i] || socket[i]->IsBroken()) + for (size_t i = 0; i < connections.size(); i++) { + auto& c = *connections[i]; + if (!started[i] || c.broken) continue; - (void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL); + std::unique_lock ul(c.send_mutex); + (void)SendFrame(c, nullptr, 0, TCPFrameType::CANCEL, -1, nullptr); + std::string cancel_ack_err; - (void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); + (void)WaitForAck(c, TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); } - for (auto &s : socket) - s->StopWriterThread(); + for (auto& c : connections) + StopConnectionThreads(*c); }; - for (size_t i = 0; i < socket.size(); i++) { + for (size_t i = 0; i < connections.size(); i++) { + auto& c = *connections[i]; + message.socket_number = static_cast(i); - if (i > 0) - message.write_master_file = false; + message.write_master_file = (i == 0); serializer.SerializeSequenceStart(message); - socket[i]->SetRunNumber(run_number); - if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { - rollback_cancel(); - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "Timeout/failure sending START on " + socket[i]->GetEndpointName()); + { + std::unique_lock ul(c.send_mutex); + if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START, -1, nullptr)) { + rollback_cancel(); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "Timeout/failure sending START on socket " + std::to_string(i)); + } } std::string ack_err; - if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { + if (!WaitForAck(c, TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err); + "START ACK failed on socket " + std::to_string(i) + ": " + ack_err); } started[i] = true; @@ -79,100 +693,110 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) { } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { - if (socket.empty()) + if (connections.empty()) return false; - auto socket_number = (image_number / images_per_file) % socket.size(); - if (socket[socket_number]->IsBroken()) + auto idx = static_cast((image_number / images_per_file) % static_cast(connections.size())); + auto& c = *connections[idx]; + + if (c.broken || !IsConnectionAlive(c)) return false; - return socket[socket_number]->Send(image_data, image_size, TCPFrameType::DATA, image_number); + std::unique_lock ul(c.send_mutex); + return SendFrame(c, image_data, image_size, TCPFrameType::DATA, image_number, nullptr); } void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) { - if (socket.empty()) { + if (connections.empty()) { z.release(); return; } - auto socket_number = (z.GetImageNumber() / images_per_file) % socket.size(); - if (socket[socket_number]->IsBroken()) { + auto idx = static_cast((z.GetImageNumber() / images_per_file) % static_cast(connections.size())); + auto& c = *connections[idx]; + + if (c.broken) { z.release(); return; } - socket[socket_number]->SendImage(z); + c.queue.PutBlocking(ImagePusherQueueElement{ + .image_data = static_cast(z.GetImage()), + .z = &z, + .end = false + }); } -bool TCPStreamPusher::EndDataCollection(const EndMessage &message) { +bool TCPStreamPusher::EndDataCollection(const EndMessage& message) { serializer.SerializeSequenceEnd(message); bool ret = true; - for (auto &s : socket) { - if (s->IsBroken()) { + for (auto& cptr : connections) { + auto& c = *cptr; + if (c.broken) { ret = false; continue; } - if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) { - ret = false; - continue; + { + std::unique_lock ul(c.send_mutex); + if (!SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END, -1, nullptr)) { + ret = false; + continue; + } } std::string ack_err; - if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) { + if (!WaitForAck(c, TCPFrameType::END, std::chrono::seconds(10), &ack_err)) ret = false; - } } - for (auto &s : socket) - s->StopWriterThread(); + for (auto& c : connections) + StopConnectionThreads(*c); transmission_error = !ret; return ret; } +bool TCPStreamPusher::SendCalibration(const CompressedImage& message) { + if (connections.empty()) + return false; + + serializer.SerializeCalibration(message); + + auto& c = *connections[0]; + if (c.broken) + return false; + + std::unique_lock ul(c.send_mutex); + return SendFrame(c, serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION, -1, nullptr); +} + std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; - for (size_t i = 0; i < socket.size(); i++) { - if (socket[i]->IsBroken()) { - const auto reason = socket[i]->GetLastAckError(); - ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";"; + for (size_t i = 0; i < connections.size(); i++) { + auto& c = *connections[i]; + { + std::unique_lock ul(c.ack_mutex); + + if (c.data_ack_error_reported && !c.data_ack_error_text.empty()) { + ret += "Writer " + std::to_string(i) + ": " + c.data_ack_error_text + ";"; + } else if (!c.last_ack_error.empty()) { + ret += "Writer " + std::to_string(i) + ": " + c.last_ack_error + ";"; + } } } + return ret; } std::string TCPStreamPusher::PrintSetup() const { - std::string output = "TCPStream2Pusher: Sending images to sockets: "; - for (const auto &s : socket) - output += s->GetEndpointName() + " "; - return output; -} - -std::string TCPStreamPusherSocket::GetEndpointName() const { - return endpoint; -} - -void TCPStreamPusherSocket::SetRunNumber(uint64_t in_run_number) { - run_number = in_run_number; -} - -bool TCPStreamPusher::SendCalibration(const CompressedImage &message) { - if (socket.empty()) - return false; - serializer.SerializeCalibration(message); - return socket[0]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION); + return "TCPStreamPusher: endpoint=" + endpoint + " expected_connections=" + std::to_string(expected_connections); } std::optional TCPStreamPusher::GetImagesWritten() const { - uint64_t ret = 0; - for (const auto &s : socket) { - auto p = s->GetDataAckProgress(); - ret += p.data_acked_ok; - } - return ret; -} + return total_data_acked_ok.load(std::memory_order_relaxed); +} \ No newline at end of file diff --git a/image_pusher/TCPStreamPusher.h b/image_pusher/TCPStreamPusher.h index 5451b714..f006c60d 100644 --- a/image_pusher/TCPStreamPusher.h +++ b/image_pusher/TCPStreamPusher.h @@ -3,23 +3,125 @@ #pragma once -#include "TCPStreamPusherSocket.h" +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ImagePusher.h" #include "ZMQWriterNotificationPuller.h" +#include "../common/ThreadSafeFIFO.h" +#include "../common/Logger.h" +#include "../common/JfjochTCP.h" class TCPStreamPusher : public ImagePusher { + struct Connection { + explicit Connection(size_t queue_size) : queue(queue_size) {} + + std::atomic fd{-1}; + uint32_t socket_number = 0; + std::atomic active{false}; + std::atomic broken{false}; + std::atomic zerocopy_enabled{false}; + + ThreadSafeFIFO queue; + std::future writer_future; + std::future ack_future; + std::future zc_future; + + std::mutex send_mutex; + std::mutex ack_mutex; + std::condition_variable ack_cv; + + struct PendingZC { + uint32_t first_id = 0; + uint32_t last_id = 0; + ZeroCopyReturnValue* z = nullptr; + }; + + std::mutex zc_mutex; + std::condition_variable zc_cv; + std::deque zc_pending; + uint32_t zc_next_id = 0; + uint32_t zc_completed_id = std::numeric_limits::max(); + + bool start_ack_received = false; + bool start_ack_ok = false; + bool end_ack_received = false; + bool end_ack_ok = false; + bool cancel_ack_received = false; + bool cancel_ack_ok = false; + + std::string last_ack_error; + std::atomic last_ack_code{TCPAckCode::None}; + + // Soft writer failure reported via DATA ACK (do not break stream on this alone) + std::atomic data_ack_error_reported{false}; + std::string data_ack_error_text; + + std::atomic data_sent{0}; + std::atomic data_acked_ok{0}; + std::atomic data_acked_bad{0}; + std::atomic data_acked_total{0}; + }; + std::vector serialization_buffer; CBORStream2Serializer serializer; - std::vector> socket; + + std::string endpoint; + size_t expected_connections = 0; + std::optional send_buffer_size; + size_t send_queue_size = 128; + + std::vector> connections; int64_t images_per_file = 1; uint64_t run_number = 0; std::string run_name; std::atomic transmission_error = false; + + std::atomic total_data_acked_ok{0}; + std::atomic total_data_acked_bad{0}; + std::atomic total_data_acked_total{0}; + + Logger logger{"TCPStreamPusher"}; + + static std::pair ParseTcpAddress(const std::string& addr); + static int OpenListenSocket(const std::string& addr); + static int AcceptOne(int listen_fd, std::chrono::milliseconds timeout); + + static void CloseFd(std::atomic& fd); + bool IsConnectionAlive(const Connection& c) const; + bool SendAll(Connection& c, const void* buf, size_t len, bool allow_zerocopy, + bool* zc_used = nullptr, uint32_t* zc_first = nullptr, uint32_t* zc_last = nullptr); + bool ReadExact(Connection& c, void* buf, size_t len); + bool SendFrame(Connection& c, const uint8_t* data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue* z); + + void WriterThread(Connection* c); + void AckThread(Connection* c); + void ZeroCopyCompletionThread(Connection* c); + + void StartConnectionThreads(Connection& c); + void StopConnectionThreads(Connection& c); + + void EnqueueZeroCopyPending(Connection& c, ZeroCopyReturnValue* z, uint32_t first_id, uint32_t last_id); + void ReleaseCompletedZeroCopy(Connection& c); + void ForceReleasePendingZeroCopy(Connection& c); + bool WaitForZeroCopyDrain(Connection& c, std::chrono::milliseconds timeout); + + bool WaitForAck(Connection& c, TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string* error_text); public: - explicit TCPStreamPusher(const std::vector& addr, - std::optional send_buffer_size = {}, - std::optional zerocopy_threshold = {}, - size_t send_queue_size = 4096); + explicit TCPStreamPusher(const std::string& addr, + size_t in_expected_connections, + std::optional in_send_buffer_size = {}); + + ~TCPStreamPusher() override; void StartDataCollection(StartMessage& message) override; bool EndDataCollection(const EndMessage& message) override; diff --git a/image_pusher/TCPStreamPusherSocket.cpp b/image_pusher/TCPStreamPusherSocket.cpp deleted file mode 100644 index 4aa9b9a3..00000000 --- a/image_pusher/TCPStreamPusherSocket.cpp +++ /dev/null @@ -1,598 +0,0 @@ -// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute -// SPDX-License-Identifier: GPL-3.0-only - -#include "TCPStreamPusherSocket.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -static std::pair ParseTcpAddress(const std::string& addr) { - const std::string prefix = "tcp://"; - if (addr.rfind(prefix, 0) != 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); - auto hp = addr.substr(prefix.size()); - auto p = hp.find_last_of(':'); - if (p == std::string::npos) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); - - const auto host = hp.substr(0, p); - const auto port_str = hp.substr(p + 1); - - int port_i = 0; - try { - size_t parsed = 0; - port_i = std::stoi(port_str, &parsed); - if (parsed != port_str.size()) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); - } catch (...) { - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); - } - - if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); - - return {host, static_cast(port_i)}; -} - -TCPStreamPusherSocket::TCPStreamPusherSocket(const std::string &addr, - uint32_t in_socket_number, - std::optional in_send_buffer_size, - std::optional in_zerocopy_threshold, - size_t send_queue_size) - : queue(send_queue_size), - endpoint(addr), - socket_number(in_socket_number), - zerocopy_threshold(in_zerocopy_threshold), - send_buffer_size(in_send_buffer_size) { - auto [host, port] = ParseTcpAddress(addr); - - listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); - if (listen_fd < 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "socket(listen) failed"); - - int one = 1; - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); - - sockaddr_in sin{}; - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - if (host == "*" || host == "0.0.0.0") - sin.sin_addr.s_addr = htonl(INADDR_ANY); - else if (inet_pton(AF_INET, host.c_str(), &sin.sin_addr) != 1) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "inet_pton failed for " + host); - - if (bind(listen_fd, reinterpret_cast(&sin), sizeof(sin)) != 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "bind() failed to " + addr); - - if (listen(listen_fd, 16) != 0) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "listen() failed on " + addr); -} - -TCPStreamPusherSocket::~TCPStreamPusherSocket() { - try { - StopWriterThread(); - } catch (...) {} - CloseDataSocket(); - if (listen_fd >= 0) - close(listen_fd); -} - -void TCPStreamPusherSocket::CloseDataSocket() { - int old_fd = fd.exchange(-1); - if (old_fd >= 0) { - shutdown(old_fd, SHUT_RDWR); - close(old_fd); - } -} - -bool TCPStreamPusherSocket::AcceptConnection(std::chrono::milliseconds timeout) { - std::unique_lock ul(send_mutex); - - // Reuse existing healthy connection instead of always forcing reconnect. - if (fd.load() >= 0 && IsConnectionAlive()) { - broken = false; - logger.Info("TCP peer already connected on " + endpoint + ", reusing existing connection"); - return true; - } - - CloseDataSocket(); - broken = false; - - pollfd pfd{}; - pfd.fd = listen_fd; - pfd.events = POLLIN; - - const int prc = poll(&pfd, 1, static_cast(timeout.count())); - if (prc == 0) { - logger.Error("TCP accept timeout (" + std::to_string(timeout.count()) + " ms) on " + endpoint); - return false; - } - if (prc < 0) { - if (errno == EINTR) - return false; - logger.Error("TCP poll() failed on " + endpoint); - return false; - } - if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) { - logger.Error("TCP listen socket error on " + endpoint); - return false; - } - - int new_fd = accept(listen_fd, nullptr, nullptr); - if (new_fd < 0) - return false; - - int one = 1; - setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); - if (send_buffer_size) - setsockopt(new_fd, SOL_SOCKET, SO_SNDBUF, &send_buffer_size.value(), sizeof(int32_t)); -#ifdef SO_ZEROCOPY - setsockopt(new_fd, SOL_SOCKET, SO_ZEROCOPY, &one, sizeof(one)); -#endif - - fd.store(new_fd); - - logger.Info("TCP peer connected on " + endpoint); - return true; -} - -bool TCPStreamPusherSocket::IsConnectionAlive() const { - if (broken) - return false; - - int local_fd = fd.load(); - if (local_fd < 0) - return false; - - pollfd pfd{}; - pfd.fd = local_fd; - pfd.events = POLLOUT; - if (poll(&pfd, 1, 0) < 0) - return false; - if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) - return false; - - int so_error = 0; - socklen_t len = sizeof(so_error); - if (getsockopt(local_fd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) - return false; - - return so_error == 0; -} - -bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) { - const uint8_t *p = static_cast(buf); - size_t sent = 0; - while (sent < len) { - int local_fd = fd.load(); - if (local_fd < 0 || broken) - return false; - - ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL); - if (rc < 0) { - if (errno == EINTR) - continue; - - if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { - CloseDataSocket(); - broken = true; - logger.Error("TCP peer disconnected on " + endpoint + ", stopping this stream"); - return false; - } - return false; - } - sent += static_cast(rc); - } - return true; -} - -bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) { - auto *p = static_cast(buf); - size_t got = 0; - - while (got < len) { - if (!active) - return false; - - int local_fd = fd.load(); - if (local_fd < 0) - return false; - - pollfd pfd{}; - pfd.fd = local_fd; - pfd.events = POLLIN; - - const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window - if (prc == 0) - continue; - if (prc < 0) { - if (errno == EINTR) - continue; - return false; - } - if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) - return false; - if ((pfd.revents & POLLIN) == 0) - continue; - - ssize_t rc = ::recv(local_fd, p + got, len - got, 0); - if (rc == 0) - return false; - if (rc < 0) { - if (errno == EINTR) - continue; - if (errno == EAGAIN || errno == EWOULDBLOCK) - continue; - return false; - } - got += static_cast(rc); - } - return true; -} - -bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) { -#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) - int local_fd = fd.load(); - if (local_fd < 0) - return false; - - msghdr msg{}; - iovec iov{}; - iov.iov_base = const_cast(data); - iov.iov_len = size; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - while (true) { - ssize_t rc = ::sendmsg(local_fd, &msg, MSG_ZEROCOPY | MSG_NOSIGNAL); - if (rc < 0) { - if (errno == EINTR) - continue; - if (errno == EAGAIN) - return SendAll(data, size); - return false; - } - if (static_cast(rc) != size) - return false; - break; - } - - std::unique_lock ul(inflight_mutex); - inflight.push_back(InflightZC{.z = z, .tx_id = next_tx_id.fetch_add(1)}); - return true; -#else - (void) z; - return SendAll(data, size); -#endif -} - -bool TCPStreamPusherSocket::SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z) { - TcpFrameHeader h{}; - h.type = static_cast(type); - h.payload_size = size; - h.image_number = image_number >= 0 ? static_cast(image_number) : 0; - h.socket_number = socket_number; - h.run_number = run_number; - - if (!SendAll(&h, sizeof(h))) { - if (z) - z->release(); - return false; - } - - if (size == 0) { - if (z) - z->release(); - return true; - } - - bool ok; - if (z && zerocopy_threshold && size >= zerocopy_threshold.value()) { - ok = SendPayloadZC(data, size, z); - if (!ok) - z->release(); - } else { - ok = SendAll(data, size); - if (z) - z->release(); - } - - if (ok && type == TCPFrameType::DATA) - data_sent.fetch_add(1, std::memory_order_relaxed); - - return ok; -} - -void TCPStreamPusherSocket::WriterThread() { - while (active) { - const auto e = queue.GetBlocking(); - if (e.end) - break; - - if (!e.z) - continue; - - if (broken) { - e.z->release(); - continue; - } - - bool ok = false; - { - std::unique_lock ul(send_mutex); - ok = SendFrame(static_cast(e.z->GetImage()), - e.z->GetImageSize(), - TCPFrameType::DATA, - e.z->GetImageNumber(), - e.z); - } - - if (!ok) { - broken = true; - logger.Error("TCP send failed on " + endpoint + ", stopping this stream"); - } - } -} - -bool TCPStreamPusherSocket::IsBroken() const { - return broken; -} - -void TCPStreamPusherSocket::CompletionThread() { -#if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) - while (active) { - int local_fd = fd.load(); - if (local_fd < 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - - char cmsgbuf[512]; - char dummy[1]; - iovec iov{.iov_base = dummy, .iov_len = sizeof(dummy)}; - msghdr msg{}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ssize_t rc = ::recvmsg(local_fd, &msg, MSG_ERRQUEUE | MSG_DONTWAIT); - if (rc < 0) { - if (errno == EAGAIN || errno == EINTR) { - std::this_thread::sleep_for(std::chrono::microseconds(100)); - continue; - } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - - for (cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level != SOL_IP || cmsg->cmsg_type != IP_RECVERR) - continue; - auto *serr = reinterpret_cast(CMSG_DATA(cmsg)); - if (!serr || serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY) - continue; - - uint64_t first = serr->ee_info; - uint64_t last = serr->ee_data; - - std::unique_lock ul(inflight_mutex); - while (!inflight.empty() && inflight.front().tx_id <= last) { - auto item = inflight.front(); - inflight.pop_front(); - if (item.tx_id >= first && item.z) - item.z->release(); - } - } - } - - std::unique_lock ul(inflight_mutex); - while (!inflight.empty()) { - if (inflight.front().z) - inflight.front().z->release(); - inflight.pop_front(); - } -#endif -} - -void TCPStreamPusherSocket::AckThread() { - while (active) { - TcpFrameHeader h{}; - if (!ReadExact(&h, sizeof(h))) { - if (active) { - broken = true; - logger.Error("TCP ACK reader disconnected on " + endpoint); - } - break; - } - - if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { - broken = true; - logger.Error("Invalid ACK frame on " + endpoint); - break; - } - - std::string error_text; - if (h.payload_size > 0) { - error_text.resize(h.payload_size); - if (!ReadExact(error_text.data(), error_text.size())) { - broken = true; - break; - } - } - - const auto ack_for = static_cast(h.ack_for); - const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; - const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; - const auto code = static_cast(h.ack_code); - - { - std::unique_lock ul(ack_state_mutex); - last_ack_code = code; - if (!error_text.empty()) - last_ack_error = error_text; - - if (ack_for == TCPFrameType::START) { - start_ack_received = true; - start_ack_ok = ok; - if (!ok && error_text.empty()) - last_ack_error = "START rejected"; - } else if (ack_for == TCPFrameType::END) { - end_ack_received = true; - end_ack_ok = ok; - if (!ok && error_text.empty()) - last_ack_error = "END rejected"; - } else if (ack_for == TCPFrameType::CANCEL) { - cancel_ack_received = true; - cancel_ack_ok = ok; - if (!ok && error_text.empty()) - last_ack_error = "CANCEL rejected"; - } else if (ack_for == TCPFrameType::DATA) { - data_acked_total.fetch_add(1, std::memory_order_relaxed); - if (ok && !fatal) { - data_acked_ok.fetch_add(1, std::memory_order_relaxed); - } else { - data_acked_bad.fetch_add(1, std::memory_order_relaxed); - if (error_text.empty()) - last_ack_error = "DATA ACK failed"; - logger.Error("Received failing DATA ACK on " + endpoint + ": " + last_ack_error); - } - } - } - ack_cv.notify_all(); - } -} - -void TCPStreamPusherSocket::StartWriterThread() { - if (active) - return; - - { - std::unique_lock ul(ack_state_mutex); - start_ack_received = false; - start_ack_ok = false; - end_ack_received = false; - end_ack_ok = false; - cancel_ack_received = false; - cancel_ack_ok = false; - last_ack_error.clear(); - last_ack_code = TCPAckCode::None; - } - - data_sent.store(0, std::memory_order_relaxed); - data_acked_ok.store(0, std::memory_order_relaxed); - data_acked_bad.store(0, std::memory_order_relaxed); - data_acked_total.store(0, std::memory_order_relaxed); - - active = true; - send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this); - completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this); - ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this); -} - -void TCPStreamPusherSocket::StopWriterThread() { - if (!active) - return; - active = false; - queue.PutBlocking({.end = true}); - ack_cv.notify_all(); - - if (send_future.valid()) - send_future.get(); - if (completion_future.valid()) - completion_future.get(); - if (ack_future.valid()) - ack_future.get(); - - // Keep fd open: END frame may still be sent after writer thread stops. - // Socket is closed in destructor / explicit close path. -} - -void TCPStreamPusherSocket::SendImage(ZeroCopyReturnValue &z) { - queue.PutBlocking(ImagePusherQueueElement{ - .image_data = static_cast(z.GetImage()), - .z = &z, - .end = false - }); -} - -bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number) { - if (broken) - return false; - - std::unique_lock ul(send_mutex); - - if (fd.load() < 0) - return false; - - if (!IsConnectionAlive()) { - broken = true; - return false; - } - - return SendFrame(data, size, type, image_number, nullptr); -} - -bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) { - std::unique_lock ul(ack_state_mutex); - const bool ok = ack_cv.wait_for(ul, timeout, [&] { - if (ack_for == TCPFrameType::START) - return start_ack_received || broken.load(); - if (ack_for == TCPFrameType::END) - return end_ack_received || broken.load(); - if (ack_for == TCPFrameType::CANCEL) - return cancel_ack_received || broken.load(); - return false; - }); - - if (!ok) { - if (error_text) - *error_text = "ACK timeout"; - return false; - } - - if (broken) { - if (error_text) - *error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error; - return false; - } - - bool ack_ok = false; - if (ack_for == TCPFrameType::START) - ack_ok = start_ack_ok; - else if (ack_for == TCPFrameType::END) - ack_ok = end_ack_ok; - else if (ack_for == TCPFrameType::CANCEL) - ack_ok = cancel_ack_ok; - - if (!ack_ok && error_text) - *error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error; - - return ack_ok; -} - -std::string TCPStreamPusherSocket::GetLastAckError() const { - std::unique_lock ul(ack_state_mutex); - return last_ack_error; -} - -ImagePusherAckProgress TCPStreamPusherSocket::GetDataAckProgress() const { - ImagePusherAckProgress p; - p.data_sent = data_sent.load(std::memory_order_relaxed); - p.data_acked_ok = data_acked_ok.load(std::memory_order_relaxed); - p.data_acked_bad = data_acked_bad.load(std::memory_order_relaxed); - p.data_acked_total = data_acked_total.load(std::memory_order_relaxed); - p.data_ack_pending = (p.data_sent >= p.data_acked_total) ? (p.data_sent - p.data_acked_total) : 0; - return p; -} \ No newline at end of file diff --git a/image_pusher/TCPStreamPusherSocket.h b/image_pusher/TCPStreamPusherSocket.h deleted file mode 100644 index b128f204..00000000 --- a/image_pusher/TCPStreamPusherSocket.h +++ /dev/null @@ -1,113 +0,0 @@ -// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute -// SPDX-License-Identifier: GPL-3.0-only - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "ImagePusher.h" -#include "../common/ThreadSafeFIFO.h" -#include "../common/Logger.h" -#include "../common/JfjochTCP.h" - -struct ImagePusherAckProgress { - uint64_t data_sent = 0; - uint64_t data_acked_ok = 0; - uint64_t data_acked_bad = 0; - uint64_t data_acked_total = 0; - uint64_t data_ack_pending = 0; -}; - -class TCPStreamPusherSocket { - struct InflightZC { - ZeroCopyReturnValue *z = nullptr; - uint64_t tx_id = 0; - }; - - std::mutex send_mutex; - std::atomic active = false; - std::future send_future; - std::future completion_future; - std::future ack_future; - - ThreadSafeFIFO queue; - - std::atomic fd{-1}; - int listen_fd = -1; - std::string endpoint; - uint32_t socket_number = 0; - uint64_t run_number = 0; - std::optional zerocopy_threshold; - std::optional send_buffer_size; - - constexpr static auto AcceptTimeout = std::chrono::seconds(5); - - std::atomic broken{false}; - std::atomic last_ack_code{TCPAckCode::None}; - std::string last_ack_error; - mutable std::mutex ack_state_mutex; - std::condition_variable ack_cv; - bool start_ack_received = false; - bool start_ack_ok = false; - bool end_ack_received = false; - bool end_ack_ok = false; - bool cancel_ack_received = false; - bool cancel_ack_ok = false; - - std::atomic next_tx_id{1}; - std::mutex inflight_mutex; - std::deque inflight; - - Logger logger{"TCPStream2PusherSocket"}; - - std::atomic data_sent{0}; - std::atomic data_acked_ok{0}; - std::atomic data_acked_bad{0}; - std::atomic data_acked_total{0}; - - void WriterThread(); - void CompletionThread(); - void AckThread(); - - void CloseDataSocket(); - - bool SendAll(const void *buf, size_t len); - bool ReadExact(void *buf, size_t len); - bool SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z); - bool SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z); -public: - explicit TCPStreamPusherSocket(const std::string& addr, - uint32_t in_socket_number, - std::optional send_buffer_size, - std::optional in_zerocopy_threshold, - size_t send_queue_size = 4096); - - ~TCPStreamPusherSocket(); - - std::string GetEndpointName() const; - - bool Send(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number = -1); - - bool AcceptConnection(std::chrono::milliseconds timeout = std::chrono::duration_cast(AcceptTimeout)); - bool IsConnectionAlive() const; - - void StartWriterThread(); - void StopWriterThread(); - - bool WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text = nullptr); - - void SetRunNumber(uint64_t in_run_number); - - void SendImage(ZeroCopyReturnValue &z); - - bool IsBroken() const; - std::string GetLastAckError() const; - - ImagePusherAckProgress GetDataAckProgress() const; -}; diff --git a/tests/JFJochReceiverProcessingTest.cpp b/tests/JFJochReceiverProcessingTest.cpp index ad8524d6..741293c8 100644 --- a/tests/JFJochReceiverProcessingTest.cpp +++ b/tests/JFJochReceiverProcessingTest.cpp @@ -1511,7 +1511,7 @@ TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver] aq_devices.Add(std::move(test)); - TCPStreamPusher pusher({"tcp://127.0.0.1:9121"}); + TCPStreamPusher pusher("tcp://127.0.0.1:9121", 1); TCPImagePuller puller("tcp://127.0.0.1:9121"); StreamWriter writer(logger, puller); diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index b7d1f454..dce1b93f 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -22,22 +22,14 @@ TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); - std::vector addr{ - "tcp://127.0.0.1:19001", - "tcp://127.0.0.1:19002" - }; + std::string addr = "tcp://127.0.0.1:19001"; std::vector> puller; for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); + puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); } - TCPStreamPusher pusher( - addr, - 64 * 1024 * 1024, - 128 * 1024, - 8192 - ); + TCPStreamPusher pusher(addr,npullers); std::vector received(npullers, 0); std::vector processed(npullers, 0); @@ -152,22 +144,13 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); - std::vector addr{ - "tcp://127.0.0.1:19011", - "tcp://127.0.0.1:19012" - }; + std::string addr = "tcp://127.0.0.1:19003"; std::vector> puller; - for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); - } + for (int i = 0; i < npullers; i++) + puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); - TCPStreamPusher pusher( - addr, - 64 * 1024 * 1024, - 128 * 1024, - 8192 - ); + TCPStreamPusher pusher(addr,npullers); std::atomic sent_fatal{false}; @@ -196,7 +179,7 @@ TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); } - REQUIRE_FALSE(pusher.EndDataCollection(end)); + REQUIRE(pusher.EndDataCollection(end)); const auto final_msg = pusher.Finalize(); REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); }); @@ -287,22 +270,13 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") { std::vector image1(x.GetPixelsNum() * nframes); for (auto &i : image1) i = dist(g1); - std::vector addr{ - "tcp://127.0.0.1:19021", - "tcp://127.0.0.1:19022" - }; + std::string addr = "tcp://127.0.0.1:19004"; std::vector> puller; - for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); - } + for (int i = 0; i < npullers; i++) + puller.push_back(std::make_unique(addr, 64 * 1024 * 1024)); - TCPStreamPusher pusher( - addr, - 64 * 1024 * 1024, - 128 * 1024, - 8192 - ); + TCPStreamPusher pusher(addr,npullers); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024);