179 lines
5.9 KiB
C++
179 lines
5.9 KiB
C++
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
|
|
// SPDX-License-Identifier: GPL-3.0-only
|
|
|
|
#include "TCPStreamPusher.h"
|
|
|
|
|
|
TCPStreamPusher::TCPStreamPusher(const std::vector<std::string> &addr,
|
|
std::optional<int32_t> send_buffer_size,
|
|
std::optional<size_t> zerocopy_threshold,
|
|
size_t send_queue_size)
|
|
: serialization_buffer(256 * 1024 * 1024),
|
|
serializer(serialization_buffer.data(), serialization_buffer.size()) {
|
|
if (addr.empty())
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "No TCP writer address provided");
|
|
|
|
for (size_t i = 0; i < addr.size(); i++) {
|
|
socket.emplace_back(std::make_unique<TCPStreamPusherSocket>(
|
|
addr[i], static_cast<uint32_t>(i), send_buffer_size, zerocopy_threshold, send_queue_size));
|
|
}
|
|
}
|
|
|
|
void TCPStreamPusher::StartDataCollection(StartMessage &message) {
|
|
if (message.images_per_file < 1)
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"Images per file cannot be zero or negative");
|
|
images_per_file = message.images_per_file;
|
|
run_number = message.run_number;
|
|
run_name = message.run_name;
|
|
transmission_error = false;
|
|
|
|
for (size_t i = 0; i < socket.size(); i++) {
|
|
if (!socket[i]->AcceptConnection(std::chrono::seconds(5)))
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"TCP accept timeout/failure on socket " + socket[i]->GetEndpointName());
|
|
}
|
|
|
|
for (auto &s : socket)
|
|
s->StartWriterThread();
|
|
|
|
std::vector<bool> started(socket.size(), false);
|
|
|
|
auto rollback_cancel = [&]() {
|
|
for (size_t i = 0; i < socket.size(); i++) {
|
|
if (!started[i] || socket[i]->IsBroken())
|
|
continue;
|
|
|
|
(void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL);
|
|
std::string cancel_ack_err;
|
|
(void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err);
|
|
}
|
|
|
|
for (auto &s : socket)
|
|
s->StopWriterThread();
|
|
};
|
|
|
|
for (size_t i = 0; i < socket.size(); i++) {
|
|
message.socket_number = static_cast<int64_t>(i);
|
|
if (i > 0)
|
|
message.write_master_file = false;
|
|
|
|
serializer.SerializeSequenceStart(message);
|
|
socket[i]->SetRunNumber(run_number);
|
|
|
|
if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) {
|
|
rollback_cancel();
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"Timeout/failure sending START on " + socket[i]->GetEndpointName());
|
|
}
|
|
|
|
std::string ack_err;
|
|
if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) {
|
|
rollback_cancel();
|
|
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
|
|
"START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err);
|
|
}
|
|
|
|
started[i] = true;
|
|
}
|
|
}
|
|
|
|
bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) {
|
|
if (socket.empty())
|
|
return false;
|
|
|
|
auto socket_number = (image_number / images_per_file) % socket.size();
|
|
if (socket[socket_number]->IsBroken())
|
|
return false;
|
|
|
|
return socket[socket_number]->Send(image_data, image_size, TCPFrameType::DATA, image_number);
|
|
}
|
|
|
|
void TCPStreamPusher::SendImage(ZeroCopyReturnValue &z) {
|
|
if (socket.empty()) {
|
|
z.release();
|
|
return;
|
|
}
|
|
|
|
auto socket_number = (z.GetImageNumber() / images_per_file) % socket.size();
|
|
if (socket[socket_number]->IsBroken()) {
|
|
z.release();
|
|
return;
|
|
}
|
|
|
|
socket[socket_number]->SendImage(z);
|
|
}
|
|
|
|
bool TCPStreamPusher::EndDataCollection(const EndMessage &message) {
|
|
serializer.SerializeSequenceEnd(message);
|
|
|
|
bool ret = true;
|
|
for (auto &s : socket) {
|
|
if (s->IsBroken()) {
|
|
ret = false;
|
|
continue;
|
|
}
|
|
|
|
if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) {
|
|
ret = false;
|
|
continue;
|
|
}
|
|
|
|
std::string ack_err;
|
|
if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) {
|
|
ret = false;
|
|
}
|
|
}
|
|
|
|
for (auto &s : socket)
|
|
s->StopWriterThread();
|
|
|
|
transmission_error = !ret;
|
|
return ret;
|
|
}
|
|
|
|
std::string TCPStreamPusher::Finalize() {
|
|
std::string ret;
|
|
if (transmission_error)
|
|
ret += "Timeout sending images (e.g., writer disabled during data collection);";
|
|
|
|
for (size_t i = 0; i < socket.size(); i++) {
|
|
if (socket[i]->IsBroken()) {
|
|
const auto reason = socket[i]->GetLastAckError();
|
|
ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";";
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
std::string TCPStreamPusher::PrintSetup() const {
|
|
std::string output = "TCPStream2Pusher: Sending images to sockets: ";
|
|
for (const auto &s : socket)
|
|
output += s->GetEndpointName() + " ";
|
|
return output;
|
|
}
|
|
|
|
std::string TCPStreamPusherSocket::GetEndpointName() const {
|
|
return endpoint;
|
|
}
|
|
|
|
void TCPStreamPusherSocket::SetRunNumber(uint64_t in_run_number) {
|
|
run_number = in_run_number;
|
|
}
|
|
|
|
bool TCPStreamPusher::SendCalibration(const CompressedImage &message) {
|
|
if (socket.empty())
|
|
return false;
|
|
serializer.SerializeCalibration(message);
|
|
return socket[0]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION);
|
|
}
|
|
|
|
std::optional<uint64_t> TCPStreamPusher::GetImagesWritten() const {
|
|
uint64_t ret = 0;
|
|
for (const auto &s : socket) {
|
|
auto p = s->GetDataAckProgress();
|
|
ret += p.data_acked_ok;
|
|
}
|
|
return ret;
|
|
}
|