Some checks failed
Build Packages / build:rpm (ubuntu2204) (push) Has been cancelled
Build Packages / build:rpm (ubuntu2404) (push) Has been cancelled
Build Packages / Generate python client (push) Has been cancelled
Build Packages / Build documentation (push) Has been cancelled
Build Packages / Unit tests (push) Has been cancelled
Build Packages / Create release (push) Has been cancelled
Build Packages / build:rpm (rocky9_nocuda) (push) Has started running
Build Packages / build:rpm (rocky8_nocuda) (push) Has been cancelled
Build Packages / build:rpm (ubuntu2404_nocuda) (push) Has been cancelled
Build Packages / build:rpm (rocky9_sls9) (push) Has been cancelled
Build Packages / build:rpm (rocky8_sls9) (push) Has been cancelled
Build Packages / build:rpm (rocky8) (push) Has been cancelled
Build Packages / build:rpm (rocky9) (push) Has been cancelled
Build Packages / build:rpm (ubuntu2204_nocuda) (push) Has been cancelled
525 lines
20 KiB
C++
525 lines
20 KiB
C++
// SPDX-FileCopyrightText: 2025 Filip Leonarski, Paul Scherrer Institute <filip.leonarski@psi.ch>
|
|
// SPDX-License-Identifier: GPL-3.0-only
|
|
|
|
#include <random>
|
|
#include <catch2/catch_all.hpp>
|
|
|
|
#include "../image_pusher/TCPStreamPusher.h"
|
|
#include "../image_puller/TCPImagePuller.h"
|
|
|
|
TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") {
|
|
const size_t nframes = 128;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 16;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(1387);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i : image1) i = dist(g1);
|
|
|
|
std::string addr = "tcp://127.0.0.1:19001";
|
|
|
|
std::vector<std::unique_ptr<TCPImagePuller>> puller;
|
|
for (int i = 0; i < npullers; i++) {
|
|
puller.push_back(std::make_unique<TCPImagePuller>(addr, 64 * 1024 * 1024));
|
|
}
|
|
|
|
TCPStreamPusher pusher(addr,npullers);
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::vector<size_t> received(npullers, 0);
|
|
std::vector<size_t> processed(npullers, 0);
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_start = false;
|
|
bool seen_end = false;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = 0;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_start = true;
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
REQUIRE(seen_start);
|
|
auto n = out->cbor->data_message->number;
|
|
REQUIRE(((n / images_per_file) % npullers) == w);
|
|
received[w]++;
|
|
processed[w]++;
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = processed[w];
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t : receivers) t.join();
|
|
|
|
REQUIRE(received[0] == nframes / 2);
|
|
REQUIRE(received[1] == nframes / 2);
|
|
|
|
for (auto &p : puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") {
|
|
const size_t nframes = 64;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 8;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(42);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i : image1) i = dist(g1);
|
|
|
|
std::string addr = "tcp://127.0.0.1:19003";
|
|
|
|
std::vector<std::unique_ptr<TCPImagePuller>> puller;
|
|
for (int i = 0; i < npullers; i++)
|
|
puller.push_back(std::make_unique<TCPImagePuller>(addr, 64 * 1024 * 1024));
|
|
|
|
TCPStreamPusher pusher(addr,npullers);
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::atomic<bool> sent_fatal{false};
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
(void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
const auto final_msg = pusher.Finalize();
|
|
REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota"));
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_end = false;
|
|
bool local_fatal_sent = false;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(2));
|
|
if (!out.has_value()) {
|
|
// Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream.
|
|
if (local_fatal_sent)
|
|
break;
|
|
REQUIRE(out.has_value());
|
|
}
|
|
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
if (w == 0 && !sent_fatal.exchange(true)) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::DATA;
|
|
ack.ok = false;
|
|
ack.fatal = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
|
|
ack.error_code = TCPAckCode::DiskQuotaExceeded;
|
|
ack.error_text = "quota exceeded";
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
local_fatal_sent = true;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t : receivers) t.join();
|
|
|
|
for (auto &p : puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") {
|
|
const size_t nframes = 64;
|
|
const int64_t npullers = 2;
|
|
const int64_t images_per_file = 8;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::mt19937 g1(123);
|
|
std::uniform_int_distribution<uint16_t> dist;
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
|
|
for (auto &i : image1) i = dist(g1);
|
|
|
|
std::string addr = "tcp://127.0.0.1:19004";
|
|
|
|
std::vector<std::unique_ptr<TCPImagePuller>> puller;
|
|
for (int i = 0; i < npullers; i++)
|
|
puller.push_back(std::make_unique<TCPImagePuller>(addr, 64 * 1024 * 1024));
|
|
|
|
TCPStreamPusher pusher(addr,npullers);
|
|
|
|
// Wait for all pullers to connect before starting data collection
|
|
for (int attempt = 0; attempt < 100 && pusher.GetConnectedWriters() < static_cast<size_t>(npullers); ++attempt)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
REQUIRE(pusher.GetConnectedWriters() == static_cast<size_t>(npullers));
|
|
|
|
std::thread sender([&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{
|
|
.images_per_file = images_per_file,
|
|
.write_master_file = true
|
|
};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(),
|
|
x.GetYPixelsNum(),
|
|
x.GetImageMode(),
|
|
x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
|
|
std::this_thread::sleep_for(std::chrono::seconds(5));
|
|
|
|
auto progress = pusher.GetImagesWritten();
|
|
REQUIRE(progress.has_value());
|
|
REQUIRE(progress == nframes / 2);
|
|
});
|
|
|
|
std::vector<std::thread> receivers;
|
|
receivers.reserve(npullers);
|
|
|
|
for (int w = 0; w < npullers; w++) {
|
|
receivers.emplace_back([&, w] {
|
|
bool seen_end = false;
|
|
uint64_t processed = 0;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller[w]->PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::START;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = 0;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->data_message) {
|
|
auto number = out->cbor->data_message->number;
|
|
|
|
processed++;
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::DATA;
|
|
ack.ok = (number % 2 == 0) ? true : false;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = number;
|
|
ack.processed_images = processed;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
continue;
|
|
}
|
|
|
|
if (out->cbor->end_message) {
|
|
PullerAckMessage ack;
|
|
ack.ack_for = TCPFrameType::END;
|
|
ack.ok = true;
|
|
ack.fatal = false;
|
|
ack.run_number = h.run_number;
|
|
ack.socket_number = h.socket_number;
|
|
ack.image_number = 0;
|
|
ack.processed_images = processed;
|
|
ack.error_code = TCPAckCode::None;
|
|
REQUIRE(puller[w]->SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
sender.join();
|
|
for (auto &t : receivers) t.join();
|
|
|
|
for (auto &p : puller)
|
|
p->Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_AutoPort_StarBind", "[TCP]") {
|
|
const size_t nframes = 8;
|
|
const int64_t images_per_file = 4;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes, 7u);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1);
|
|
TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024);
|
|
|
|
std::thread receiver([&] {
|
|
bool seen_end = false;
|
|
uint64_t processed = 0;
|
|
|
|
while (!seen_end) {
|
|
auto out = puller.PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack{.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .error_code = TCPAckCode::None};
|
|
REQUIRE(puller.SendAck(ack));
|
|
} else if (out->cbor->data_message) {
|
|
processed++;
|
|
} else if (out->cbor->end_message) {
|
|
PullerAckMessage ack{.ack_for = TCPFrameType::END, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .processed_images = processed, .error_code = TCPAckCode::None};
|
|
REQUIRE(puller.SendAck(ack));
|
|
seen_end = true;
|
|
}
|
|
}
|
|
});
|
|
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{.images_per_file = images_per_file, .write_master_file = true};
|
|
EndMessage end{};
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(), x.GetYPixelsNum(),
|
|
x.GetImageMode(), x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
|
|
}
|
|
|
|
REQUIRE(pusher.EndDataCollection(end));
|
|
receiver.join();
|
|
puller.Disconnect();
|
|
}
|
|
|
|
TEST_CASE("TCPImageCommTest_DisconnectMidWrite_NoHang", "[TCP]") {
|
|
const size_t nframes = 256;
|
|
const int64_t images_per_file = 16;
|
|
|
|
DiffractionExperiment x(DetJF(1));
|
|
x.Raw();
|
|
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
|
|
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
|
|
|
|
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes, 11u);
|
|
|
|
TCPStreamPusher pusher("tcp://127.0.0.1:*", 1);
|
|
TCPImagePuller puller(pusher.GetAddress()[0], 64 * 1024 * 1024);
|
|
|
|
std::thread receiver([&] {
|
|
bool disconnected = false;
|
|
while (!disconnected) {
|
|
auto out = puller.PollImage(std::chrono::seconds(10));
|
|
REQUIRE(out.has_value());
|
|
REQUIRE(out->cbor != nullptr);
|
|
REQUIRE(out->tcp_msg != nullptr);
|
|
|
|
const auto &h = out->tcp_msg->header;
|
|
if (out->cbor->start_message) {
|
|
PullerAckMessage ack{.ack_for = TCPFrameType::START, .ok = true, .run_number = h.run_number,
|
|
.socket_number = h.socket_number, .error_code = TCPAckCode::None};
|
|
REQUIRE(puller.SendAck(ack));
|
|
} else if (out->cbor->data_message) {
|
|
puller.Disconnect(); // simulate puller disappearing mid-stream
|
|
disconnected = true;
|
|
}
|
|
}
|
|
});
|
|
|
|
auto sender = std::async(std::launch::async, [&] {
|
|
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
|
|
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
|
|
|
|
StartMessage start{.images_per_file = images_per_file, .write_master_file = true};
|
|
EndMessage end{};
|
|
|
|
pusher.StartDataCollection(start);
|
|
|
|
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
|
|
DataMessage data_message;
|
|
data_message.number = i;
|
|
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
|
|
x.GetPixelsNum() * sizeof(uint16_t),
|
|
x.GetXPixelsNum(), x.GetYPixelsNum(),
|
|
x.GetImageMode(), x.GetCompressionAlgorithm());
|
|
serializer.SerializeImage(data_message);
|
|
(void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i);
|
|
}
|
|
|
|
return pusher.EndDataCollection(end);
|
|
});
|
|
|
|
REQUIRE(sender.wait_for(std::chrono::seconds(20)) == std::future_status::ready);
|
|
CHECK(sender.get() == false);
|
|
|
|
receiver.join();
|
|
} |