// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "TcpImagePuller.h" #include #include #include #include #include #include static std::pair ParseTcpAddressPull(const std::string& addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); auto hp = addr.substr(prefix.size()); auto p = hp.find_last_of(':'); if (p == std::string::npos) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); const auto host = hp.substr(0, p); const auto port_str = hp.substr(p + 1); int port_i = 0; try { size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } TCPImagePuller::TCPImagePuller(const std::string &tcp_addr, std::optional rcv_buffer_size) : addr(tcp_addr), receive_buffer_size(rcv_buffer_size) { auto parsed = ParseTcpAddressPull(tcp_addr); host = parsed.first; port = parsed.second; receiver_thread = std::thread(&TCPImagePuller::ReceiverThread, this); cbor_thread = std::thread(&TCPImagePuller::CBORThread, this); } TCPImagePuller::~TCPImagePuller() { TCPImagePuller::Disconnect(); } void TCPImagePuller::CloseSocket() { int old_fd = -1; { std::unique_lock ul(fd_mutex); if (fd >= 0) { old_fd = fd; fd = -1; } } if (old_fd >= 0) { shutdown(old_fd, SHUT_RDWR); close(old_fd); } } bool TCPImagePuller::EnsureConnected() { { std::unique_lock ul(fd_mutex); if (fd >= 0) return true; } addrinfo hints{}; hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; addrinfo *res = nullptr; const std::string port_str = std::to_string(port); int gai_rc = getaddrinfo(host.c_str(), port_str.c_str(), &hints, &res); if (gai_rc != 0) { logger.Error(std::string("getaddrinfo failed for ") + host + ":" + port_str + " - " + gai_strerror(gai_rc)); return false; } int new_fd = -1; for (addrinfo *ai = res; ai != nullptr; ai = ai->ai_next) { new_fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); if (new_fd < 0) continue; if (receive_buffer_size) setsockopt(new_fd, SOL_SOCKET, SO_RCVBUF, &receive_buffer_size.value(), sizeof(int32_t)); if (connect(new_fd, ai->ai_addr, ai->ai_addrlen) == 0) break; close(new_fd); new_fd = -1; } freeaddrinfo(res); if (new_fd < 0) return false; { std::unique_lock ul(fd_mutex); if (fd >= 0) { close(new_fd); return true; } fd = new_fd; } logger.Info("TCP connected to " + addr); return true; } bool TCPImagePuller::ReadExact(void *buf, size_t size) { auto p = static_cast(buf); size_t got = 0; while (got < size) { if (disconnect) return false; int local_fd = -1; { std::unique_lock ul(fd_mutex); local_fd = fd; } if (local_fd < 0) return false; ssize_t rc = recv(local_fd, p + got, size - got, 0); if (rc == 0) return false; if (rc < 0) { if (errno == EINTR) continue; return false; } got += static_cast(rc); } return true; } void TCPImagePuller::ReceiverThread() { try { while (!disconnect) { if (!EnsureConnected()) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } RawFrame frame{}; if (!ReadExact(&frame.header, sizeof(frame.header))) { logger.Info("TCP receive failed, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } if (frame.header.magic != JFJOCH_TCP_MAGIC || frame.header.version != JFJOCH_TCP_VERSION) { logger.Error("Invalid TCP frame header, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } ImagePullerOutput out; out.tcp_msg = std::make_shared(); out.tcp_msg->header = frame.header; out.tcp_msg->payload.resize(frame.header.payload_size); if (frame.header.payload_size > 0 && !ReadExact(out.tcp_msg->payload.data(), out.tcp_msg->payload.size())) { logger.Info("TCP payload read failed, reconnecting to " + addr); CloseSocket(); std::this_thread::sleep_for(std::chrono::milliseconds(20)); continue; } cbor_fifo.PutBlocking(out); } } catch (const JFJochException &e) { logger.ErrorException(e); } catch (const std::exception &e) { logger.Error(std::string("Unhandled exception in ReceiverThread: ") + e.what()); } catch (...) { logger.Error("Unhandled unknown exception in ReceiverThread"); } CloseSocket(); cbor_fifo.PutBlocking(ImagePullerOutput{}); } void TCPImagePuller::CBORThread() { auto ret = cbor_fifo.GetBlocking(); while (ret.tcp_msg) { try { ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); outside_fifo.PutBlocking(ret); } catch (const JFJochException &e) { logger.ErrorException(e); } ret = cbor_fifo.GetBlocking(); } outside_fifo.PutBlocking(ret); } void TCPImagePuller::Disconnect() { if (disconnect.exchange(true)) return; CloseSocket(); if (receiver_thread.joinable()) receiver_thread.join(); if (cbor_thread.joinable()) cbor_thread.join(); }