diff --git a/broker/JFJochBrokerParser.cpp b/broker/JFJochBrokerParser.cpp index a0d7c966..821ae5b2 100644 --- a/broker/JFJochBrokerParser.cpp +++ b/broker/JFJochBrokerParser.cpp @@ -201,8 +201,6 @@ std::unique_ptr ParseTCPImagePusher(const org::openapitools::server auto tmp = std::make_unique(j.getZeromq().getImageSocket(), send_buffer_size); - if (j.getZeromq().writerNotificationSocketIsSet()) - tmp->WriterNotificationSocket(j.getZeromq().getWriterNotificationSocket()); return std::move(tmp); } diff --git a/common/JfjochTCP.h b/common/JfjochTCP.h index b5fe36a7..51c6d2a4 100644 --- a/common/JfjochTCP.h +++ b/common/JfjochTCP.h @@ -6,15 +6,33 @@ #include constexpr uint32_t JFJOCH_TCP_MAGIC = 0x4A464A54; // JFJT -constexpr uint32_t JFJOCH_TCP_VERSION = 1; +constexpr uint32_t JFJOCH_TCP_VERSION = 2; enum class TCPFrameType : uint16_t { START = 1, DATA = 2, CALIBRATION = 3, - END = 4 + END = 4, + ACK = 5, + CANCEL = 6 }; +enum class TCPAckCode : uint16_t { + None = 0, + StartFailed = 1, + DataWriteFailed = 2, + EndFailed = 3, + DiskQuotaExceeded = 4, + NoSpaceLeft = 5, + PermissionDenied = 6, + IoError = 7, + ProtocolError = 8 +}; + +constexpr uint32_t TCP_ACK_FLAG_OK = 1u << 0; +constexpr uint32_t TCP_ACK_FLAG_FATAL = 1u << 1; +constexpr uint32_t TCP_ACK_FLAG_HAS_ERROR_TEXT = 1u << 2; + struct alignas(64) TcpFrameHeader { uint32_t magic = JFJOCH_TCP_MAGIC; uint16_t version = JFJOCH_TCP_VERSION ; @@ -24,5 +42,10 @@ struct alignas(64) TcpFrameHeader { uint32_t socket_number = 0; uint32_t flags = 0; uint64_t run_number = 0; - uint64_t reserved[4] = {0, 0, 0, 0}; + + uint32_t ack_processed_images = 0; + uint16_t ack_code = 0; + uint16_t ack_for = 0; + + uint64_t reserved[2] = {0, 0}; }; \ No newline at end of file diff --git a/image_puller/ImagePuller.h b/image_puller/ImagePuller.h index c6ef4971..66f0b545 100644 --- a/image_puller/ImagePuller.h +++ b/image_puller/ImagePuller.h @@ -11,6 +11,19 @@ #include "../frame_serialize/CBORStream2Deserializer.h" #include "../common/ThreadSafeFIFO.h" #include "../common/JfjochTCP.h" +#include "../common/JfjochTCP.h" + +struct PullerAckMessage { + TCPFrameType ack_for = TCPFrameType::DATA; + bool ok = true; + bool fatal = false; + uint64_t run_number = 0; + uint32_t socket_number = 0; + uint64_t image_number = 0; + uint64_t processed_images = 0; + TCPAckCode error_code = TCPAckCode::None; + std::string error_text; +}; struct RawFrame { TcpFrameHeader header{}; @@ -42,6 +55,9 @@ public: [[nodiscard]] size_t GetMaxFifoUtilization() const; [[nodiscard]] size_t GetCurrentFifoUtilization() const; + virtual bool SupportsAck() const { return false; } + virtual bool SendAck(const PullerAckMessage &) { return false; } + virtual void Disconnect() = 0; }; diff --git a/image_puller/TCPImagePuller.cpp b/image_puller/TCPImagePuller.cpp index 5c9fb0a8..b002e662 100644 --- a/image_puller/TCPImagePuller.cpp +++ b/image_puller/TCPImagePuller.cpp @@ -51,6 +51,77 @@ TCPImagePuller::TCPImagePuller(const std::string &tcp_addr, cbor_thread = std::thread(&TCPImagePuller::CBORThread, this); } +bool TCPImagePuller::SendAll(const void *buf, size_t len) { + const auto *p = static_cast(buf); + size_t sent = 0; + while (sent < len) { + if (disconnect) + return false; + + int local_fd = -1; + { + std::unique_lock ul(fd_mutex); + local_fd = fd; + } + if (local_fd < 0) + return false; + + ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL); + if (rc < 0) { + if (errno == EINTR) + continue; + return false; + } + sent += static_cast(rc); + } + return true; +} + +bool TCPImagePuller::SendAck(const PullerAckMessage &ack) { + TcpFrameHeader h{}; + h.type = static_cast(TCPFrameType::ACK); + h.run_number = ack.run_number; + h.socket_number = ack.socket_number; + h.image_number = ack.image_number; + h.flags = 0; + if (ack.ok) + h.flags |= TCP_ACK_FLAG_OK; + if (ack.fatal) + h.flags |= TCP_ACK_FLAG_FATAL; + if (!ack.error_text.empty()) + h.flags |= TCP_ACK_FLAG_HAS_ERROR_TEXT; + + h.ack_for = static_cast(ack.ack_for); + h.ack_processed_images = ack.processed_images; + h.ack_code = static_cast(ack.error_code); + h.payload_size = ack.error_text.size(); + + if (!SendAll(&h, sizeof(h))) + return false; + if (!ack.error_text.empty()) + return SendAll(ack.error_text.data(), ack.error_text.size()); + return true; +} + +void TCPImagePuller::CBORThread() { + auto ret = cbor_fifo.GetBlocking(); + while (ret.tcp_msg) { + try { + const auto type = static_cast(ret.tcp_msg->header.type); + if (type == TCPFrameType::CANCEL) { + outside_fifo.PutBlocking(ret); + } else { + ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); + outside_fifo.PutBlocking(ret); + } + } catch (const JFJochException &e) { + logger.ErrorException(e); + } + ret = cbor_fifo.GetBlocking(); + } + outside_fifo.PutBlocking(ret); +} + TCPImagePuller::~TCPImagePuller() { TCPImagePuller::Disconnect(); } @@ -179,6 +250,17 @@ void TCPImagePuller::ReceiverThread() { continue; } + // Ignore ACK on puller side + if (static_cast(frame.header.type) == TCPFrameType::ACK) { + if (frame.header.payload_size > 0) { + std::vector discard(frame.header.payload_size); + if (!ReadExact(discard.data(), discard.size())) { + CloseSocket(); + } + } + continue; + } + ImagePullerOutput out; out.tcp_msg = std::make_shared(); out.tcp_msg->header = frame.header; @@ -206,19 +288,6 @@ void TCPImagePuller::ReceiverThread() { cbor_fifo.PutBlocking(ImagePullerOutput{}); } -void TCPImagePuller::CBORThread() { - auto ret = cbor_fifo.GetBlocking(); - while (ret.tcp_msg) { - try { - ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); - outside_fifo.PutBlocking(ret); - } catch (const JFJochException &e) { - logger.ErrorException(e); - } - ret = cbor_fifo.GetBlocking(); - } - outside_fifo.PutBlocking(ret); -} void TCPImagePuller::Disconnect() { if (disconnect.exchange(true)) return; diff --git a/image_puller/TCPImagePuller.h b/image_puller/TCPImagePuller.h index 5f4b5b33..714a3f9b 100644 --- a/image_puller/TCPImagePuller.h +++ b/image_puller/TCPImagePuller.h @@ -29,6 +29,7 @@ class TCPImagePuller : public ImagePuller { Logger logger{"TCPImagePuller"}; bool ReadExact(void *buf, size_t size); + bool SendAll(const void *buf, size_t len); bool EnsureConnected(); void CloseSocket(); void ReceiverThread(); @@ -37,5 +38,7 @@ public: explicit TCPImagePuller(const std::string &tcp_addr, std::optional rcv_buffer_size = {}); ~TCPImagePuller() override; + bool SupportsAck() const override { return true; } + bool SendAck(const PullerAckMessage &ack) override; void Disconnect() override; }; \ No newline at end of file diff --git a/image_pusher/HDF5FilePusher.h b/image_pusher/HDF5FilePusher.h index 4a1c0940..796df4ff 100644 --- a/image_pusher/HDF5FilePusher.h +++ b/image_pusher/HDF5FilePusher.h @@ -26,7 +26,6 @@ public: void SendImage(ZeroCopyReturnValue &z) override; bool SendCalibration(const CompressedImage &message) override; - std::string PrintSetup() const override; }; diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index a369164a..8ac57279 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -19,7 +19,6 @@ TCPStreamPusher::TCPStreamPusher(const std::vector &addr, } } - void TCPStreamPusher::StartDataCollection(StartMessage &message) { if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, @@ -35,6 +34,25 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) { "TCP accept timeout/failure on socket " + socket[i]->GetEndpointName()); } + for (auto &s : socket) + s->StartWriterThread(); + + std::vector started(socket.size(), false); + + auto rollback_cancel = [&]() { + for (size_t i = 0; i < socket.size(); i++) { + if (!started[i] || socket[i]->IsBroken()) + continue; + + (void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL); + std::string cancel_ack_err; + (void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); + } + + for (auto &s : socket) + s->StopWriterThread(); + }; + for (size_t i = 0; i < socket.size(); i++) { message.socket_number = static_cast(i); if (i > 0) @@ -44,17 +62,20 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) { socket[i]->SetRunNumber(run_number); if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { - // one-shot recovery: reconnect and retry START once - if (!socket[i]->AcceptConnection(std::chrono::seconds(5)) || - !socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "Timeout/failure sending START"); - } + rollback_cancel(); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "Timeout/failure sending START on " + socket[i]->GetEndpointName()); } - } - for (auto &s : socket) - s->StartWriterThread(); + std::string ack_err; + if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { + rollback_cancel(); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err); + } + + started[i] = true; + } } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { @@ -88,12 +109,25 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage &message) { bool ret = true; for (auto &s : socket) { - s->StopWriterThread(); - if (s->IsBroken()) + if (s->IsBroken()) { ret = false; - else if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) + continue; + } + + if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) { ret = false; + continue; + } + + std::string ack_err; + if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) { + ret = false; + } } + + for (auto &s : socket) + s->StopWriterThread(); + transmission_error = !ret; return ret; } @@ -102,32 +136,16 @@ std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; - if (writer_notification_socket) { - for (size_t i = 0; i < socket.size(); i++) { - auto n = writer_notification_socket->Receive(run_number, run_name); - if (!n) - ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end"; - else if (static_cast(n->socket_number) >= socket.size()) - ret += "Writer " + std::to_string(i) + ": mismatch in socket number"; - else if (!n->ok) - ret += "Writer " + std::to_string(i) + ": " + n->error; + + for (size_t i = 0; i < socket.size(); i++) { + if (socket[i]->IsBroken()) { + const auto reason = socket[i]->GetLastAckError(); + ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";"; } } return ret; } -std::string TCPStreamPusher::GetWriterNotificationSocketAddress() const { - if (writer_notification_socket) - return writer_notification_socket->GetEndpointName(); - else - return ""; -} - -TCPStreamPusher &TCPStreamPusher::WriterNotificationSocket(const std::string &addr) { - writer_notification_socket = std::make_unique(addr, std::chrono::minutes(1)); - return *this; -} - std::string TCPStreamPusher::PrintSetup() const { std::string output = "TCPStream2Pusher: Sending images to sockets: "; for (const auto &s : socket) diff --git a/image_pusher/TCPStreamPusher.h b/image_pusher/TCPStreamPusher.h index 3d574c22..a29c6115 100644 --- a/image_pusher/TCPStreamPusher.h +++ b/image_pusher/TCPStreamPusher.h @@ -11,8 +11,6 @@ class TCPStreamPusher : public ImagePusher { CBORStream2Serializer serializer; std::vector> socket; - std::unique_ptr writer_notification_socket; - int64_t images_per_file = 1; uint64_t run_number = 0; std::string run_name; @@ -23,9 +21,6 @@ public: std::optional zerocopy_threshold = {}, size_t send_queue_size = 4096); - TCPStreamPusher& WriterNotificationSocket(const std::string& addr); - std::string GetWriterNotificationSocketAddress() const override; - void StartDataCollection(StartMessage& message) override; bool EndDataCollection(const EndMessage& message) override; bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override; diff --git a/image_pusher/TCPStreamPusherSocket.cpp b/image_pusher/TCPStreamPusherSocket.cpp index ef4291ec..f8e7df32 100644 --- a/image_pusher/TCPStreamPusherSocket.cpp +++ b/image_pusher/TCPStreamPusherSocket.cpp @@ -197,6 +197,50 @@ bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) { return true; } +bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) { + auto *p = static_cast(buf); + size_t got = 0; + + while (got < len) { + if (!active) + return false; + + int local_fd = fd.load(); + if (local_fd < 0) + return false; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLIN; + + const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window + if (prc == 0) + continue; + if (prc < 0) { + if (errno == EINTR) + continue; + return false; + } + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return false; + if ((pfd.revents & POLLIN) == 0) + continue; + + ssize_t rc = ::recv(local_fd, p + got, len - got, 0); + if (rc == 0) + return false; + if (rc < 0) { + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) + continue; + return false; + } + got += static_cast(rc); + } + return true; +} + bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) { #if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) int local_fd = fd.load(); @@ -358,10 +402,89 @@ void TCPStreamPusherSocket::CompletionThread() { #endif } +void TCPStreamPusherSocket::AckThread() { + while (active) { + TcpFrameHeader h{}; + if (!ReadExact(&h, sizeof(h))) { + if (active) { + broken = true; + logger.Error("TCP ACK reader disconnected on " + endpoint); + } + break; + } + + if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { + broken = true; + logger.Error("Invalid ACK frame on " + endpoint); + break; + } + + std::string error_text; + if (h.payload_size > 0) { + error_text.resize(h.payload_size); + if (!ReadExact(error_text.data(), error_text.size())) { + broken = true; + break; + } + } + + const auto ack_for = static_cast(h.ack_for); + const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; + const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; + const auto code = static_cast(h.ack_code); + + { + std::unique_lock ul(ack_state_mutex); + last_ack_code = code; + if (!error_text.empty()) + last_ack_error = error_text; + + if (ack_for == TCPFrameType::START) { + start_ack_received = true; + start_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "START rejected"; + } else if (ack_for == TCPFrameType::END) { + end_ack_received = true; + end_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "END rejected"; + } else if (ack_for == TCPFrameType::CANCEL) { + cancel_ack_received = true; + cancel_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "CANCEL rejected"; + } else if (ack_for == TCPFrameType::DATA && (!ok || fatal)) { + broken = true; + if (error_text.empty()) + last_ack_error = "DATA fatal ACK"; + logger.Error("Received fatal DATA ACK on " + endpoint + ": " + last_ack_error); + } + } + ack_cv.notify_all(); + } +} + void TCPStreamPusherSocket::StartWriterThread() { + if (active) + return; + + { + std::unique_lock ul(ack_state_mutex); + start_ack_received = false; + start_ack_ok = false; + end_ack_received = false; + end_ack_ok = false; + cancel_ack_received = false; + cancel_ack_ok = false; + last_ack_error.clear(); + last_ack_code = TCPAckCode::None; + } + active = true; send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this); completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this); + ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this); } void TCPStreamPusherSocket::StopWriterThread() { @@ -369,11 +492,14 @@ void TCPStreamPusherSocket::StopWriterThread() { return; active = false; queue.PutBlocking({.end = true}); + ack_cv.notify_all(); if (send_future.valid()) send_future.get(); if (completion_future.valid()) completion_future.get(); + if (ack_future.valid()) + ack_future.get(); // Keep fd open: END frame may still be sent after writer thread stops. // Socket is closed in destructor / explicit close path. @@ -403,3 +529,46 @@ bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType return SendFrame(data, size, type, image_number, nullptr); } + +bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) { + std::unique_lock ul(ack_state_mutex); + const bool ok = ack_cv.wait_for(ul, timeout, [&] { + if (ack_for == TCPFrameType::START) + return start_ack_received || broken.load(); + if (ack_for == TCPFrameType::END) + return end_ack_received || broken.load(); + if (ack_for == TCPFrameType::CANCEL) + return cancel_ack_received || broken.load(); + return false; + }); + + if (!ok) { + if (error_text) + *error_text = "ACK timeout"; + return false; + } + + if (broken) { + if (error_text) + *error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error; + return false; + } + + bool ack_ok = false; + if (ack_for == TCPFrameType::START) + ack_ok = start_ack_ok; + else if (ack_for == TCPFrameType::END) + ack_ok = end_ack_ok; + else if (ack_for == TCPFrameType::CANCEL) + ack_ok = cancel_ack_ok; + + if (!ack_ok && error_text) + *error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error; + + return ack_ok; +} + +std::string TCPStreamPusherSocket::GetLastAckError() const { + std::unique_lock ul(ack_state_mutex); + return last_ack_error; +} diff --git a/image_pusher/TCPStreamPusherSocket.h b/image_pusher/TCPStreamPusherSocket.h index c735be59..80350ef7 100644 --- a/image_pusher/TCPStreamPusherSocket.h +++ b/image_pusher/TCPStreamPusherSocket.h @@ -26,6 +26,7 @@ class TCPStreamPusherSocket { std::atomic active = false; std::future send_future; std::future completion_future; + std::future ack_future; ThreadSafeFIFO queue; @@ -40,6 +41,16 @@ class TCPStreamPusherSocket { constexpr static auto AcceptTimeout = std::chrono::seconds(5); std::atomic broken{false}; + std::atomic last_ack_code{TCPAckCode::None}; + std::string last_ack_error; + mutable std::mutex ack_state_mutex; + std::condition_variable ack_cv; + bool start_ack_received = false; + bool start_ack_ok = false; + bool end_ack_received = false; + bool end_ack_ok = false; + bool cancel_ack_received = false; + bool cancel_ack_ok = false; std::atomic next_tx_id{1}; std::mutex inflight_mutex; @@ -49,10 +60,12 @@ class TCPStreamPusherSocket { void WriterThread(); void CompletionThread(); + void AckThread(); void CloseDataSocket(); bool SendAll(const void *buf, size_t len); + bool ReadExact(void *buf, size_t len); bool SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z); bool SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z); public: @@ -74,9 +87,12 @@ public: void StartWriterThread(); void StopWriterThread(); + bool WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text = nullptr); + void SetRunNumber(uint64_t in_run_number); void SendImage(ZeroCopyReturnValue &z); bool IsBroken() const; + std::string GetLastAckError() const; }; diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index f4608d1f..0cb0a646 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -7,7 +7,7 @@ #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" -TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { +TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; const int64_t npullers = 2; const int64_t images_per_file = 16; @@ -23,24 +23,24 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { for (auto &i : image1) i = dist(g1); std::vector addr{ - "tcp://127.0.0.1:19001", - "tcp://127.0.0.1:19002" -}; + "tcp://127.0.0.1:19001", + "tcp://127.0.0.1:19002" + }; std::vector> puller; for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique( - addr[i], 64 * 1024 * 1024)); // decoded cbor ring + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); } TCPStreamPusher pusher( addr, 64 * 1024 * 1024, - 128 * 1024, // zerocopy threshold - 8192 // sender queue + 128 * 1024, + 8192 ); std::vector received(npullers, 0); + std::vector processed(npullers, 0); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); @@ -70,29 +70,204 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { REQUIRE(pusher.EndDataCollection(end)); }); + std::vector receivers; + receivers.reserve(npullers); + for (int w = 0; w < npullers; w++) { - bool seen_end = false; - while (!seen_end) { - auto out = puller[w]->PollImage(std::chrono::seconds(10)); - REQUIRE(out.has_value()); - REQUIRE(out->cbor != nullptr); - if (out->cbor->end_message) { - seen_end = true; - continue; + receivers.emplace_back([&, w] { + bool seen_start = false; + bool seen_end = false; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(10)); + REQUIRE(out.has_value()); + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = 0; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_start = true; + continue; + } + + if (out->cbor->data_message) { + REQUIRE(seen_start); + auto n = out->cbor->data_message->number; + REQUIRE(((n / images_per_file) % npullers) == w); + received[w]++; + processed[w]++; + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = processed[w]; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } } - if (out->cbor->data_message) { - auto n = out->cbor->data_message->number; - REQUIRE(((n / images_per_file) % npullers) == w); - received[w]++; - } - } + }); } sender.join(); + for (auto &t : receivers) t.join(); REQUIRE(received[0] == nframes / 2); REQUIRE(received[1] == nframes / 2); + for (auto &p : puller) + p->Disconnect(); +} + +TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { + const size_t nframes = 64; + const int64_t npullers = 2; + const int64_t images_per_file = 8; + + DiffractionExperiment x(DetJF(1)); + x.Raw(); + x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) + .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); + + std::mt19937 g1(42); + std::uniform_int_distribution dist; + std::vector image1(x.GetPixelsNum() * nframes); + for (auto &i : image1) i = dist(g1); + + std::vector addr{ + "tcp://127.0.0.1:19011", + "tcp://127.0.0.1:19012" + }; + + std::vector> puller; + for (int i = 0; i < npullers; i++) { + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); + } + + TCPStreamPusher pusher( + addr, + 64 * 1024 * 1024, + 128 * 1024, + 8192 + ); + + std::atomic sent_fatal{false}; + + std::thread sender([&] { + std::vector serialization_buffer(16 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + + StartMessage start{ + .images_per_file = images_per_file, + .write_master_file = true + }; + EndMessage end{}; + + pusher.StartDataCollection(start); + + for (int64_t i = 0; i < static_cast(nframes); i++) { + DataMessage data_message; + data_message.number = i; + data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t), + x.GetXPixelsNum(), + x.GetYPixelsNum(), + x.GetImageMode(), + x.GetCompressionAlgorithm()); + serializer.SerializeImage(data_message); + (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); + } + + REQUIRE_FALSE(pusher.EndDataCollection(end)); + const auto final_msg = pusher.Finalize(); + REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); + }); + + std::vector receivers; + receivers.reserve(npullers); + + for (int w = 0; w < npullers; w++) { + receivers.emplace_back([&, w] { + bool seen_end = false; + bool local_fatal_sent = false; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(2)); + if (!out.has_value()) { + // Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream. + if (local_fatal_sent) + break; + REQUIRE(out.has_value()); + } + + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->data_message) { + if (w == 0 && !sent_fatal.exchange(true)) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::DATA; + ack.ok = false; + ack.fatal = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = static_cast(out->cbor->data_message->number); + ack.error_code = TCPAckCode::DiskQuotaExceeded; + ack.error_text = "quota exceeded"; + REQUIRE(puller[w]->SendAck(ack)); + local_fatal_sent = true; + } + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } + } + }); + } + + sender.join(); + for (auto &t : receivers) t.join(); + for (auto &p : puller) p->Disconnect(); } \ No newline at end of file diff --git a/writer/StreamWriter.cpp b/writer/StreamWriter.cpp index 1b4ac22b..372b6abd 100644 --- a/writer/StreamWriter.cpp +++ b/writer/StreamWriter.cpp @@ -20,6 +20,27 @@ StreamWriter::StreamWriter(Logger &in_logger, max_image_number(0) { } +void StreamWriter::NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text) { + if (!image_puller.SupportsAck()) + return; + + PullerAckMessage ack; + ack.ack_for = ack_for; + ack.ok = ok; + ack.fatal = fatal; + ack.error_code = code; + ack.error_text = error_text; + ack.run_number = run_number; + ack.socket_number = static_cast(socket_number); + ack.processed_images = processed_images.load(); + + if (image_puller_output.cbor && image_puller_output.cbor->data_message) + ack.image_number = image_puller_output.cbor->data_message->number; + + if (!image_puller.SendAck(ack)) + logger.Warning("Failed to send TCP ACK"); +} + void StreamWriter::ProcessStartMessage() { if (state == StreamWriterState::Finalized) return; // Should not happen (?) @@ -28,6 +49,7 @@ void StreamWriter::ProcessStartMessage() { FinalizeDataCollection(); err = ""; + tcp_data_fatal_sent = false; max_image_number = 0; @@ -51,11 +73,13 @@ void StreamWriter::ProcessStartMessage() { image_puller_output.cbor->start_message->file_prefix, image_puller_output.cbor->start_message->number_of_images); state = StreamWriterState::Started; + NotifyTcpAck(TCPFrameType::START, true, false, TCPAckCode::None); } catch (const JFJochException &e) { logger.ErrorException(e); logger.Error("Error writing start message - switching to error state"); state = StreamWriterState::Error; err = e.what(); + NotifyTcpAck(TCPFrameType::START, false, true, TCPAckCode::StartFailed, err); } } @@ -108,6 +132,10 @@ void StreamWriter::ProcessDataImage() { logger.Warning("Error writing image - switching to error state"); state = StreamWriterState::Error; err = e.what(); + if (!tcp_data_fatal_sent) { + tcp_data_fatal_sent = true; + NotifyTcpAck(TCPFrameType::DATA, false, true, TCPAckCode::DataWriteFailed, err); + } } break; case StreamWriterState::Error: @@ -156,6 +184,11 @@ void StreamWriter::FinalizeDataCollection() { } file_writer.reset(); NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr); + NotifyTcpAck(TCPFrameType::END, + state != StreamWriterState::Error, + state == StreamWriterState::Error, + state == StreamWriterState::Error ? TCPAckCode::EndFailed : TCPAckCode::None, + state == StreamWriterState::Error ? err : ""); logger.Info("Data writing finished"); state = StreamWriterState::Finalized; } @@ -168,6 +201,21 @@ void StreamWriter::CollectImages() { while (run && state != StreamWriterState::Finalized) { run = WaitForImage(); + if (image_puller_output.tcp_msg && + static_cast(image_puller_output.tcp_msg->header.type) == TCPFrameType::CANCEL) { + logger.Warning("Received TCP CANCEL, finalizing data collection"); + if (state != StreamWriterState::Idle && state != StreamWriterState::Finalized) + FinalizeDataCollection(); + NotifyTcpAck(TCPFrameType::CANCEL, true, false, TCPAckCode::None); + state = StreamWriterState::Finalized; + continue; + } + + if (!image_puller_output.cbor) { + logger.Warning("Missing CBOR payload for non-CANCEL TCP frame"); + continue; + } + if (image_puller_output.cbor->start_message) ProcessStartMessage(); else if (image_puller_output.cbor->calibration) diff --git a/writer/StreamWriter.h b/writer/StreamWriter.h index 045d0d50..301a4e6f 100644 --- a/writer/StreamWriter.h +++ b/writer/StreamWriter.h @@ -55,12 +55,14 @@ class StreamWriter { std::vector hdf5_data_file_statistics; bool debug_skip_write_notification = false; + bool tcp_data_fatal_sent = false; ImagePuller &image_puller; Logger &logger; void CollectImages(); bool WaitForImage(); void NotifyReceiverOnFinalizedWrite(const std::string &detector_update_zmq_addr); + void NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text = ""); void ProcessStartMessage(); void ProcessEndMessage(); void ProcessDataImage();