// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute // SPDX-License-Identifier: GPL-3.0-only #include "TCPStreamPusher.h" TCPStreamPusher::TCPStreamPusher(const std::vector &addr, std::optional send_buffer_size, std::optional zerocopy_threshold, size_t send_queue_size) : serialization_buffer(256 * 1024 * 1024), serializer(serialization_buffer.data(), serialization_buffer.size()) { if (addr.empty()) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided"); for (size_t i = 0; i < addr.size(); i++) { socket.emplace_back(std::make_unique( addr[i], static_cast(i), send_buffer_size, zerocopy_threshold, send_queue_size)); } } void TCPStreamPusher::StartDataCollection(StartMessage &message) { if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Images per file cannot be zero or negative"); images_per_file = message.images_per_file; run_number = message.run_number; run_name = message.run_name; transmission_error = false; for (size_t i = 0; i < socket.size(); i++) { if (!socket[i]->AcceptConnection(std::chrono::seconds(5))) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP accept timeout/failure on socket " + socket[i]->GetEndpointName()); } for (auto &s : socket) s->StartWriterThread(); std::vector started(socket.size(), false); auto rollback_cancel = [&]() { for (size_t i = 0; i < socket.size(); i++) { if (!started[i] || socket[i]->IsBroken()) continue; (void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL); std::string cancel_ack_err; (void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); } for (auto &s : socket) s->StopWriterThread(); }; for (size_t i = 0; i < socket.size(); i++) { message.socket_number = static_cast(i); if (i > 0) message.write_master_file = false; serializer.SerializeSequenceStart(message); socket[i]->SetRunNumber(run_number); if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Timeout/failure sending START on " + socket[i]->GetEndpointName()); } std::string ack_err; if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { rollback_cancel(); throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err); } started[i] = true; } } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { if (socket.empty()) return false; auto socket_number = (image_number / images_per_file) % socket.size(); if (socket[socket_number]->IsBroken()) return false; return socket[socket_number]->Send(image_data, image_size, TCPFrameType::DATA, image_number); } void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) { if (socket.empty()) { z.release(); return; } auto socket_number = (z.GetImageNumber() / images_per_file) % socket.size(); if (socket[socket_number]->IsBroken()) { z.release(); return; } socket[socket_number]->SendImage(z); } bool TCPStreamPusher::EndDataCollection(const EndMessage &message) { serializer.SerializeSequenceEnd(message); bool ret = true; for (auto &s : socket) { if (s->IsBroken()) { ret = false; continue; } if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) { ret = false; continue; } std::string ack_err; if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) { ret = false; } } for (auto &s : socket) s->StopWriterThread(); transmission_error = !ret; return ret; } std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; for (size_t i = 0; i < socket.size(); i++) { if (socket[i]->IsBroken()) { const auto reason = socket[i]->GetLastAckError(); ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";"; } } return ret; } std::string TCPStreamPusher::PrintSetup() const { std::string output = "TCPStream2Pusher: Sending images to sockets: "; for (const auto &s : socket) output += s->GetEndpointName() + " "; return output; } std::string TCPStreamPusherSocket::GetEndpointName() const { return endpoint; } void TCPStreamPusherSocket::SetRunNumber(uint64_t in_run_number) { run_number = in_run_number; } bool TCPStreamPusher::SendCalibration(const CompressedImage &message) { if (socket.empty()) return false; serializer.SerializeCalibration(message); return socket[0]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION); } std::optional TCPStreamPusher::GetImagesWritten() const { uint64_t ret = 0; for (const auto &s : socket) { auto p = s->GetDataAckProgress(); ret += p.data_acked_ok; } return ret; }