From ad1e724bcf8abc4c09a58a1ca9263d9e04f38ac8 Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Thu, 5 Mar 2026 17:22:47 +0100 Subject: [PATCH] jfjoch_writer: Repub ZeroMQ from TCP image stream --- image_puller/TCPImagePuller.cpp | 69 ++++++++++++-- image_puller/TCPImagePuller.h | 13 ++- tests/TCPImagePusherTest.cpp | 153 ++++++++++++++++++++++++++++++++ writer/jfjoch_writer.cpp | 7 +- 4 files changed, 227 insertions(+), 15 deletions(-) diff --git a/image_puller/TCPImagePuller.cpp b/image_puller/TCPImagePuller.cpp index bbab529d..ecaa2162 100644 --- a/image_puller/TCPImagePuller.cpp +++ b/image_puller/TCPImagePuller.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: GPL-3.0-only #include "TCPImagePuller.h" +#include "../frame_serialize/CBORStream2Serializer.h" #include #include @@ -10,7 +11,7 @@ #include #include -static std::pair ParseTcpAddressPull(const std::string& addr) { +static std::pair ParseTcpAddressPull(const std::string &addr) { const std::string prefix = "tcp://"; if (addr.rfind(prefix, 0) != 0) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP address: " + addr); @@ -28,27 +29,39 @@ static std::pair ParseTcpAddressPull(const std::string& a size_t parsed = 0; port_i = std::stoi(port_str, &parsed); if (parsed != port_str.size()) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "Invalid TCP port in address: " + addr); } catch (...) { throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "Invalid TCP port in address: " + addr); } if (port_i < 1 || port_i > static_cast(std::numeric_limits::max())) - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, "TCP port out of range in address: " + addr); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "TCP port out of range in address: " + addr); return {host, static_cast(port_i)}; } TCPImagePuller::TCPImagePuller(const std::string &tcp_addr, - std::optional rcv_buffer_size) - : addr(tcp_addr), - receive_buffer_size(rcv_buffer_size) { + std::optional rcv_buffer_size, + const std::string &repub_address, + const std::optional &repub_watermark) + : addr(tcp_addr), + receive_buffer_size(rcv_buffer_size) { auto parsed = ParseTcpAddressPull(tcp_addr); host = parsed.first; port = parsed.second; receiver_thread = std::thread(&TCPImagePuller::ReceiverThread, this); cbor_thread = std::thread(&TCPImagePuller::CBORThread, this); + + if (!repub_address.empty()) { + repub_socket = std::make_unique(ZMQSocketType::Push); + repub_socket->SendWaterMark(repub_watermark.value_or(default_repub_watermark)); + repub_socket->SendTimeout(RepubTimeout); + repub_socket->Bind(repub_address); + repub_thread = std::thread(&TCPImagePuller::RepubThread, this); + } } bool TCPImagePuller::SendAll(const void *buf, size_t len) { @@ -113,15 +126,53 @@ void TCPImagePuller::CBORThread() { } else { ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); outside_fifo.PutBlocking(ret); + if (repub_socket) { + if ((ret.cbor->msg_type == CBORImageType::START) + || (ret.cbor->msg_type == CBORImageType::END)) + repub_fifo.PutBlocking(ret); + else + repub_fifo.Put(ret); + } } } catch (const JFJochException &e) { logger.ErrorException(e); } ret = cbor_fifo.GetBlocking(); } + if (repub_socket) + repub_fifo.PutBlocking(ret); outside_fifo.PutBlocking(ret); } +void TCPImagePuller::RepubThread() { + auto ret = repub_fifo.GetBlocking(); + bool repub_active = false; + + while (ret.tcp_msg) { + try { + if (ret.cbor->msg_type == CBORImageType::START) { + // Start message needs to be cleaned when running republish + StartMessage msg = ret.cbor->start_message.value(); + msg.writer_notification_zmq_addr = ""; + std::vector serialization_buffer(256 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + serializer.SerializeSequenceStart(msg); + repub_active = repub_socket->Send(serialization_buffer.data(), serializer.GetBufferSize(), true); + if (repub_active) + logger.Info("Republish active"); + } else { + if (repub_active) + repub_socket->Send(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size(), true); + } + } catch (const JFJochException &e) { + logger.ErrorException(e); + } + ret = repub_fifo.GetBlocking(); + } + if (repub_active) + logger.Info("Republish finished"); +} + TCPImagePuller::~TCPImagePuller() { TCPImagePuller::Disconnect(); } @@ -150,7 +201,7 @@ bool TCPImagePuller::EnsureConnected() { } addrinfo hints{}; - hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 + hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; @@ -322,4 +373,6 @@ void TCPImagePuller::Disconnect() { receiver_thread.join(); if (cbor_thread.joinable()) cbor_thread.join(); -} \ No newline at end of file + if (repub_thread.joinable()) + repub_thread.join(); +} diff --git a/image_puller/TCPImagePuller.h b/image_puller/TCPImagePuller.h index 714a3f9b..5b5f444f 100644 --- a/image_puller/TCPImagePuller.h +++ b/image_puller/TCPImagePuller.h @@ -22,20 +22,31 @@ class TCPImagePuller : public ImagePuller { std::atomic disconnect{false}; ThreadSafeFIFO cbor_fifo{200}; + ThreadSafeFIFO repub_fifo{200}; + + std::unique_ptr repub_socket; std::thread receiver_thread; std::thread cbor_thread; + std::thread repub_thread; Logger logger{"TCPImagePuller"}; + static constexpr uint32_t default_repub_watermark = 220; + static constexpr auto RepubTimeout = std::chrono::milliseconds(100); + bool ReadExact(void *buf, size_t size); bool SendAll(const void *buf, size_t len); bool EnsureConnected(); void CloseSocket(); void ReceiverThread(); void CBORThread(); + void RepubThread(); public: - explicit TCPImagePuller(const std::string &tcp_addr, std::optional rcv_buffer_size = {}); + explicit TCPImagePuller(const std::string &tcp_addr, + std::optional rcv_buffer_size = {}, + const std::string &repub_address = "", + const std::optional &repub_watermark = {}); ~TCPImagePuller() override; bool SupportsAck() const override { return true; } diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index 05586959..af29315a 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -6,6 +6,7 @@ #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" +#include "../image_puller/ZMQImagePuller.h" TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; @@ -550,3 +551,155 @@ TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") { receiver.join(); } + +TEST_CASE("TCPImageCommTest_RepubToZMQ", "[TCP][ZeroMQ]") { + // Chain: TCPStreamPusher --TCP--> TCPImagePuller --ZMQ repub--> ZMQImagePuller + const size_t nframes = 64; + const int64_t images_per_file = 8; + + DiffractionExperiment x(DetJF(1)); + x.Raw(); + x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) + .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); + + std::mt19937 g1(9999); + std::uniform_int_distribution dist; + std::vector image1(x.GetPixelsNum() * nframes); + for (auto &i : image1) i = dist(g1); + + // 1. Create TCP pusher on an auto-assigned port + TCPStreamPusher pusher("tcp://127.0.0.1:*", 1); + + // 2. Create TCP puller with repub over ZMQ (ipc, auto-assigned) + const std::string repub_addr = "ipc://*"; + // Need to figure out the actual repub endpoint after bind — ZMQ ipc://* picks a temp path. + // However, ZMQSocket::Bind with "ipc://*" is used in project; the repub socket binds internally, + // so we need a known address. Use a tcp address instead for testability. + const std::string repub_bind_addr = "tcp://127.0.0.1:19010"; + TCPImagePuller tcp_puller(pusher.GetAddress()[0], {}, repub_bind_addr); + + // 3. Create ZMQ puller that connects to the repub address + ZMQImagePuller zmq_puller(repub_bind_addr); + + // Wait for TCP connection + for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < 1; ++attempt) + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + REQUIRE(pusher.GetConnectedWriters() == 1); + + // Sender thread: push frames over TCP + std::thread sender([&] { + std::vector serialization_buffer(16 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + + StartMessage start{ + .images_per_file = images_per_file, + .write_master_file = true + }; + EndMessage end{}; + + pusher.StartDataCollection(start); + + for (int64_t i = 0; i < static_cast(nframes); i++) { + DataMessage data_message; + data_message.number = i; + data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t), + x.GetXPixelsNum(), + x.GetYPixelsNum(), + x.GetImageMode(), + x.GetCompressionAlgorithm()); + serializer.SerializeImage(data_message); + REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i)); + } + + REQUIRE(pusher.EndDataCollection(end)); + }); + + // TCP puller consumer: drains the TCP side (with ACKs) so data keeps flowing + std::thread tcp_consumer([&] { + bool seen_end = false; + while (!seen_end) { + auto out = tcp_puller.PollImage(std::chrono::seconds(10)); + REQUIRE(out.has_value()); + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack{}; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(tcp_puller.SendAck(ack)); + } else if (out->cbor->end_message) { + PullerAckMessage ack{}; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(tcp_puller.SendAck(ack)); + seen_end = true; + } + // data frames: no ack needed for this test + } + }); + + // ZMQ puller consumer: verify the republished stream + size_t zmq_nimages = 0; + size_t zmq_errors = 0; + bool zmq_seen_start = false; + bool zmq_seen_end = false; + + std::thread zmq_consumer([&] { + auto timeout = std::chrono::seconds(30); + + // First message should be START + auto img = zmq_puller.PollImage(timeout); + if (!img || !img->cbor || !img->cbor->start_message) { + zmq_errors++; + return; + } + zmq_seen_start = true; + + // Republished START should have writer_notification_zmq_addr cleared + if (!img->cbor->start_message->writer_notification_zmq_addr.empty()) { + zmq_errors++; + } + + // Consume data and END + img = zmq_puller.PollImage(timeout); + while (img && img->cbor && !img->cbor->end_message) { + if (img->cbor->data_message) { + auto n = img->cbor->data_message->number; + if (img->cbor->data_message->image.GetCompressedSize() != x.GetPixelsNum() * sizeof(uint16_t)) + zmq_errors++; + else if (memcmp(img->cbor->data_message->image.GetCompressed(), + image1.data() + n * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t)) != 0) + zmq_errors++; + zmq_nimages++; + } + img = zmq_puller.PollImage(timeout); + } + if (img && img->cbor && img->cbor->end_message) + zmq_seen_end = true; + }); + + sender.join(); + tcp_consumer.join(); + zmq_consumer.join(); + + tcp_puller.Disconnect(); + zmq_puller.Disconnect(); + + // The repub uses non-blocking Put for data, so some frames *could* be dropped + // under extreme back-pressure, but with only 64 frames we expect all of them. + REQUIRE(zmq_seen_start); + REQUIRE(zmq_seen_end); + REQUIRE(zmq_nimages == nframes); + REQUIRE(zmq_errors == 0); +} diff --git a/writer/jfjoch_writer.cpp b/writer/jfjoch_writer.cpp index 57f2ef8d..0006f83a 100644 --- a/writer/jfjoch_writer.cpp +++ b/writer/jfjoch_writer.cpp @@ -155,11 +155,6 @@ int main(int argc, char **argv) { exit(EXIT_FAILURE); } - if (raw_tcp && zmq_repub_port > 0) { - logger.Error("Republish option at the moment only possible with ZeroMQ socket (no -T"); - exit(EXIT_FAILURE); - } - if (!root_dir.empty()) { try { std::filesystem::current_path(root_dir); @@ -234,7 +229,7 @@ int main(int argc, char **argv) { std::unique_ptr puller; if (raw_tcp) - puller = std::make_unique(argv[first_argc]); + puller = std::make_unique(argv[first_argc], std::nullopt, repub_zmq_addr, repub_watermark); else puller = std::make_unique(argv[first_argc], repub_zmq_addr, rcv_watermark, repub_watermark);