TCP: Allow to get written image count

This commit is contained in:
2026-03-04 16:13:14 +01:00
parent 536fc0761d
commit d2c66edd45
10 changed files with 48 additions and 220 deletions

View File

@@ -11,7 +11,6 @@ void HDF5FilePusher::StartDataCollection(StartMessage &message) {
writer = std::make_unique<FileWriter>(message);
writer_future = std::async(std::launch::async, &HDF5FilePusher::WriterThread, this);
images_written = 0;
images_err = 0;
}
bool HDF5FilePusher::EndDataCollection(const EndMessage &message) {
@@ -36,12 +35,8 @@ bool HDF5FilePusher::SendImage(const uint8_t *image_data, size_t image_size, int
auto deserialized = CBORStream2Deserialize(image_data, image_size);
if (deserialized->data_message) {
try {
writer->Write(*deserialized->data_message);
images_written++;
} catch (const JFJochException &e) {
images_err++;
}
writer->Write(*deserialized->data_message);
images_written++;
} else
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"HDF5FilePusher::SendImage accepts only data image");
@@ -78,14 +73,6 @@ std::string HDF5FilePusher::PrintSetup() const {
return "HDF5FilePusher: Images are written directly to file in base directory " + currentPath.string();
}
std::optional<ImagePusherAckProgress> HDF5FilePusher::GetAckProgress() const {
uint64_t ack_ok = images_written;
uint64_t ack_bad = images_err;
uint64_t ack_total = ack_ok + ack_bad;
return ImagePusherAckProgress{
.data_acked_ok = ack_ok,
.data_acked_bad = ack_bad,
.data_acked_total = ack_total
};
std::optional<uint64_t> HDF5FilePusher::GetImagesWritten() const {
return images_written;
}

View File

@@ -19,7 +19,6 @@ class HDF5FilePusher : public ImagePusher {
void WriterThread();
std::atomic<uint64_t> images_written = 0;
std::atomic<uint64_t> images_err = 0;
public:
// Thread safety: StartDataCollection, EndDataCollection and SendCalibration must run poorly in serial context
// SendImage can be executed in parallel
@@ -31,7 +30,7 @@ public:
std::string PrintSetup() const override;
std::optional<ImagePusherAckProgress> GetAckProgress() const override;
std::optional<uint64_t> GetImagesWritten() const override;
};

View File

@@ -20,14 +20,6 @@ struct ImagePusherQueueElement {
bool end;
};
struct ImagePusherAckProgress {
uint64_t data_sent = 0;
uint64_t data_acked_ok = 0;
uint64_t data_acked_bad = 0;
uint64_t data_acked_total = 0;
uint64_t data_ack_pending = 0;
};
void PrepareCBORImage(DataMessage& message,
const DiffractionExperiment &experiment,
void *image, size_t image_size);
@@ -43,7 +35,7 @@ public:
virtual std::string GetWriterNotificationSocketAddress() const;
virtual ~ImagePusher() = default;
virtual std::string PrintSetup() const = 0;
virtual std::optional<ImagePusherAckProgress> GetAckProgress() const { return std::nullopt; }
virtual std::optional<uint64_t> GetImagesWritten() const { return std::nullopt; }
};

View File

@@ -168,16 +168,11 @@ bool TCPStreamPusher::SendCalibration(const CompressedImage &message) {
return socket[0]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::CALIBRATION);
}
std::optional<ImagePusherAckProgress> TCPStreamPusher::GetAckProgress() const {
ImagePusherAckProgress out{
};
std::optional<uint64_t> TCPStreamPusher::GetImagesWritten() const {
uint64_t ret = 0;
for (const auto &s : socket) {
auto p = s->GetDataAckProgress();
out.data_sent += p.data_sent;
out.data_acked_ok += p.data_acked_ok;
out.data_acked_bad += p.data_acked_bad;
out.data_acked_total += p.data_acked_total;
out.data_ack_pending += p.data_ack_pending;
ret += p.data_acked_ok;
}
return out;
return ret;
}

View File

@@ -30,5 +30,5 @@ public:
std::string Finalize() override;
std::string PrintSetup() const override;
std::optional<ImagePusherAckProgress> GetAckProgress() const override;
std::optional<uint64_t> GetImagesWritten() const override;
};

View File

@@ -16,6 +16,14 @@
#include "../common/Logger.h"
#include "../common/JfjochTCP.h"
struct ImagePusherAckProgress {
uint64_t data_sent = 0;
uint64_t data_acked_ok = 0;
uint64_t data_acked_bad = 0;
uint64_t data_acked_total = 0;
uint64_t data_ack_pending = 0;
};
class TCPStreamPusherSocket {
struct InflightZC {
ZeroCopyReturnValue *z = nullptr;

View File

@@ -37,6 +37,10 @@ void ZMQStream2Pusher::SendImage(ZeroCopyReturnValue &z) {
}
void ZMQStream2Pusher::StartDataCollection(StartMessage& message) {
{
std::unique_lock ul(images_written_mutex);
images_written = std::nullopt;
}
if (message.images_per_file < 1)
throw JFJochException(JFJochExceptionCategory::InputParameterInvalid,
"Images per file cannot be zero or negative");
@@ -93,6 +97,7 @@ std::string ZMQStream2Pusher::Finalize() {
std::string ret;
if (transmission_error)
ret += "Timeout sending images (e.g., writer disabled during data collection);";
uint64_t images = 0;
if (writer_notification_socket) {
for (int i = 0; i < socket.size(); i++) {
auto n = writer_notification_socket->Receive(run_number, run_name);
@@ -100,10 +105,18 @@ std::string ZMQStream2Pusher::Finalize() {
ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end";
else if (n->socket_number >= socket.size())
ret += "Writer " + std::to_string(i) + ": mismatch in socket number";
else if (!n->ok)
ret += "Writer " + std::to_string(i) + ": " + n->error;
else {
if (!n->ok)
ret += "Writer " + std::to_string(i) + ": " + n->error;
images += n->processed_images;
}
}
}
{
std::unique_lock ul(images_written_mutex);
images_written = images;
}
return ret;
}
@@ -125,3 +138,8 @@ std::string ZMQStream2Pusher::PrintSetup() const {
output += s->GetEndpointName() + " ";
return output;
}
std::optional<uint64_t> ZMQStream2Pusher::GetImagesWritten() const {
std::unique_lock ul(images_written_mutex);
return images_written;
}

View File

@@ -22,6 +22,9 @@ class ZMQStream2Pusher : public ImagePusher {
uint64_t run_number = 0;
std::string run_name;
std::atomic<bool> transmission_error = false;
mutable std::mutex images_written_mutex;
std::optional<uint64_t> images_written;
public:
explicit ZMQStream2Pusher(const std::vector<std::string>& addr,
std::optional<int32_t> send_buffer_high_watermark = {},
@@ -43,8 +46,8 @@ public:
std::string Finalize() override;
std::string PrintSetup() const override;
std::optional<uint64_t> GetImagesWritten() const override;
};
#endif //JUNGFRAUJOCH_ZMQSTREAM2PUSHER_H

View File

@@ -1548,11 +1548,7 @@ TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver]
REQUIRE_NOTHROW(writer_future.get());
auto ack = pusher.GetAckProgress();
auto ack = pusher.GetImagesWritten();
REQUIRE(ack.has_value());
CHECK(ack->data_acked_ok == experiment.GetImageNum());
CHECK(ack->data_acked_bad == 0);
CHECK(ack->data_acked_total == experiment.GetImageNum());
CHECK(ack->data_ack_pending == 0);
CHECK(ack->data_sent == experiment.GetImageNum());
CHECK(ack == experiment.GetImageNum());
}

View File

@@ -331,22 +331,11 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") {
REQUIRE(pusher.EndDataCollection(end));
const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_acked_total == nframes)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
std::this_thread::sleep_for(std::chrono::seconds(5));
auto progress = pusher.GetAckProgress();
auto progress = pusher.GetImagesWritten();
REQUIRE(progress.has_value());
REQUIRE(progress->data_sent == nframes);
REQUIRE(progress->data_acked_ok == nframes / 2);
REQUIRE(progress->data_acked_bad == nframes / 2);
REQUIRE(progress->data_acked_total == nframes);
REQUIRE(progress->data_ack_pending == 0);
REQUIRE(progress == nframes / 2);
});
std::vector<std::thread> receivers;
@@ -419,162 +408,3 @@ TEST_CASE("TCPImageCommTest_GetAckProgress_Correct", "[TCP]") {
for (auto &p : puller)
p->Disconnect();
}
TEST_CASE("TCPImageCommTest_GetAckProgress_InFlightPending", "[TCP]") {
const size_t nframes = 128;
const int64_t npullers = 2;
const int64_t images_per_file = 8;
DiffractionExperiment x(DetJF(1));
x.Raw();
x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4)
.ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION);
std::mt19937 g1(321);
std::uniform_int_distribution<uint16_t> dist;
std::vector<uint16_t> image1(x.GetPixelsNum() * nframes);
for (auto &i : image1) i = dist(g1);
std::vector<std::string> addr{
"tcp://127.0.0.1:19031",
"tcp://127.0.0.1:19032"
};
std::vector<std::unique_ptr<TCPImagePuller>> puller;
for (int i = 0; i < npullers; i++) {
puller.push_back(std::make_unique<TCPImagePuller>(addr[i], 64 * 1024 * 1024));
}
TCPStreamPusher pusher(
addr,
64 * 1024 * 1024,
128 * 1024,
8192
);
std::atomic<bool> observed_pending{false};
std::thread sender([&] {
std::vector<uint8_t> serialization_buffer(16 * 1024 * 1024);
CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size());
StartMessage start{
.images_per_file = images_per_file,
.write_master_file = true
};
EndMessage end{};
pusher.StartDataCollection(start);
for (int64_t i = 0; i < static_cast<int64_t>(nframes); i++) {
DataMessage data_message;
data_message.number = i;
data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(),
x.GetPixelsNum() * sizeof(uint16_t),
x.GetXPixelsNum(),
x.GetYPixelsNum(),
x.GetImageMode(),
x.GetCompressionAlgorithm());
serializer.SerializeImage(data_message);
REQUIRE(pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i));
if (i >= 16 && !observed_pending.load()) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_sent > progress->data_acked_total && progress->data_ack_pending > 0) {
observed_pending = true;
}
}
}
REQUIRE(pusher.EndDataCollection(end));
REQUIRE(observed_pending.load());
const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5);
while (std::chrono::steady_clock::now() < deadline) {
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
if (progress->data_acked_total == nframes)
break;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto progress = pusher.GetAckProgress();
REQUIRE(progress.has_value());
REQUIRE(progress->data_sent == nframes);
REQUIRE(progress->data_acked_ok == nframes);
REQUIRE(progress->data_acked_bad == 0);
REQUIRE(progress->data_acked_total == nframes);
REQUIRE(progress->data_ack_pending == 0);
});
std::vector<std::thread> receivers;
receivers.reserve(npullers);
for (int w = 0; w < npullers; w++) {
receivers.emplace_back([&, w] {
bool seen_end = false;
uint64_t processed = 0;
while (!seen_end) {
auto out = puller[w]->PollImage(std::chrono::seconds(10));
REQUIRE(out.has_value());
REQUIRE(out->cbor != nullptr);
REQUIRE(out->tcp_msg != nullptr);
const auto &h = out->tcp_msg->header;
if (out->cbor->start_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::START;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = 0;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->data_message) {
processed++;
std::this_thread::sleep_for(std::chrono::milliseconds(3));
PullerAckMessage ack;
ack.ack_for = TCPFrameType::DATA;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = static_cast<uint64_t>(out->cbor->data_message->number);
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
continue;
}
if (out->cbor->end_message) {
PullerAckMessage ack;
ack.ack_for = TCPFrameType::END;
ack.ok = true;
ack.fatal = false;
ack.run_number = h.run_number;
ack.socket_number = h.socket_number;
ack.image_number = 0;
ack.processed_images = processed;
ack.error_code = TCPAckCode::None;
REQUIRE(puller[w]->SendAck(ack));
seen_end = true;
}
}
});
}
sender.join();
for (auto &t : receivers) t.join();
for (auto &p : puller)
p->Disconnect();
}