From cd561dd3d989946b84a4a089de00957d4bb9eda9 Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 09:00:43 +0100 Subject: [PATCH 01/42] ScaleAndMerge: rotation wedge can be refined --- image_analysis/scale_merge/ScaleAndMerge.cpp | 19 +++++++++++++------ image_analysis/scale_merge/ScaleAndMerge.h | 2 ++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/image_analysis/scale_merge/ScaleAndMerge.cpp b/image_analysis/scale_merge/ScaleAndMerge.cpp index a45264bc..2d813178 100644 --- a/image_analysis/scale_merge/ScaleAndMerge.cpp +++ b/image_analysis/scale_merge/ScaleAndMerge.cpp @@ -121,7 +121,6 @@ namespace { weight_(SafeInv(sigma_obs, 1.0)), delta_phi_(r.delta_phi_deg), lp_(SafeInv(r.rlp, 1.0)), - half_wedge_(wedge_deg / 2.0), c1_(r.zeta / std::sqrt(2.0)), partiality_(r.partiality) { } @@ -130,11 +129,13 @@ namespace { bool operator()(const T *const G, const T *const mosaicity, const T *const Itrue, + const T *const wedge, T *residual) const { T partiality; if (mosaicity[0] >= 0.0) { - const T arg_plus = T(delta_phi_ + half_wedge_) * T(c1_) / mosaicity[0]; - const T arg_minus = T(delta_phi_ - half_wedge_) * T(c1_) / mosaicity[0]; + const T half_wedge = wedge[0] / T(2.0); + const T arg_plus = T(delta_phi_ + half_wedge) * T(c1_) / mosaicity[0]; + const T arg_minus = T(delta_phi_ - half_wedge) * T(c1_) / mosaicity[0]; partiality = (ceres::erf(arg_plus) - ceres::erf(arg_minus)) / T(2.0); } else partiality = T(1.0); @@ -148,7 +149,6 @@ namespace { double weight_; double delta_phi_; double lp_; - double half_wedge_; double c1_; double partiality_; }; @@ -275,17 +275,20 @@ namespace { } } + double wedge = opt.wedge_deg.value_or(0.0); + std::vector is_valid_hkl_slot(nhkl, false); for (const auto &o: obs) { switch (opt.partiality_model) { case ScaleMergeOptions::PartialityModel::Rotation: { - auto *cost = new ceres::AutoDiffCostFunction( + auto *cost = new ceres::AutoDiffCostFunction( new IntensityRotResidual(*o.r, o.sigma, opt.wedge_deg.value_or(0.0))); problem.AddResidualBlock(cost, nullptr, &g[o.img_id], &mosaicity[o.img_id], - &Itrue[o.hkl_slot]); + &Itrue[o.hkl_slot], + &wedge); } break; case ScaleMergeOptions::PartialityModel::Still: { @@ -374,6 +377,10 @@ namespace { problem.SetParameterUpperBound(&mosaicity[i], 0, opt.mosaicity_max_deg); } } + if (!opt.refine_wedge) + problem.SetParameterBlockConstant(&wedge); + else + problem.SetParameterLowerBound(&wedge, 0, 0.0); } // use all available threads diff --git a/image_analysis/scale_merge/ScaleAndMerge.h b/image_analysis/scale_merge/ScaleAndMerge.h index 9710ce05..d154e941 100644 --- a/image_analysis/scale_merge/ScaleAndMerge.h +++ b/image_analysis/scale_merge/ScaleAndMerge.h @@ -48,6 +48,8 @@ struct ScaleMergeOptions { int64_t image_cluster = 1; + bool refine_wedge = false; + enum class PartialityModel {Fixed, Rotation, Unity, Still} partiality_model = PartialityModel::Fixed; }; -- 2.49.1 From 939bb02ce294e4b961d991bdc195a5b585d2c862 Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 11:22:00 +0100 Subject: [PATCH 02/42] XtalOptimizer: Add lattice rotation-only optimizer --- image_analysis/IndexAndRefine.cpp | 12 +- .../geom_refinement/XtalOptimizer.cpp | 161 ++++++++++++++++++ .../geom_refinement/XtalOptimizer.h | 5 + 3 files changed, 172 insertions(+), 6 deletions(-) diff --git a/image_analysis/IndexAndRefine.cpp b/image_analysis/IndexAndRefine.cpp index c9d16b67..5a9357d5 100644 --- a/image_analysis/IndexAndRefine.cpp +++ b/image_analysis/IndexAndRefine.cpp @@ -103,20 +103,20 @@ void IndexAndRefine::RefineGeometryIfNeeded(DataMessage &msg, IndexAndRefine::In .max_time = 0.04 // 40 ms is max allowed time for the operation }; - if (experiment.IsRotationIndexing()) { - data.refine_beam_center = false; - data.refine_rotation_axis = false; - data.refine_unit_cell = false; - } if (outcome.symmetry.crystal_system == gemmi::CrystalSystem::Trigonal) data.crystal_system = gemmi::CrystalSystem::Hexagonal; + switch (experiment.GetIndexingSettings().GetGeomRefinementAlgorithm()) { case GeomRefinementAlgorithmEnum::None: break; case GeomRefinementAlgorithmEnum::BeamCenter: - if (XtalOptimizer(data, msg.spots)) { + if (experiment.IsRotationIndexing()) { + XtalOptimizerRotationOnly(data, msg.spots, 0.2); + XtalOptimizerRotationOnly(data, msg.spots, 0.1); + XtalOptimizerRotationOnly(data, msg.spots, 0.05); + } else if (XtalOptimizer(data, msg.spots)) { outcome.experiment.BeamX_pxl(data.geom.GetBeamX_pxl()) .BeamY_pxl(data.geom.GetBeamY_pxl()); outcome.beam_center_updated = true; diff --git a/image_analysis/geom_refinement/XtalOptimizer.cpp b/image_analysis/geom_refinement/XtalOptimizer.cpp index 0a4685df..aec455c9 100644 --- a/image_analysis/geom_refinement/XtalOptimizer.cpp +++ b/image_analysis/geom_refinement/XtalOptimizer.cpp @@ -187,6 +187,60 @@ struct XtalResidual { gemmi::CrystalSystem symmetry; }; +struct XtalResidualRotationOnlyPrecomp { + XtalResidualRotationOnlyPrecomp(const Coord &recip_obs, + const Coord &latt_a, + const Coord &latt_b, + const Coord &latt_c, + double h, double k, double l) + : s_obs(recip_obs), + a0(latt_a), b0(latt_b), c0(latt_c), + h(h), k(k), l(l) { + } + + template + bool operator()(const T *const rot_aa, T *residual) const { + // Rotate the CURRENT lattice vectors by rot_aa (proper SO(3) rotation) + T a_in[3] = {T(a0.x), T(a0.y), T(a0.z)}; + T b_in[3] = {T(b0.x), T(b0.y), T(b0.z)}; + T c_in[3] = {T(c0.x), T(c0.y), T(c0.z)}; + + T a_rot[3], b_rot[3], c_rot[3]; + ceres::AngleAxisRotatePoint(rot_aa, a_in, a_rot); + ceres::AngleAxisRotatePoint(rot_aa, b_in, b_rot); + ceres::AngleAxisRotatePoint(rot_aa, c_in, c_rot); + + const Eigen::Matrix A(a_rot[0], a_rot[1], a_rot[2]); + const Eigen::Matrix B(b_rot[0], b_rot[1], b_rot[2]); + const Eigen::Matrix C(c_rot[0], c_rot[1], c_rot[2]); + + // Reciprocal basis from rotated direct lattice + const Eigen::Matrix BxC = B.cross(C); + const Eigen::Matrix CxA = C.cross(A); + const Eigen::Matrix AxB = A.cross(B); + + const T V = A.dot(BxC); + const T invV = T(1) / V; + + const Eigen::Matrix Astar = BxC * invV; + const Eigen::Matrix Bstar = CxA * invV; + const Eigen::Matrix Cstar = AxB * invV; + + const Eigen::Matrix s_pred = + Astar * T(h) + Bstar * T(k) + Cstar * T(l); + + // Residual in reciprocal space + residual[0] = T(s_obs.x) - s_pred[0]; + residual[1] = T(s_obs.y) - s_pred[1]; + residual[2] = T(s_obs.z) - s_pred[2]; + return true; + } + + Coord s_obs; + Coord a0, b0, c0; + double h, k, l; +}; + inline void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt, double rod[3], double lengths[3]) { @@ -647,3 +701,110 @@ bool XtalOptimizer(XtalOptimizerData &data, const std::vector &spots return XtalOptimizerInternal(data, spots, 0.1); } +bool XtalOptimizerRotationOnly(XtalOptimizerData &data, + const std::vector &spots, + const float tolerance) { + try { + // Parameter: angle-axis for the extra rotation. Identity == {0,0,0}. + double rot_aa[3] = {0.0, 0.0, 0.0}; + + const Coord a0 = data.latt.Vec0(); + const Coord b0 = data.latt.Vec1(); + const Coord c0 = data.latt.Vec2(); + + // Spot selection by current indexing (same approach as XtalOptimizerInternal) + const Coord vec0 = data.latt.Vec0(); + const Coord vec1 = data.latt.Vec1(); + const Coord vec2 = data.latt.Vec2(); + + const float tol_sq = tolerance * tolerance; + + ceres::Problem problem; + + for (const auto &pt : spots) { + if (!data.index_ice_rings && pt.ice_ring) + continue; + + // Compute fractional HKL using the CURRENT lattice + Coord recip_index = pt.ReciprocalCoord(data.geom); + if (data.axis.has_value()) + recip_index = data.axis->GetTransformationAngle(pt.phi) * recip_index; + + const double h_fp = static_cast(recip_index * vec0); + const double k_fp = static_cast(recip_index * vec1); + const double l_fp = static_cast(recip_index * vec2); + + const double h = std::round(h_fp); + const double k = std::round(k_fp); + const double l = std::round(l_fp); + + const double norm_sq = + (h - h_fp) * (h - h_fp) + + (k - k_fp) * (k - k_fp) + + (l - l_fp) * (l - l_fp); + + if (norm_sq > static_cast(tol_sq)) + continue; + + const Coord s_obs = data.geom.DetectorToRecip(pt.x, pt.y); + + auto *cost = + new ceres::AutoDiffCostFunction( + new XtalResidualRotationOnlyPrecomp(s_obs, a0, b0, c0, h, k, l) + ); + + problem.AddResidualBlock(cost, nullptr, rot_aa); + } + + if (problem.NumResidualBlocks() < data.min_spots) + return false; + + ceres::Solver::Options options; + options.linear_solver_type = ceres::DENSE_QR; + options.minimizer_progress_to_stdout = false; + options.max_solver_time_in_seconds = data.max_time; + options.logging_type = ceres::LoggingType::SILENT; + options.num_threads = 1; + + ceres::Solver::Summary summary; + ceres::Solve(options, &problem, &summary); + + // Apply rotation to lattice (L' = R * L, acting on column vectors) + double R_raw[9]; + ceres::AngleAxisToRotationMatrix(rot_aa, R_raw); // row-major 3x3 + + Eigen::Matrix3d R; + R << R_raw[0], R_raw[1], R_raw[2], + R_raw[3], R_raw[4], R_raw[5], + R_raw[6], R_raw[7], R_raw[8]; + + const Eigen::Vector3d A(a0.x, a0.y, a0.z); + const Eigen::Vector3d B(b0.x, b0.y, b0.z); + const Eigen::Vector3d C(c0.x, c0.y, c0.z); + + const Eigen::Vector3d A2 = R * A; + const Eigen::Vector3d B2 = R * B; + const Eigen::Vector3d C2 = R * C; + + data.latt = CrystalLattice( + Coord(static_cast(A2.x()), static_cast(A2.y()), static_cast(A2.z())), + Coord(static_cast(B2.x()), static_cast(B2.y()), static_cast(B2.z())), + Coord(static_cast(C2.x()), static_cast(C2.y()), static_cast(C2.z())) + ); + + double theta = std::sqrt(rot_aa[0] * rot_aa[0] + rot_aa[1] * rot_aa[1] + rot_aa[2] * rot_aa[2]); + data.angle_corr = theta; + if (theta > 1e-6) { + Coord rot; + rot.x = rot_aa[0] / theta; + rot.y = rot_aa[1] / theta; + rot.z = rot_aa[2] / theta; + data.angle_axis = rot; + } else + data.angle_axis.reset(); + + return true; + } catch (...) { + return false; + } +} diff --git a/image_analysis/geom_refinement/XtalOptimizer.h b/image_analysis/geom_refinement/XtalOptimizer.h index 79422e6e..9ca306dd 100644 --- a/image_analysis/geom_refinement/XtalOptimizer.h +++ b/image_analysis/geom_refinement/XtalOptimizer.h @@ -38,6 +38,10 @@ struct XtalOptimizerData { // output std::optional beam_corr_x; std::optional beam_corr_y; + + // For rotation only optimizer + std::optional angle_corr; + std::optional angle_axis; }; void LatticeToRodriguesAndLengths_GS(const CrystalLattice &latt, double rod[3], double lengths[3]); @@ -54,5 +58,6 @@ CrystalLattice AngleAxisAndCellToLattice(const double rod[3], double gamma_rad); bool XtalOptimizer(XtalOptimizerData &data, const std::vector &spots); +bool XtalOptimizerRotationOnly(XtalOptimizerData &data, const std::vector &spots, float tolerance); #endif //JFJOCH_XTALOPTIMIZER_H -- 2.49.1 From a3a986830bb4462fe8bdb4cf052f6129255c2612 Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 12:24:05 +0100 Subject: [PATCH 03/42] TCP: Implemented ACK return stream, as a feedback channel (to be read properly!) --- broker/JFJochBrokerParser.cpp | 2 - common/JfjochTCP.h | 29 +++- image_puller/ImagePuller.h | 16 ++ image_puller/TCPImagePuller.cpp | 95 +++++++++-- image_puller/TCPImagePuller.h | 3 + image_pusher/HDF5FilePusher.h | 1 - image_pusher/TCPStreamPusher.cpp | 86 ++++++---- image_pusher/TCPStreamPusher.h | 5 - image_pusher/TCPStreamPusherSocket.cpp | 169 +++++++++++++++++++ image_pusher/TCPStreamPusherSocket.h | 16 ++ tests/TCPImagePusherTest.cpp | 219 ++++++++++++++++++++++--- writer/StreamWriter.cpp | 48 ++++++ writer/StreamWriter.h | 2 + 13 files changed, 611 insertions(+), 80 deletions(-) diff --git a/broker/JFJochBrokerParser.cpp b/broker/JFJochBrokerParser.cpp index a0d7c966..821ae5b2 100644 --- a/broker/JFJochBrokerParser.cpp +++ b/broker/JFJochBrokerParser.cpp @@ -201,8 +201,6 @@ std::unique_ptr ParseTCPImagePusher(const org::openapitools::server auto tmp = std::make_unique(j.getZeromq().getImageSocket(), send_buffer_size); - if (j.getZeromq().writerNotificationSocketIsSet()) - tmp->WriterNotificationSocket(j.getZeromq().getWriterNotificationSocket()); return std::move(tmp); } diff --git a/common/JfjochTCP.h b/common/JfjochTCP.h index b5fe36a7..51c6d2a4 100644 --- a/common/JfjochTCP.h +++ b/common/JfjochTCP.h @@ -6,15 +6,33 @@ #include constexpr uint32_t JFJOCH_TCP_MAGIC = 0x4A464A54; // JFJT -constexpr uint32_t JFJOCH_TCP_VERSION = 1; +constexpr uint32_t JFJOCH_TCP_VERSION = 2; enum class TCPFrameType : uint16_t { START = 1, DATA = 2, CALIBRATION = 3, - END = 4 + END = 4, + ACK = 5, + CANCEL = 6 }; +enum class TCPAckCode : uint16_t { + None = 0, + StartFailed = 1, + DataWriteFailed = 2, + EndFailed = 3, + DiskQuotaExceeded = 4, + NoSpaceLeft = 5, + PermissionDenied = 6, + IoError = 7, + ProtocolError = 8 +}; + +constexpr uint32_t TCP_ACK_FLAG_OK = 1u << 0; +constexpr uint32_t TCP_ACK_FLAG_FATAL = 1u << 1; +constexpr uint32_t TCP_ACK_FLAG_HAS_ERROR_TEXT = 1u << 2; + struct alignas(64) TcpFrameHeader { uint32_t magic = JFJOCH_TCP_MAGIC; uint16_t version = JFJOCH_TCP_VERSION ; @@ -24,5 +42,10 @@ struct alignas(64) TcpFrameHeader { uint32_t socket_number = 0; uint32_t flags = 0; uint64_t run_number = 0; - uint64_t reserved[4] = {0, 0, 0, 0}; + + uint32_t ack_processed_images = 0; + uint16_t ack_code = 0; + uint16_t ack_for = 0; + + uint64_t reserved[2] = {0, 0}; }; \ No newline at end of file diff --git a/image_puller/ImagePuller.h b/image_puller/ImagePuller.h index c6ef4971..66f0b545 100644 --- a/image_puller/ImagePuller.h +++ b/image_puller/ImagePuller.h @@ -11,6 +11,19 @@ #include "../frame_serialize/CBORStream2Deserializer.h" #include "../common/ThreadSafeFIFO.h" #include "../common/JfjochTCP.h" +#include "../common/JfjochTCP.h" + +struct PullerAckMessage { + TCPFrameType ack_for = TCPFrameType::DATA; + bool ok = true; + bool fatal = false; + uint64_t run_number = 0; + uint32_t socket_number = 0; + uint64_t image_number = 0; + uint64_t processed_images = 0; + TCPAckCode error_code = TCPAckCode::None; + std::string error_text; +}; struct RawFrame { TcpFrameHeader header{}; @@ -42,6 +55,9 @@ public: [[nodiscard]] size_t GetMaxFifoUtilization() const; [[nodiscard]] size_t GetCurrentFifoUtilization() const; + virtual bool SupportsAck() const { return false; } + virtual bool SendAck(const PullerAckMessage &) { return false; } + virtual void Disconnect() = 0; }; diff --git a/image_puller/TCPImagePuller.cpp b/image_puller/TCPImagePuller.cpp index 5c9fb0a8..b002e662 100644 --- a/image_puller/TCPImagePuller.cpp +++ b/image_puller/TCPImagePuller.cpp @@ -51,6 +51,77 @@ TCPImagePuller::TCPImagePuller(const std::string &tcp_addr, cbor_thread = std::thread(&TCPImagePuller::CBORThread, this); } +bool TCPImagePuller::SendAll(const void *buf, size_t len) { + const auto *p = static_cast(buf); + size_t sent = 0; + while (sent < len) { + if (disconnect) + return false; + + int local_fd = -1; + { + std::unique_lock ul(fd_mutex); + local_fd = fd; + } + if (local_fd < 0) + return false; + + ssize_t rc = ::send(local_fd, p + sent, len - sent, MSG_NOSIGNAL); + if (rc < 0) { + if (errno == EINTR) + continue; + return false; + } + sent += static_cast(rc); + } + return true; +} + +bool TCPImagePuller::SendAck(const PullerAckMessage &ack) { + TcpFrameHeader h{}; + h.type = static_cast(TCPFrameType::ACK); + h.run_number = ack.run_number; + h.socket_number = ack.socket_number; + h.image_number = ack.image_number; + h.flags = 0; + if (ack.ok) + h.flags |= TCP_ACK_FLAG_OK; + if (ack.fatal) + h.flags |= TCP_ACK_FLAG_FATAL; + if (!ack.error_text.empty()) + h.flags |= TCP_ACK_FLAG_HAS_ERROR_TEXT; + + h.ack_for = static_cast(ack.ack_for); + h.ack_processed_images = ack.processed_images; + h.ack_code = static_cast(ack.error_code); + h.payload_size = ack.error_text.size(); + + if (!SendAll(&h, sizeof(h))) + return false; + if (!ack.error_text.empty()) + return SendAll(ack.error_text.data(), ack.error_text.size()); + return true; +} + +void TCPImagePuller::CBORThread() { + auto ret = cbor_fifo.GetBlocking(); + while (ret.tcp_msg) { + try { + const auto type = static_cast(ret.tcp_msg->header.type); + if (type == TCPFrameType::CANCEL) { + outside_fifo.PutBlocking(ret); + } else { + ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); + outside_fifo.PutBlocking(ret); + } + } catch (const JFJochException &e) { + logger.ErrorException(e); + } + ret = cbor_fifo.GetBlocking(); + } + outside_fifo.PutBlocking(ret); +} + TCPImagePuller::~TCPImagePuller() { TCPImagePuller::Disconnect(); } @@ -179,6 +250,17 @@ void TCPImagePuller::ReceiverThread() { continue; } + // Ignore ACK on puller side + if (static_cast(frame.header.type) == TCPFrameType::ACK) { + if (frame.header.payload_size > 0) { + std::vector discard(frame.header.payload_size); + if (!ReadExact(discard.data(), discard.size())) { + CloseSocket(); + } + } + continue; + } + ImagePullerOutput out; out.tcp_msg = std::make_shared(); out.tcp_msg->header = frame.header; @@ -206,19 +288,6 @@ void TCPImagePuller::ReceiverThread() { cbor_fifo.PutBlocking(ImagePullerOutput{}); } -void TCPImagePuller::CBORThread() { - auto ret = cbor_fifo.GetBlocking(); - while (ret.tcp_msg) { - try { - ret.cbor = CBORStream2Deserialize(ret.tcp_msg->payload.data(), ret.tcp_msg->payload.size()); - outside_fifo.PutBlocking(ret); - } catch (const JFJochException &e) { - logger.ErrorException(e); - } - ret = cbor_fifo.GetBlocking(); - } - outside_fifo.PutBlocking(ret); -} void TCPImagePuller::Disconnect() { if (disconnect.exchange(true)) return; diff --git a/image_puller/TCPImagePuller.h b/image_puller/TCPImagePuller.h index 5f4b5b33..714a3f9b 100644 --- a/image_puller/TCPImagePuller.h +++ b/image_puller/TCPImagePuller.h @@ -29,6 +29,7 @@ class TCPImagePuller : public ImagePuller { Logger logger{"TCPImagePuller"}; bool ReadExact(void *buf, size_t size); + bool SendAll(const void *buf, size_t len); bool EnsureConnected(); void CloseSocket(); void ReceiverThread(); @@ -37,5 +38,7 @@ public: explicit TCPImagePuller(const std::string &tcp_addr, std::optional rcv_buffer_size = {}); ~TCPImagePuller() override; + bool SupportsAck() const override { return true; } + bool SendAck(const PullerAckMessage &ack) override; void Disconnect() override; }; \ No newline at end of file diff --git a/image_pusher/HDF5FilePusher.h b/image_pusher/HDF5FilePusher.h index 4a1c0940..796df4ff 100644 --- a/image_pusher/HDF5FilePusher.h +++ b/image_pusher/HDF5FilePusher.h @@ -26,7 +26,6 @@ public: void SendImage(ZeroCopyReturnValue &z) override; bool SendCalibration(const CompressedImage &message) override; - std::string PrintSetup() const override; }; diff --git a/image_pusher/TCPStreamPusher.cpp b/image_pusher/TCPStreamPusher.cpp index a369164a..8ac57279 100644 --- a/image_pusher/TCPStreamPusher.cpp +++ b/image_pusher/TCPStreamPusher.cpp @@ -19,7 +19,6 @@ TCPStreamPusher::TCPStreamPusher(const std::vector &addr, } } - void TCPStreamPusher::StartDataCollection(StartMessage &message) { if (message.images_per_file < 1) throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, @@ -35,6 +34,25 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) { "TCP accept timeout/failure on socket " + socket[i]->GetEndpointName()); } + for (auto &s : socket) + s->StartWriterThread(); + + std::vector started(socket.size(), false); + + auto rollback_cancel = [&]() { + for (size_t i = 0; i < socket.size(); i++) { + if (!started[i] || socket[i]->IsBroken()) + continue; + + (void)socket[i]->Send(nullptr, 0, TCPFrameType::CANCEL); + std::string cancel_ack_err; + (void)socket[i]->WaitForAck(TCPFrameType::CANCEL, std::chrono::milliseconds(500), &cancel_ack_err); + } + + for (auto &s : socket) + s->StopWriterThread(); + }; + for (size_t i = 0; i < socket.size(); i++) { message.socket_number = static_cast(i); if (i > 0) @@ -44,17 +62,20 @@ void TCPStreamPusher::StartDataCollection(StartMessage &message) { socket[i]->SetRunNumber(run_number); if (!socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { - // one-shot recovery: reconnect and retry START once - if (!socket[i]->AcceptConnection(std::chrono::seconds(5)) || - !socket[i]->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::START)) { - throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, - "Timeout/failure sending START"); - } + rollback_cancel(); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "Timeout/failure sending START on " + socket[i]->GetEndpointName()); } - } - for (auto &s : socket) - s->StartWriterThread(); + std::string ack_err; + if (!socket[i]->WaitForAck(TCPFrameType::START, std::chrono::seconds(5), &ack_err)) { + rollback_cancel(); + throw JFJochException(JFJochExceptionCategory::InputParameterInvalid, + "START ACK failed on " + socket[i]->GetEndpointName() + ": " + ack_err); + } + + started[i] = true; + } } bool TCPStreamPusher::SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) { @@ -88,12 +109,25 @@ bool TCPStreamPusher::EndDataCollection(const EndMessage &message) { bool ret = true; for (auto &s : socket) { - s->StopWriterThread(); - if (s->IsBroken()) + if (s->IsBroken()) { ret = false; - else if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) + continue; + } + + if (!s->Send(serialization_buffer.data(), serializer.GetBufferSize(), TCPFrameType::END)) { ret = false; + continue; + } + + std::string ack_err; + if (!s->WaitForAck(TCPFrameType::END, std::chrono::seconds(10), &ack_err)) { + ret = false; + } } + + for (auto &s : socket) + s->StopWriterThread(); + transmission_error = !ret; return ret; } @@ -102,32 +136,16 @@ std::string TCPStreamPusher::Finalize() { std::string ret; if (transmission_error) ret += "Timeout sending images (e.g., writer disabled during data collection);"; - if (writer_notification_socket) { - for (size_t i = 0; i < socket.size(); i++) { - auto n = writer_notification_socket->Receive(run_number, run_name); - if (!n) - ret += "Writer " + std::to_string(i) + ": no end notification received within 1 minute from collection end"; - else if (static_cast(n->socket_number) >= socket.size()) - ret += "Writer " + std::to_string(i) + ": mismatch in socket number"; - else if (!n->ok) - ret += "Writer " + std::to_string(i) + ": " + n->error; + + for (size_t i = 0; i < socket.size(); i++) { + if (socket[i]->IsBroken()) { + const auto reason = socket[i]->GetLastAckError(); + ret += "Writer " + std::to_string(i) + ": " + (reason.empty() ? "stream broken" : reason) + ";"; } } return ret; } -std::string TCPStreamPusher::GetWriterNotificationSocketAddress() const { - if (writer_notification_socket) - return writer_notification_socket->GetEndpointName(); - else - return ""; -} - -TCPStreamPusher &TCPStreamPusher::WriterNotificationSocket(const std::string &addr) { - writer_notification_socket = std::make_unique(addr, std::chrono::minutes(1)); - return *this; -} - std::string TCPStreamPusher::PrintSetup() const { std::string output = "TCPStream2Pusher: Sending images to sockets: "; for (const auto &s : socket) diff --git a/image_pusher/TCPStreamPusher.h b/image_pusher/TCPStreamPusher.h index 3d574c22..a29c6115 100644 --- a/image_pusher/TCPStreamPusher.h +++ b/image_pusher/TCPStreamPusher.h @@ -11,8 +11,6 @@ class TCPStreamPusher : public ImagePusher { CBORStream2Serializer serializer; std::vector> socket; - std::unique_ptr writer_notification_socket; - int64_t images_per_file = 1; uint64_t run_number = 0; std::string run_name; @@ -23,9 +21,6 @@ public: std::optional zerocopy_threshold = {}, size_t send_queue_size = 4096); - TCPStreamPusher& WriterNotificationSocket(const std::string& addr); - std::string GetWriterNotificationSocketAddress() const override; - void StartDataCollection(StartMessage& message) override; bool EndDataCollection(const EndMessage& message) override; bool SendImage(const uint8_t *image_data, size_t image_size, int64_t image_number) override; diff --git a/image_pusher/TCPStreamPusherSocket.cpp b/image_pusher/TCPStreamPusherSocket.cpp index ef4291ec..f8e7df32 100644 --- a/image_pusher/TCPStreamPusherSocket.cpp +++ b/image_pusher/TCPStreamPusherSocket.cpp @@ -197,6 +197,50 @@ bool TCPStreamPusherSocket::SendAll(const void *buf, size_t len) { return true; } +bool TCPStreamPusherSocket::ReadExact(void *buf, size_t len) { + auto *p = static_cast(buf); + size_t got = 0; + + while (got < len) { + if (!active) + return false; + + int local_fd = fd.load(); + if (local_fd < 0) + return false; + + pollfd pfd{}; + pfd.fd = local_fd; + pfd.events = POLLIN; + + const int prc = poll(&pfd, 1, 100); // 100 ms interruptibility window + if (prc == 0) + continue; + if (prc < 0) { + if (errno == EINTR) + continue; + return false; + } + if ((pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) != 0) + return false; + if ((pfd.revents & POLLIN) == 0) + continue; + + ssize_t rc = ::recv(local_fd, p + got, len - got, 0); + if (rc == 0) + return false; + if (rc < 0) { + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) + continue; + return false; + } + got += static_cast(rc); + } + return true; +} + bool TCPStreamPusherSocket::SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z) { #if defined(MSG_ZEROCOPY) && defined(SO_ZEROCOPY) int local_fd = fd.load(); @@ -358,10 +402,89 @@ void TCPStreamPusherSocket::CompletionThread() { #endif } +void TCPStreamPusherSocket::AckThread() { + while (active) { + TcpFrameHeader h{}; + if (!ReadExact(&h, sizeof(h))) { + if (active) { + broken = true; + logger.Error("TCP ACK reader disconnected on " + endpoint); + } + break; + } + + if (h.magic != JFJOCH_TCP_MAGIC || h.version != JFJOCH_TCP_VERSION || static_cast(h.type) != TCPFrameType::ACK) { + broken = true; + logger.Error("Invalid ACK frame on " + endpoint); + break; + } + + std::string error_text; + if (h.payload_size > 0) { + error_text.resize(h.payload_size); + if (!ReadExact(error_text.data(), error_text.size())) { + broken = true; + break; + } + } + + const auto ack_for = static_cast(h.ack_for); + const bool ok = (h.flags & TCP_ACK_FLAG_OK) != 0; + const bool fatal = (h.flags & TCP_ACK_FLAG_FATAL) != 0; + const auto code = static_cast(h.ack_code); + + { + std::unique_lock ul(ack_state_mutex); + last_ack_code = code; + if (!error_text.empty()) + last_ack_error = error_text; + + if (ack_for == TCPFrameType::START) { + start_ack_received = true; + start_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "START rejected"; + } else if (ack_for == TCPFrameType::END) { + end_ack_received = true; + end_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "END rejected"; + } else if (ack_for == TCPFrameType::CANCEL) { + cancel_ack_received = true; + cancel_ack_ok = ok; + if (!ok && error_text.empty()) + last_ack_error = "CANCEL rejected"; + } else if (ack_for == TCPFrameType::DATA && (!ok || fatal)) { + broken = true; + if (error_text.empty()) + last_ack_error = "DATA fatal ACK"; + logger.Error("Received fatal DATA ACK on " + endpoint + ": " + last_ack_error); + } + } + ack_cv.notify_all(); + } +} + void TCPStreamPusherSocket::StartWriterThread() { + if (active) + return; + + { + std::unique_lock ul(ack_state_mutex); + start_ack_received = false; + start_ack_ok = false; + end_ack_received = false; + end_ack_ok = false; + cancel_ack_received = false; + cancel_ack_ok = false; + last_ack_error.clear(); + last_ack_code = TCPAckCode::None; + } + active = true; send_future = std::async(std::launch::async, &TCPStreamPusherSocket::WriterThread, this); completion_future = std::async(std::launch::async, &TCPStreamPusherSocket::CompletionThread, this); + ack_future = std::async(std::launch::async, &TCPStreamPusherSocket::AckThread, this); } void TCPStreamPusherSocket::StopWriterThread() { @@ -369,11 +492,14 @@ void TCPStreamPusherSocket::StopWriterThread() { return; active = false; queue.PutBlocking({.end = true}); + ack_cv.notify_all(); if (send_future.valid()) send_future.get(); if (completion_future.valid()) completion_future.get(); + if (ack_future.valid()) + ack_future.get(); // Keep fd open: END frame may still be sent after writer thread stops. // Socket is closed in destructor / explicit close path. @@ -403,3 +529,46 @@ bool TCPStreamPusherSocket::Send(const uint8_t *data, size_t size, TCPFrameType return SendFrame(data, size, type, image_number, nullptr); } + +bool TCPStreamPusherSocket::WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text) { + std::unique_lock ul(ack_state_mutex); + const bool ok = ack_cv.wait_for(ul, timeout, [&] { + if (ack_for == TCPFrameType::START) + return start_ack_received || broken.load(); + if (ack_for == TCPFrameType::END) + return end_ack_received || broken.load(); + if (ack_for == TCPFrameType::CANCEL) + return cancel_ack_received || broken.load(); + return false; + }); + + if (!ok) { + if (error_text) + *error_text = "ACK timeout"; + return false; + } + + if (broken) { + if (error_text) + *error_text = last_ack_error.empty() ? "Socket broken" : last_ack_error; + return false; + } + + bool ack_ok = false; + if (ack_for == TCPFrameType::START) + ack_ok = start_ack_ok; + else if (ack_for == TCPFrameType::END) + ack_ok = end_ack_ok; + else if (ack_for == TCPFrameType::CANCEL) + ack_ok = cancel_ack_ok; + + if (!ack_ok && error_text) + *error_text = last_ack_error.empty() ? "ACK rejected" : last_ack_error; + + return ack_ok; +} + +std::string TCPStreamPusherSocket::GetLastAckError() const { + std::unique_lock ul(ack_state_mutex); + return last_ack_error; +} diff --git a/image_pusher/TCPStreamPusherSocket.h b/image_pusher/TCPStreamPusherSocket.h index c735be59..80350ef7 100644 --- a/image_pusher/TCPStreamPusherSocket.h +++ b/image_pusher/TCPStreamPusherSocket.h @@ -26,6 +26,7 @@ class TCPStreamPusherSocket { std::atomic active = false; std::future send_future; std::future completion_future; + std::future ack_future; ThreadSafeFIFO queue; @@ -40,6 +41,16 @@ class TCPStreamPusherSocket { constexpr static auto AcceptTimeout = std::chrono::seconds(5); std::atomic broken{false}; + std::atomic last_ack_code{TCPAckCode::None}; + std::string last_ack_error; + mutable std::mutex ack_state_mutex; + std::condition_variable ack_cv; + bool start_ack_received = false; + bool start_ack_ok = false; + bool end_ack_received = false; + bool end_ack_ok = false; + bool cancel_ack_received = false; + bool cancel_ack_ok = false; std::atomic next_tx_id{1}; std::mutex inflight_mutex; @@ -49,10 +60,12 @@ class TCPStreamPusherSocket { void WriterThread(); void CompletionThread(); + void AckThread(); void CloseDataSocket(); bool SendAll(const void *buf, size_t len); + bool ReadExact(void *buf, size_t len); bool SendFrame(const uint8_t *data, size_t size, TCPFrameType type, int64_t image_number, ZeroCopyReturnValue *z); bool SendPayloadZC(const uint8_t *data, size_t size, ZeroCopyReturnValue *z); public: @@ -74,9 +87,12 @@ public: void StartWriterThread(); void StopWriterThread(); + bool WaitForAck(TCPFrameType ack_for, std::chrono::milliseconds timeout, std::string *error_text = nullptr); + void SetRunNumber(uint64_t in_run_number); void SendImage(ZeroCopyReturnValue &z); bool IsBroken() const; + std::string GetLastAckError() const; }; diff --git a/tests/TCPImagePusherTest.cpp b/tests/TCPImagePusherTest.cpp index f4608d1f..0cb0a646 100644 --- a/tests/TCPImagePusherTest.cpp +++ b/tests/TCPImagePusherTest.cpp @@ -7,7 +7,7 @@ #include "../image_pusher/TCPStreamPusher.h" #include "../image_puller/TCPImagePuller.h" -TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { +TEST_CASE("TCPImageCommTest_2Writers_WithAck", "[TCP]") { const size_t nframes = 128; const int64_t npullers = 2; const int64_t images_per_file = 16; @@ -23,24 +23,24 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { for (auto &i : image1) i = dist(g1); std::vector addr{ - "tcp://127.0.0.1:19001", - "tcp://127.0.0.1:19002" -}; + "tcp://127.0.0.1:19001", + "tcp://127.0.0.1:19002" + }; std::vector> puller; for (int i = 0; i < npullers; i++) { - puller.push_back(std::make_unique( - addr[i], 64 * 1024 * 1024)); // decoded cbor ring + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); } TCPStreamPusher pusher( addr, 64 * 1024 * 1024, - 128 * 1024, // zerocopy threshold - 8192 // sender queue + 128 * 1024, + 8192 ); std::vector received(npullers, 0); + std::vector processed(npullers, 0); std::thread sender([&] { std::vector serialization_buffer(16 * 1024 * 1024); @@ -70,29 +70,204 @@ TEST_CASE("TCPImageCommTest_2Writers", "[TCP]") { REQUIRE(pusher.EndDataCollection(end)); }); + std::vector receivers; + receivers.reserve(npullers); + for (int w = 0; w < npullers; w++) { - bool seen_end = false; - while (!seen_end) { - auto out = puller[w]->PollImage(std::chrono::seconds(10)); - REQUIRE(out.has_value()); - REQUIRE(out->cbor != nullptr); - if (out->cbor->end_message) { - seen_end = true; - continue; + receivers.emplace_back([&, w] { + bool seen_start = false; + bool seen_end = false; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(10)); + REQUIRE(out.has_value()); + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = 0; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_start = true; + continue; + } + + if (out->cbor->data_message) { + REQUIRE(seen_start); + auto n = out->cbor->data_message->number; + REQUIRE(((n / images_per_file) % npullers) == w); + received[w]++; + processed[w]++; + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.fatal = false; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = 0; + ack.processed_images = processed[w]; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } } - if (out->cbor->data_message) { - auto n = out->cbor->data_message->number; - REQUIRE(((n / images_per_file) % npullers) == w); - received[w]++; - } - } + }); } sender.join(); + for (auto &t : receivers) t.join(); REQUIRE(received[0] == nframes / 2); REQUIRE(received[1] == nframes / 2); + for (auto &p : puller) + p->Disconnect(); +} + +TEST_CASE("TCPImageCommTest_DataFatalAck_PropagatesToPusher", "[TCP]") { + const size_t nframes = 64; + const int64_t npullers = 2; + const int64_t images_per_file = 8; + + DiffractionExperiment x(DetJF(1)); + x.Raw(); + x.PedestalG0Frames(0).NumTriggers(1).UseInternalPacketGenerator(false).IncidentEnergy_keV(12.4) + .ImagesPerTrigger(nframes).Compression(CompressionAlgorithm::NO_COMPRESSION); + + std::mt19937 g1(42); + std::uniform_int_distribution dist; + std::vector image1(x.GetPixelsNum() * nframes); + for (auto &i : image1) i = dist(g1); + + std::vector addr{ + "tcp://127.0.0.1:19011", + "tcp://127.0.0.1:19012" + }; + + std::vector> puller; + for (int i = 0; i < npullers; i++) { + puller.push_back(std::make_unique(addr[i], 64 * 1024 * 1024)); + } + + TCPStreamPusher pusher( + addr, + 64 * 1024 * 1024, + 128 * 1024, + 8192 + ); + + std::atomic sent_fatal{false}; + + std::thread sender([&] { + std::vector serialization_buffer(16 * 1024 * 1024); + CBORStream2Serializer serializer(serialization_buffer.data(), serialization_buffer.size()); + + StartMessage start{ + .images_per_file = images_per_file, + .write_master_file = true + }; + EndMessage end{}; + + pusher.StartDataCollection(start); + + for (int64_t i = 0; i < static_cast(nframes); i++) { + DataMessage data_message; + data_message.number = i; + data_message.image = CompressedImage(image1.data() + i * x.GetPixelsNum(), + x.GetPixelsNum() * sizeof(uint16_t), + x.GetXPixelsNum(), + x.GetYPixelsNum(), + x.GetImageMode(), + x.GetCompressionAlgorithm()); + serializer.SerializeImage(data_message); + (void)pusher.SendImage(serialization_buffer.data(), serializer.GetBufferSize(), i); + } + + REQUIRE_FALSE(pusher.EndDataCollection(end)); + const auto final_msg = pusher.Finalize(); + REQUIRE_THAT(final_msg, Catch::Matchers::ContainsSubstring("quota")); + }); + + std::vector receivers; + receivers.reserve(npullers); + + for (int w = 0; w < npullers; w++) { + receivers.emplace_back([&, w] { + bool seen_end = false; + bool local_fatal_sent = false; + + while (!seen_end) { + auto out = puller[w]->PollImage(std::chrono::seconds(2)); + if (!out.has_value()) { + // Once this receiver has sent a fatal DATA ACK, no END is guaranteed on this stream. + if (local_fatal_sent) + break; + REQUIRE(out.has_value()); + } + + REQUIRE(out->cbor != nullptr); + REQUIRE(out->tcp_msg != nullptr); + + const auto &h = out->tcp_msg->header; + + if (out->cbor->start_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::START; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + continue; + } + + if (out->cbor->data_message) { + if (w == 0 && !sent_fatal.exchange(true)) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::DATA; + ack.ok = false; + ack.fatal = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.image_number = static_cast(out->cbor->data_message->number); + ack.error_code = TCPAckCode::DiskQuotaExceeded; + ack.error_text = "quota exceeded"; + REQUIRE(puller[w]->SendAck(ack)); + local_fatal_sent = true; + } + continue; + } + + if (out->cbor->end_message) { + PullerAckMessage ack; + ack.ack_for = TCPFrameType::END; + ack.ok = true; + ack.run_number = h.run_number; + ack.socket_number = h.socket_number; + ack.error_code = TCPAckCode::None; + REQUIRE(puller[w]->SendAck(ack)); + seen_end = true; + } + } + }); + } + + sender.join(); + for (auto &t : receivers) t.join(); + for (auto &p : puller) p->Disconnect(); } \ No newline at end of file diff --git a/writer/StreamWriter.cpp b/writer/StreamWriter.cpp index 1b4ac22b..372b6abd 100644 --- a/writer/StreamWriter.cpp +++ b/writer/StreamWriter.cpp @@ -20,6 +20,27 @@ StreamWriter::StreamWriter(Logger &in_logger, max_image_number(0) { } +void StreamWriter::NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text) { + if (!image_puller.SupportsAck()) + return; + + PullerAckMessage ack; + ack.ack_for = ack_for; + ack.ok = ok; + ack.fatal = fatal; + ack.error_code = code; + ack.error_text = error_text; + ack.run_number = run_number; + ack.socket_number = static_cast(socket_number); + ack.processed_images = processed_images.load(); + + if (image_puller_output.cbor && image_puller_output.cbor->data_message) + ack.image_number = image_puller_output.cbor->data_message->number; + + if (!image_puller.SendAck(ack)) + logger.Warning("Failed to send TCP ACK"); +} + void StreamWriter::ProcessStartMessage() { if (state == StreamWriterState::Finalized) return; // Should not happen (?) @@ -28,6 +49,7 @@ void StreamWriter::ProcessStartMessage() { FinalizeDataCollection(); err = ""; + tcp_data_fatal_sent = false; max_image_number = 0; @@ -51,11 +73,13 @@ void StreamWriter::ProcessStartMessage() { image_puller_output.cbor->start_message->file_prefix, image_puller_output.cbor->start_message->number_of_images); state = StreamWriterState::Started; + NotifyTcpAck(TCPFrameType::START, true, false, TCPAckCode::None); } catch (const JFJochException &e) { logger.ErrorException(e); logger.Error("Error writing start message - switching to error state"); state = StreamWriterState::Error; err = e.what(); + NotifyTcpAck(TCPFrameType::START, false, true, TCPAckCode::StartFailed, err); } } @@ -108,6 +132,10 @@ void StreamWriter::ProcessDataImage() { logger.Warning("Error writing image - switching to error state"); state = StreamWriterState::Error; err = e.what(); + if (!tcp_data_fatal_sent) { + tcp_data_fatal_sent = true; + NotifyTcpAck(TCPFrameType::DATA, false, true, TCPAckCode::DataWriteFailed, err); + } } break; case StreamWriterState::Error: @@ -156,6 +184,11 @@ void StreamWriter::FinalizeDataCollection() { } file_writer.reset(); NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr); + NotifyTcpAck(TCPFrameType::END, + state != StreamWriterState::Error, + state == StreamWriterState::Error, + state == StreamWriterState::Error ? TCPAckCode::EndFailed : TCPAckCode::None, + state == StreamWriterState::Error ? err : ""); logger.Info("Data writing finished"); state = StreamWriterState::Finalized; } @@ -168,6 +201,21 @@ void StreamWriter::CollectImages() { while (run && state != StreamWriterState::Finalized) { run = WaitForImage(); + if (image_puller_output.tcp_msg && + static_cast(image_puller_output.tcp_msg->header.type) == TCPFrameType::CANCEL) { + logger.Warning("Received TCP CANCEL, finalizing data collection"); + if (state != StreamWriterState::Idle && state != StreamWriterState::Finalized) + FinalizeDataCollection(); + NotifyTcpAck(TCPFrameType::CANCEL, true, false, TCPAckCode::None); + state = StreamWriterState::Finalized; + continue; + } + + if (!image_puller_output.cbor) { + logger.Warning("Missing CBOR payload for non-CANCEL TCP frame"); + continue; + } + if (image_puller_output.cbor->start_message) ProcessStartMessage(); else if (image_puller_output.cbor->calibration) diff --git a/writer/StreamWriter.h b/writer/StreamWriter.h index 045d0d50..301a4e6f 100644 --- a/writer/StreamWriter.h +++ b/writer/StreamWriter.h @@ -55,12 +55,14 @@ class StreamWriter { std::vector hdf5_data_file_statistics; bool debug_skip_write_notification = false; + bool tcp_data_fatal_sent = false; ImagePuller &image_puller; Logger &logger; void CollectImages(); bool WaitForImage(); void NotifyReceiverOnFinalizedWrite(const std::string &detector_update_zmq_addr); + void NotifyTcpAck(TCPFrameType ack_for, bool ok, bool fatal, TCPAckCode code, const std::string &error_text = ""); void ProcessStartMessage(); void ProcessEndMessage(); void ProcessDataImage(); -- 2.49.1 From 94072d362698ba8c6b8f3e4bf64f198dc2c0811f Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 13:37:52 +0100 Subject: [PATCH 04/42] jfjoch_test: Add TCP/IP integration test --- tests/JFJochReceiverProcessingTest.cpp | 81 ++++++++++++++++++++++++++ writer/StreamWriter.cpp | 17 ++++-- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/tests/JFJochReceiverProcessingTest.cpp b/tests/JFJochReceiverProcessingTest.cpp index 79b80591..03427553 100644 --- a/tests/JFJochReceiverProcessingTest.cpp +++ b/tests/JFJochReceiverProcessingTest.cpp @@ -14,6 +14,8 @@ #include "../writer/StreamWriter.h" #include "../image_pusher/NonePusher.h" #include "../image_pusher/HDF5FilePusher.h" +#include "../image_pusher/TCPStreamPusher.h" +#include "../image_puller/TCPImagePuller.h" TEST_CASE("JFJochIntegrationTest_ZMQ_lysozyme_spot_and_index", "[JFJochReceiver]") { Logger logger(Catch::getResultCapture().getCurrentTestName()); @@ -1467,3 +1469,82 @@ TEST_CASE("JFJochIntegrationTest_HDF5FilePusher_Raw", "[JFJochReceiver]") { CHECK(receiver_out.status.images_sent == 5); CHECK(!receiver_out.status.cancelled); } + +TEST_CASE("JFJochIntegrationTest_TCP_lysozyme_spot_and_index", "[JFJochReceiver]") { + Logger logger(Catch::getResultCapture().getCurrentTestName()); + + RegisterHDF5Filter(); + + const uint16_t nthreads = 4; + + DiffractionExperiment experiment(DetJF4M()); + experiment.ImagesPerTrigger(5).NumTriggers(1).UseInternalPacketGenerator(true).ImagesPerFile(2) + .FilePrefix("lyso_test_tcp").JungfrauConvPhotonCnt(false).SetFileWriterFormat(FileWriterFormat::NXmxVDS).OverwriteExistingFiles(true) + .DetectorDistance_mm(75).BeamY_pxl(1136).BeamX_pxl(1090).IncidentEnergy_keV(12.4) + .SetUnitCell(UnitCell{.a = 36.9, .b = 78.95, .c = 78.95, .alpha =90, .beta = 90, .gamma = 90}); + experiment.SampleTemperature_K(123.0).RingCurrent_mA(115); + + PixelMask pixel_mask(experiment); + + // Load example image + HDF5ReadOnlyFile data("../../tests/test_data/compression_benchmark.h5"); + HDF5DataSet dataset(data, "/entry/data/data"); + HDF5DataSpace file_space(dataset); + + REQUIRE(file_space.GetDimensions()[2] == experiment.GetXPixelsNum()); + REQUIRE(file_space.GetDimensions()[1] == experiment.GetYPixelsNum()); + std::vector image_conv (file_space.GetDimensions()[1] * file_space.GetDimensions()[2]); + + std::vector start = {4,0,0}; + std::vector file_size = {1, file_space.GetDimensions()[1], file_space.GetDimensions()[2]}; + dataset.ReadVector(image_conv, start, file_size); + + std::vector image_raw_geom(experiment.GetModulesNum() * RAW_MODULE_SIZE); + ConvertedToRawGeometry(experiment, image_raw_geom.data(), image_conv.data()); + logger.Info("Loaded image"); + + // Setup acquisition device + AcquisitionDeviceGroup aq_devices; + std::unique_ptr test = std::make_unique(0, 64); + for (int m = 0; m < experiment.GetModulesNum(); m++) + test->SetInternalGeneratorFrame((uint16_t *) image_raw_geom.data() + m * RAW_MODULE_SIZE, m); + + aq_devices.Add(std::move(test)); + + TCPStreamPusher pusher({"tcp://127.0.0.1:9121"}); + + TCPImagePuller puller("tcp://127.0.0.1:9121"); + StreamWriter writer(logger, puller); + auto writer_future = std::async(std::launch::async, &StreamWriter::Run, &writer); + + JFJochReceiverService service(aq_devices, logger, pusher); + service.NumThreads(nthreads); + service.Indexing(experiment.GetIndexingSettings()); + + // No progress value at the start of measurement + REQUIRE(!service.GetProgress().has_value()); + + SpotFindingSettings settings = DiffractionExperiment::DefaultDataProcessingSettings(); + settings.signal_to_noise_threshold = 2.5; + settings.photon_count_threshold = 5; + settings.min_pix_per_spot = 1; + settings.max_pix_per_spot = 200; + settings.high_resolution_limit = 2.0; + settings.low_resolution_limit = 50.0; + service.SetSpotFindingSettings(settings); + + service.Start(experiment, pixel_mask, nullptr); + auto receiver_out = service.Stop(); + + CHECK(receiver_out.efficiency == 1.0); + REQUIRE(receiver_out.status.indexing_rate); + CHECK(receiver_out.status.indexing_rate.value() == 1.0); + CHECK(receiver_out.status.images_sent == experiment.GetImageNum()); + CHECK(receiver_out.writer_err.empty()); + CHECK(!receiver_out.status.cancelled); + + // No progress value at the end of measurement + REQUIRE(!service.GetProgress().has_value()); + + REQUIRE_NOTHROW(writer_future.get()); +} diff --git a/writer/StreamWriter.cpp b/writer/StreamWriter.cpp index 372b6abd..ff7e7440 100644 --- a/writer/StreamWriter.cpp +++ b/writer/StreamWriter.cpp @@ -91,6 +91,7 @@ void StreamWriter::ProcessCalibrationImage() { } catch (const std::exception &e) { logger.Warning(e.what()); logger.Warning("Error during writing calibration data - skipping"); + NotifyTcpAck(TCPFrameType::CALIBRATION, false, false, TCPAckCode::DataWriteFailed, e.what()); } break; case StreamWriterState::Receiving: @@ -139,6 +140,7 @@ void StreamWriter::ProcessDataImage() { } break; case StreamWriterState::Error: + // Error state => Wait till end only case StreamWriterState::Finalized: break; } @@ -164,7 +166,16 @@ void StreamWriter::ProcessEndMessage() { err = e.what(); } } + bool error_state = (state == StreamWriterState::Error); + FinalizeDataCollection(); + + // Notifications happen only when handling END message + // No end message ==> no need to ACK + NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr); + NotifyTcpAck(TCPFrameType::END, !error_state, error_state, + error_state ? TCPAckCode::EndFailed : TCPAckCode::None, + error_state ? err : ""); } void StreamWriter::FinalizeDataCollection() { @@ -183,12 +194,6 @@ void StreamWriter::FinalizeDataCollection() { hdf5_data_file_statistics.clear(); } file_writer.reset(); - NotifyReceiverOnFinalizedWrite(writer_notification_zmq_addr); - NotifyTcpAck(TCPFrameType::END, - state != StreamWriterState::Error, - state == StreamWriterState::Error, - state == StreamWriterState::Error ? TCPAckCode::EndFailed : TCPAckCode::None, - state == StreamWriterState::Error ? err : ""); logger.Info("Data writing finished"); state = StreamWriterState::Finalized; } -- 2.49.1 From e6e8ffd83817384ae5fc2f7aee38cbd5c3f1136f Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 13:46:22 +0100 Subject: [PATCH 05/42] TCPStreamPusherSocket: Very minor refactor --- image_pusher/TCPStreamPusherSocket.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/image_pusher/TCPStreamPusherSocket.cpp b/image_pusher/TCPStreamPusherSocket.cpp index f8e7df32..7040dfc3 100644 --- a/image_pusher/TCPStreamPusherSocket.cpp +++ b/image_pusher/TCPStreamPusherSocket.cpp @@ -297,16 +297,16 @@ bool TCPStreamPusherSocket::SendFrame(const uint8_t *data, size_t size, TCPFrame return true; } + bool ok; if (z && zerocopy_threshold && size >= zerocopy_threshold.value()) { - bool ok = SendPayloadZC(data, size, z); + ok = SendPayloadZC(data, size, z); if (!ok) z->release(); - return ok; + } else { + ok = SendAll(data, size); + if (z) + z->release(); } - - bool ok = SendAll(data, size); - if (z) - z->release(); return ok; } -- 2.49.1 From b6db6040e836bb80c0f682924be6c7c32c7f8923 Mon Sep 17 00:00:00 2001 From: Filip Leonarski Date: Wed, 4 Mar 2026 14:07:08 +0100 Subject: [PATCH 06/42] jfjoch_broker: Output version with /status --- broker/JFJochStateMachine.cpp | 2 ++ broker/JFJochStateMachine.h | 11 +------ broker/OpenAPIConvert.cpp | 1 + broker/gen/model/Broker_status.cpp | 33 ++++++++++++++++++-- broker/gen/model/Broker_status.h | 9 ++++++ broker/jfjoch_api.yaml | 4 +++ broker/redoc-static.html | 6 ++-- common/BrokerStatus.h | 18 +++++++++++ common/CMakeLists.txt | 1 + docs/python_client/docs/BrokerStatus.md | 1 + frontend/package-lock.json | 4 +-- frontend/src/openapi/models/broker_status.ts | 4 +++ 12 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 common/BrokerStatus.h diff --git a/broker/JFJochStateMachine.cpp b/broker/JFJochStateMachine.cpp index 9d517dd1..dbc3e2f3 100644 --- a/broker/JFJochStateMachine.cpp +++ b/broker/JFJochStateMachine.cpp @@ -6,6 +6,7 @@ #include "JFJochStateMachine.h" #include "../preview/JFJochTIFF.h" #include "../common/CUDAWrapper.h" +#include "../common/GitInfo.h" JFJochStateMachine::JFJochStateMachine(const DiffractionExperiment& in_experiment, JFJochServices &in_services, @@ -551,6 +552,7 @@ BrokerStatus JFJochStateMachine::GetStatus() const { BrokerStatus ret = broker_status; ret.progress = services.GetReceiverProgress(); ret.gpu_count = gpu_count; + ret.broker_version = jfjoch_version(); return ret; } diff --git a/broker/JFJochStateMachine.h b/broker/JFJochStateMachine.h index 5bd7852c..3296769f 100644 --- a/broker/JFJochStateMachine.h +++ b/broker/JFJochStateMachine.h @@ -15,16 +15,7 @@ #include "JFJochServices.h" #include "../common/ROIMap.h" - -enum class JFJochState {Inactive, Idle, Measuring, Error, Busy, Calibration}; - -struct BrokerStatus { - JFJochState state = JFJochState::Inactive; - std::optional progress; - std::optional message; - enum class MessageSeverity {Error, Info, Warning, Success} message_severity = MessageSeverity::Error; - int64_t gpu_count; -}; +#include "../common/BrokerStatus.h" struct DetectorListElement { std::string description; diff --git a/broker/OpenAPIConvert.cpp b/broker/OpenAPIConvert.cpp index 3b603dcc..84c78c0f 100644 --- a/broker/OpenAPIConvert.cpp +++ b/broker/OpenAPIConvert.cpp @@ -241,6 +241,7 @@ org::openapitools::server::model::Broker_status Convert(const BrokerStatus& inpu ret.setProgress(input.progress.value()); ret.setGpuCount(input.gpu_count); + ret.setBrokerVersion(input.broker_version); return ret; } diff --git a/broker/gen/model/Broker_status.cpp b/broker/gen/model/Broker_status.cpp index 8c1cf391..0bca4129 100644 --- a/broker/gen/model/Broker_status.cpp +++ b/broker/gen/model/Broker_status.cpp @@ -30,6 +30,8 @@ Broker_status::Broker_status() m_Message_severityIsSet = false; m_Gpu_count = 0; m_Gpu_countIsSet = false; + m_Broker_version = ""; + m_Broker_versionIsSet = false; } @@ -71,7 +73,7 @@ bool Broker_status::validate(std::stringstream& msg, const std::string& pathPref } } - + return success; } @@ -93,7 +95,10 @@ bool Broker_status::operator==(const Broker_status& rhs) const ((!messageSeverityIsSet() && !rhs.messageSeverityIsSet()) || (messageSeverityIsSet() && rhs.messageSeverityIsSet() && getMessageSeverity() == rhs.getMessageSeverity())) && - ((!gpuCountIsSet() && !rhs.gpuCountIsSet()) || (gpuCountIsSet() && rhs.gpuCountIsSet() && getGpuCount() == rhs.getGpuCount())) + ((!gpuCountIsSet() && !rhs.gpuCountIsSet()) || (gpuCountIsSet() && rhs.gpuCountIsSet() && getGpuCount() == rhs.getGpuCount())) && + + + ((!brokerVersionIsSet() && !rhs.brokerVersionIsSet()) || (brokerVersionIsSet() && rhs.brokerVersionIsSet() && getBrokerVersion() == rhs.getBrokerVersion())) ; } @@ -115,6 +120,8 @@ void to_json(nlohmann::json& j, const Broker_status& o) j["message_severity"] = o.m_Message_severity; if(o.gpuCountIsSet()) j["gpu_count"] = o.m_Gpu_count; + if(o.brokerVersionIsSet()) + j["broker_version"] = o.m_Broker_version; } @@ -141,6 +148,11 @@ void from_json(const nlohmann::json& j, Broker_status& o) j.at("gpu_count").get_to(o.m_Gpu_count); o.m_Gpu_countIsSet = true; } + if(j.find("broker_version") != j.end()) + { + j.at("broker_version").get_to(o.m_Broker_version); + o.m_Broker_versionIsSet = true; + } } @@ -220,6 +232,23 @@ void Broker_status::unsetGpu_count() { m_Gpu_countIsSet = false; } +std::string Broker_status::getBrokerVersion() const +{ + return m_Broker_version; +} +void Broker_status::setBrokerVersion(std::string const& value) +{ + m_Broker_version = value; + m_Broker_versionIsSet = true; +} +bool Broker_status::brokerVersionIsSet() const +{ + return m_Broker_versionIsSet; +} +void Broker_status::unsetBroker_version() +{ + m_Broker_versionIsSet = false; +} } // namespace org::openapitools::server::model diff --git a/broker/gen/model/Broker_status.h b/broker/gen/model/Broker_status.h index 7b1fcbad..c00d131f 100644 --- a/broker/gen/model/Broker_status.h +++ b/broker/gen/model/Broker_status.h @@ -91,6 +91,13 @@ public: void setGpuCount(int32_t const value); bool gpuCountIsSet() const; void unsetGpu_count(); + /// + /// Version of the jfjoch_broker + /// + std::string getBrokerVersion() const; + void setBrokerVersion(std::string const& value); + bool brokerVersionIsSet() const; + void unsetBroker_version(); friend void to_json(nlohmann::json& j, const Broker_status& o); friend void from_json(const nlohmann::json& j, Broker_status& o); @@ -105,6 +112,8 @@ protected: bool m_Message_severityIsSet; int32_t m_Gpu_count; bool m_Gpu_countIsSet; + std::string m_Broker_version; + bool m_Broker_versionIsSet; }; diff --git a/broker/jfjoch_api.yaml b/broker/jfjoch_api.yaml index 92f500b6..925e168c 100644 --- a/broker/jfjoch_api.yaml +++ b/broker/jfjoch_api.yaml @@ -1334,6 +1334,10 @@ components: type: integer format: int32 description: Number of installed GPUs + broker_version: + type: string + description: Version of the jfjoch_broker + example: "1.0.0-rc.128" plot: type: object required: diff --git a/broker/redoc-static.html b/broker/redoc-static.html index d06d1936..3b90814d 100644 --- a/broker/redoc-static.html +++ b/broker/redoc-static.html @@ -763,7 +763,7 @@ This can only be done when detector is Idle, Error or
http://localhost:5232/config/dark_mask

Response samples

Content type
application/json
{
  • "detector_threshold_keV": 3.5,
  • "frame_time_us": 10000,
  • "number_of_frames": 1000,
  • "max_allowed_pixel_count": 1,
  • "max_frames_with_signal": 10
}

Get Jungfraujoch status

Status of the data acquisition

Responses

Response samples

Content type
application/json
{
  • "state": "Inactive",
  • "progress": 1,
  • "message": "string",
  • "message_severity": "success",
  • "gpu_count": 0
}

Get status of FPGA devices

Responses

Response samples

Content type
application/json
{
  • "state": "Inactive",
  • "progress": 1,
  • "message": "string",
  • "message_severity": "success",
  • "gpu_count": 0,
  • "broker_version": "1.0.0-rc.128"
}

Get status of FPGA devices

Responses

Response samples

Content type
application/json
[
  • {
    }
]

Return XFEL pulse IDs for the current data acquisition

Return array of XFEL pulse IDs - (-1) if image not recorded

Responses

Request samples

Content type
application/json
{
  • "box": {
    },
  • "circle": {
    },
  • "azim": {
    }
}

Response samples

Content type
application/json
{
  • "msg": "Detector in wrong state",
  • "reason": "WrongDAQState"
}

Get general statistics

query Parameters
compression
boolean
Default: false

Enable DEFLATE compression of output data.

Responses

Response samples

Content type
application/json
{
  • "detector": {
    },
  • "detector_list": {
    },
  • "detector_settings": {
    },
  • "image_format_settings": {
    },
  • "instrument_metadata": {
    },
  • "file_writer_settings": {
    },
  • "data_processing_settings": {
    },
  • "measurement": {
    },
  • "broker": {
    },
  • "fpga": [
    ],
  • "calibration": [
    ],
  • "zeromq_preview": {
    },
  • "zeromq_metadata": {
    },
  • "dark_mask": {
    },
  • "pixel_mask": {
    },
  • "roi": {
    },
  • "az_int": {
    },
  • "buffer": {
    },
  • "indexing": {
    }
}

Get data collection statistics

Results of the last data collection

+
http://localhost:5232/statistics

Response samples

Content type
application/json
{
  • "detector": {
    },
  • "detector_list": {
    },
  • "detector_settings": {
    },
  • "image_format_settings": {
    },
  • "instrument_metadata": {
    },
  • "file_writer_settings": {
    },
  • "data_processing_settings": {
    },
  • "measurement": {
    },
  • "broker": {
    },
  • "fpga": [
    ],
  • "calibration": [
    ],
  • "zeromq_preview": {
    },
  • "zeromq_metadata": {
    },
  • "dark_mask": {
    },
  • "pixel_mask": {
    },
  • "roi": {
    },
  • "az_int": {
    },
  • "buffer": {
    },
  • "indexing": {
    }
}

Get data collection statistics

Results of the last data collection

Responses