// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "TCPImagePuller.h" #include #include #include #include #include #include static std::pair ParseTcpAddressPull(const std::string& addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); auto hp = addr.substr(prefix.size()); auto p = hp.find_last_of(':'); if (p == std::string::npos) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); const auto host = hp.substr(0, p); const auto port_str = hp.substr(p + 1); int port_i = 0; try { size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } TCPImagePuller::TCPImagePuller(const std::string &tcp_addr, std::optional rcv_buffer_size) : addr(tcp_addr), receive_buffer_size(rcv_buffer_size) { auto parsed = ParseTcpAddressPull(tcp_addr); host = parsed.first; port = parsed.second; receiver_thread = std::thread(&TCPImagePuller::ReceiverThread, this); cbor_thread = std::thread(&TCPImagePuller::CBORThread, this); } bool TCPImagePuller::SendAll(const void *buf, size_t len) { const auto *p = static_cast(buf); size_t sent = 0; while (sent < len) { if (disconnect) return false; int local_fd = -1; { std::unique_lock ul(fd_mutex); local_fd = fd; } if (local_fd < 0) return false; ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL); if (rc < 0) { if (errno == EINTR) continue; return false; } sent += static_cast(rc); } return true; } bool TCPImagePuller::SendAck(const PullerAckMessage &ack) { TcpFrameHeader h{}; h.type = static_cast(TCPFrameType::ACK); h.run_number = ack.run_number; h.socket_number = ack.socket_number; h.image_number = ack.image_number; h.flags = 0; if (ack.ok) h.flags |= TCP_ACK_FLAG_OK; if (ack.fatal) h.flags |= TCP_ACK_FLAG_FATAL; if (!ack.error_text.empty()) h.flags |= TCP_ACK_FLAG_HAS_ERROR_TEXT; h.ack_for = static_cast(ack.ack_for); h.ack_processed_images = ack.processed_images; h.ack_code = static_cast(ack.error_code); h.payload_size = ack.error_text.size(); if (!SendAll(&h, sizeof(h))) return false; if (!ack.error_text.empty()) return SendAll(ack.error_text.data(), ack.error_text.size()); return true; } void TCPImagePuller::CBORThread() { auto ret = cbor_fifo.GetBlocking(); while (ret.tcp_msg) { try { const auto type = static_cast(ret.tcp_msg->header.type); if (type == TCPFrameType::CANCEL) { outside_fifo.PutBlocking(ret); } else { ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); outside_fifo.PutBlocking(ret); } } catch (const JFJochException &e) { logger.ErrorException(e); } ret = cbor_fifo.GetBlocking(); } outside_fifo.PutBlocking(ret); } TCPImagePuller::~TCPImagePuller() { TCPImagePuller::Disconnect(); } void TCPImagePuller::CloseSocket() { int old_fd = -1; { std::unique_lock ul(fd_mutex); if (fd >= 0) { old_fd = fd; fd = -1; } } if (old_fd >= 0) { shutdown(old_fd, SHUT_RDWR); close(old_fd); } } bool TCPImagePuller::EnsureConnected() { { std::unique_lock ul(fd_mutex); if (fd >= 0) return true; } addrinfo hints{}; hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; addrinfo *res = nullptr; const std::string port_str = std::to_string(port); int gai_rc = getaddrinfo(host.c_str(), port_str.c_str(), &hints, &res); if (gai_rc != 0) { logger.Error(std::string("getaddrinfo failed for ") + host + ":" + port_str + " - " + gai_strerror(gai_rc)); return false; } int new_fd = -1; for (addrinfo *ai = res; ai != nullptr; ai = ai->ai_next) { new_fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); if (new_fd < 0) continue; if (receive_buffer_size) setsockopt(new_fd, SOL_SOCKET, SO_RCVBUF, &receive_buffer_size.value(), sizeof(int32_t)); if (connect(new_fd, ai->ai_addr, ai->ai_addrlen) == 0) break; close(new_fd); new_fd = -1; } freeaddrinfo(res); if (new_fd < 0) return false; { std::unique_lock ul(fd_mutex); if (fd >= 0) { close(new_fd); return true; } fd = new_fd; } logger.Info("TCP connected to " + addr); return true; } bool TCPImagePuller::ReadExact(void *buf, size_t size) { auto p = static_cast(buf); size_t got = 0; while (got < size) { if (disconnect) return false; int local_fd = -1; { std::unique_lock ul(fd_mutex); local_fd = fd; } if (local_fd < 0) return false; ssize_t rc = recv(local_fd, p + got, size - got, 0); if (rc == 0) return false; if (rc < 0) { if (errno == EINTR) continue; return false; } got += static_cast(rc); } return true; } void TCPImagePuller::ReceiverThread() { try { while (!disconnect) { if (!EnsureConnected()) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } RawFrame frame{}; if (!ReadExact(&frame.header, sizeof(frame.header))) { logger.Info("TCP receive failed, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } if (frame.header.magic != JFJOCH_TCP_MAGIC || frame.header.version != JFJOCH_TCP_VERSION) { logger.Error("Invalid TCP frame header, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } const auto frame_type = static_cast(frame.header.type); // Respond to keepalive ping with a keepalive pong if (frame_type == TCPFrameType::KEEPALIVE) { if (frame.header.payload_size > 0) { std::vector discard(frame.header.payload_size); if (!ReadExact(discard.data(), discard.size())) { CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } } // Send keepalive pong back TcpFrameHeader pong{}; pong.type = static_cast(TCPFrameType::KEEPALIVE); pong.payload_size = 0; if (!SendAll(&pong, sizeof(pong))) { logger.Info("Keepalive pong send failed, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); } continue; } // Ignore ACK on puller side if (frame_type == TCPFrameType::ACK) { if (frame.header.payload_size > 0) { std::vector discard(frame.header.payload_size); if (!ReadExact(discard.data(), discard.size())) { CloseSocket(); } } continue; } ImagePullerOutput out; out.tcp_msg = std::make_shared(); out.tcp_msg->header = frame.header; out.tcp_msg->payload.resize(frame.header.payload_size); if (frame.header.payload_size > 0 && !ReadExact(out.tcp_msg->payload.data(), out.tcp_msg->payload.size())) { logger.Info("TCP payload read failed, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } cbor_fifo.PutBlocking(out); } } catch (const JFJochException &e) { logger.ErrorException(e); } catch (const std::exception &e) { logger.Error(std::string("Unhandled exception in ReceiverThread: ") + e.what()); } catch (...) { logger.Error("Unhandled unknown exception in ReceiverThread"); } CloseSocket(); cbor_fifo.PutBlocking(ImagePullerOutput{}); } void TCPImagePuller::Disconnect() { if (disconnect.exchange(true)) return; CloseSocket(); if (receiver_thread.joinable()) receiver_thread.join(); if (cbor_thread.joinable()) cbor_thread.join(); }