From 1de0060f89a436b9f3d80d69e7d3b5f7897ecf16 Mon Sep 17 00:00:00 2001 From: Dominik Werder Date: Tue, 12 Nov 2024 13:38:49 +0100 Subject: [PATCH] Factored proto in separate crate --- netfetch/Cargo.toml | 1 + netfetch/src/ca.rs | 1 - netfetch/src/ca/conn.rs | 34 ++++---- netfetch/src/ca/conn/enumfetch.rs | 12 +-- netfetch/src/ca/conn2/channel.rs | 7 +- netfetch/src/ca/conn2/conn.rs | 3 +- netfetch/src/ca/findioc.rs | 20 ++--- netfetch/src/ca/proto.rs | 132 ++++++++++++++++++++++-------- netfetch/src/lib.rs | 1 + netfetch/src/tcpasyncwriteread.rs | 54 ++++++++++++ 10 files changed, 190 insertions(+), 75 deletions(-) create mode 100644 netfetch/src/tcpasyncwriteread.rs diff --git a/netfetch/Cargo.toml b/netfetch/Cargo.toml index dbad665..bd1a8b6 100644 --- a/netfetch/Cargo.toml +++ b/netfetch/Cargo.toml @@ -51,6 +51,7 @@ netpod = { path = "../../daqbuf-netpod", package = "daqbuf-netpod" } items_0 = { path = "../../daqbuf-items-0", package = "daqbuf-items-0" } items_2 = { path = "../../daqbuf-items-2", package = "daqbuf-items-2" } streams = { path = "../../daqbuf-streams", package = "daqbuf-streams" } +ca_proto = { path = "../../daqbuf-ca-proto", package = "daqbuf-ca-proto" } taskrun = { path = "../../daqbuffer/crates/taskrun" } #bitshuffle = { path = "../../daqbuffer/crates/bitshuffle" } mrucache = { path = "../mrucache" } diff --git a/netfetch/src/ca.rs b/netfetch/src/ca.rs index 0dfed89..10eb155 100644 --- a/netfetch/src/ca.rs +++ b/netfetch/src/ca.rs @@ -5,7 +5,6 @@ pub mod connset; pub mod connset_input_merge; pub mod finder; pub mod findioc; -pub mod proto; pub mod search; pub mod statemap; diff --git a/netfetch/src/ca/conn.rs b/netfetch/src/ca/conn.rs index e1d8297..b20d2b5 100644 --- a/netfetch/src/ca/conn.rs +++ b/netfetch/src/ca/conn.rs @@ -1,16 +1,12 @@ mod enumfetch; -use super::proto; -use super::proto::CaDataValue; -use super::proto::CaEventValue; -use super::proto::ReadNotify; -use crate::ca::proto::ChannelClose; -use crate::ca::proto::EventCancel; use crate::conf::ChannelConfig; use crate::metrics::status::StorageUsage; +use crate::tcpasyncwriteread::TcpAsyncWriteRead; use crate::throttletrace::ThrottleTrace; use async_channel::Receiver; use async_channel::Sender; +use ca_proto::ca::proto; use core::fmt; use dbpg::seriesbychannel::ChannelInfoQuery; use dbpg::seriesbychannel::ChannelInfoResult; @@ -34,12 +30,17 @@ use netpod::Shape; use netpod::TsMs; use netpod::TsNano; use netpod::EMIT_ACCOUNTING_SNAP; +use proto::CaDataValue; +use proto::CaEventValue; use proto::CaItem; use proto::CaMsg; use proto::CaMsgTy; use proto::CaProto; +use proto::ChannelClose; use proto::CreateChan; use proto::EventAdd; +use proto::EventCancel; +use proto::ReadNotify; use scywr::insertqueues::InsertDeques; use scywr::insertqueues::InsertQueuesTx; use scywr::insertqueues::InsertSenderPolling; @@ -169,7 +170,7 @@ pub enum Error { NoProtocol, ProtocolError, IocIssue, - Protocol(#[from] crate::ca::proto::Error), + Protocol(#[from] proto::Error), RtWriter(#[from] serieswriter::rtwriter::Error), BinWriter(#[from] serieswriter::binwriter::Error), SeriesWriter(#[from] serieswriter::writer::Error), @@ -2200,7 +2201,7 @@ impl CaConn { rng: &mut Xoshiro128PlusPlus, ) -> Result<(), Error> { { - use proto::CaMetaValue::*; + use ca_proto::ca::proto::CaMetaValue::*; match &value.meta { CaMetaTime(meta) => { if meta.status != 0 { @@ -2286,8 +2287,8 @@ impl CaConn { } fn check_ev_value_data(data: &proto::CaDataValue, scalar_type: &ScalarType) -> Result<(), Error> { - use crate::ca::proto::CaDataScalarValue; - use crate::ca::proto::CaDataValue; + use ca_proto::ca::proto::CaDataScalarValue; + use ca_proto::ca::proto::CaDataValue; match data { CaDataValue::Scalar(x) => match &x { CaDataScalarValue::F32(..) => match &scalar_type { @@ -2921,10 +2922,11 @@ impl CaConn { })?; self.backoff_reset(); let proto = CaProto::new( - tcp, + TcpAsyncWriteRead::from(tcp), self.remote_addr_dbg.to_string(), self.opts.array_truncate, - self.ca_proto_stats.clone(), + // self.ca_proto_stats.clone(), + (), ); self.state = CaConnState::Init; self.proto = Some(proto); @@ -3675,7 +3677,7 @@ impl CaWriterValue { fn new(val: CaEventValue, crst: &CreatedState) -> Self { let valstr = match &val.data { CaDataValue::Scalar(val) => { - use super::proto::CaDataScalarValue; + use ca_proto::ca::proto::CaDataScalarValue; match val { CaDataScalarValue::Enum(x) => { let x = *x; @@ -3748,11 +3750,11 @@ impl EmittableType for CaWriterValue { // debug!("diff_data emit {:?}", state.series_data); let (ts_msp, ts_lsp, ts_msp_chg) = state.msp_split_data.split(ts, self.byte_size()); let data_value = { - use super::proto::CaDataValue; + use ca_proto::ca::proto::CaDataValue; use scywr::iteminsertqueue::DataValue; let ret = match self.0.data { CaDataValue::Scalar(val) => DataValue::Scalar({ - use super::proto::CaDataScalarValue; + use ca_proto::ca::proto::CaDataScalarValue; use scywr::iteminsertqueue::ScalarValue; match val { CaDataScalarValue::I8(x) => ScalarValue::I8(x), @@ -3772,7 +3774,7 @@ impl EmittableType for CaWriterValue { } }), CaDataValue::Array(val) => DataValue::Array({ - use super::proto::CaDataArrayValue; + use ca_proto::ca::proto::CaDataArrayValue; use scywr::iteminsertqueue::ArrayValue; match val { CaDataArrayValue::I8(x) => ArrayValue::I8(x), diff --git a/netfetch/src/ca/conn/enumfetch.rs b/netfetch/src/ca/conn/enumfetch.rs index d9ac921..ac9887d 100644 --- a/netfetch/src/ca/conn/enumfetch.rs +++ b/netfetch/src/ca/conn/enumfetch.rs @@ -1,12 +1,13 @@ use super::CaConn; use super::CreatedState; use super::Ioid; -use crate::ca::proto::CaMsg; -use crate::ca::proto::ReadNotify; +use ca_proto::ca::proto; use dbpg::seriesbychannel::ChannelInfoQuery; use err::thiserror; use err::ThisError; use log::*; +use proto::CaMsg; +use proto::ReadNotify; use series::SeriesId; use std::pin::Pin; use std::time::Instant; @@ -32,7 +33,7 @@ impl EnumFetch { // info!("EnumFetch::new name {}", created_state.name()); let dbr_ctrl_enum = 31; let ioid = conn.ioid_next(); - let ty = crate::ca::proto::CaMsgTy::ReadNotify(ReadNotify { + let ty = proto::CaMsgTy::ReadNotify(ReadNotify { data_type: dbr_ctrl_enum, data_count: 0, sid: created_state.sid.to_u32(), @@ -53,10 +54,9 @@ impl ConnFuture for EnumFetch { fn camsg(mut self: Pin<&mut Self>, camsg: CaMsg, conn: &mut CaConn) -> Result<(), Error> { let tsnow = Instant::now(); let crst = &mut self.created_state; - // info!("EnumFetch::poll name {}", crst.name()); match camsg.ty { - crate::ca::proto::CaMsgTy::ReadNotifyRes(msg2) => match msg2.value.meta { - super::proto::CaMetaValue::CaMetaVariants(meta) => { + proto::CaMsgTy::ReadNotifyRes(msg2) => match msg2.value.meta { + proto::CaMetaValue::CaMetaVariants(meta) => { crst.enum_str_table = Some(meta.variants); } _ => { diff --git a/netfetch/src/ca/conn2/channel.rs b/netfetch/src/ca/conn2/channel.rs index 7937e1b..af4cfd8 100644 --- a/netfetch/src/ca/conn2/channel.rs +++ b/netfetch/src/ca/conn2/channel.rs @@ -1,13 +1,12 @@ -use err::thiserror; -use err::ThisError; +use ca_proto::ca::proto; -#[derive(Debug, ThisError)] +#[derive(Debug, thiserror::Error)] #[cstm(name = "ConnChannelError")] pub enum Error {} trait Channel { fn can_accept_ca_msg(&self) -> bool; - fn process_ca_msg(&mut self, msg: crate::ca::proto::CaMsg) -> Result<(), Error>; + fn process_ca_msg(&mut self, msg: proto::CaMsg) -> Result<(), Error>; } struct ChannelAny {} diff --git a/netfetch/src/ca/conn2/conn.rs b/netfetch/src/ca/conn2/conn.rs index 0dcde3f..9938b77 100644 --- a/netfetch/src/ca/conn2/conn.rs +++ b/netfetch/src/ca/conn2/conn.rs @@ -2,8 +2,8 @@ use super::conncmd::ConnCommand; use super::connevent::CaConnEvent; use super::connevent::EndOfStreamReason; use crate::ca::conn::CaConnOpts; -use crate::ca::proto::CaProto; use async_channel::Sender; +use ca_proto::ca::proto; use dbpg::seriesbychannel::ChannelInfoQuery; use futures_util::Future; use futures_util::FutureExt; @@ -11,6 +11,7 @@ use futures_util::Stream; use futures_util::StreamExt; use hashbrown::HashMap; use log::*; +use proto::CaProto; use scywr::insertqueues::InsertDeques; use scywr::insertqueues::InsertQueuesTx; use scywr::iteminsertqueue::QueryItem; diff --git a/netfetch/src/ca/findioc.rs b/netfetch/src/ca/findioc.rs index 165d168..3bd017a 100644 --- a/netfetch/src/ca/findioc.rs +++ b/netfetch/src/ca/findioc.rs @@ -1,13 +1,14 @@ -use crate::ca::proto::CaMsg; -use crate::ca::proto::CaMsgTy; -use crate::ca::proto::HeadInfo; use crate::throttletrace::ThrottleTrace; use async_channel::Receiver; +use ca_proto::ca::proto; use futures_util::Future; use futures_util::FutureExt; use futures_util::Stream; use libc::c_int; use log::*; +use proto::CaMsg; +use proto::CaMsgTy; +use proto::HeadInfo; use stats::IocFinderStats; use std::collections::BTreeMap; use std::collections::VecDeque; @@ -35,7 +36,7 @@ pub enum Error { SendFailure, ReadFailure, ReadEmpty, - Proto(#[from] crate::ca::proto::Error), + Proto(#[from] proto::Error), Slidebuf(#[from] slidebuf::Error), IO(#[from] std::io::Error), } @@ -669,21 +670,10 @@ impl Stream for FindIocStream { } if !self.channels_input.is_closed() { while self.in_flight.len() < self.in_flight_max { - #[cfg(DISABLED)] - { - let n1 = self.in_flight.len(); - self.thr_msg_1.trigger("FindIocStream while A {}", &[&n1]); - } let chns = self.get_input_up_to_batch_max(cx); if chns.len() == 0 { break; } else { - #[cfg(DISABLED)] - { - let n1 = self.in_flight.len(); - let n2 = chns.len(); - self.thr_msg_2.trigger("FindIocStream while B {} {}", &[&n1, &n2]); - } self.create_in_flight(chns); have_progress = true; } diff --git a/netfetch/src/ca/proto.rs b/netfetch/src/ca/proto.rs index 13ad68e..0d36c65 100644 --- a/netfetch/src/ca/proto.rs +++ b/netfetch/src/ca/proto.rs @@ -1,27 +1,19 @@ -use crate::netbuf; -use err::thiserror; -use err::ThisError; +use futures_util::AsyncRead; +use futures_util::AsyncWrite; use futures_util::Stream; use log::*; use netpod::timeunits::*; use slidebuf::SlideBuf; -use stats::CaProtoStats; use std::collections::VecDeque; use std::io; use std::pin::Pin; -use std::sync::Arc; use std::task::Context; use std::task::Poll; use std::time::Instant; -use taskrun::tokio; -use tokio::io::AsyncRead; -use tokio::io::AsyncWrite; -use tokio::io::ReadBuf; -#[derive(Debug, ThisError)] -#[cstm(name = "NetfetchCaProto")] +#[derive(Debug, thiserror::Error)] +#[cstm(name = "CaProto")] pub enum Error { - NetBuf(#[from] netbuf::Error), SlideBuf(#[from] slidebuf::Error), #[error("BufferTooSmallForNeedMin({0}, {1})")] BufferTooSmallForNeedMin(usize, usize), @@ -59,6 +51,85 @@ const TESTING_EVENT_ADD_RES_MAX: u32 = 3; const TESTING_PROTOCOL_ERROR_TODO_REMOVE: bool = false; const TESTING_PROTOCOL_ERROR_AFTER_BYTES: u32 = 400; +pub trait StatsCounter { + fn inc(&mut self); +} + +pub trait StatsCumulative { + fn add(&mut self, v: u64); +} + +pub trait StatsHisto { + fn ingest(&mut self, v: u32); +} + +impl StatsCounter for () { + fn inc(&mut self) {} +} + +impl StatsCumulative for () { + fn add(&mut self, _v: u64) {} +} + +impl StatsHisto for () { + fn ingest(&mut self, _v: u32) {} +} + +pub trait CaProtoStatsRecv: Unpin { + fn out_msg_placed(&mut self) -> &mut dyn StatsCounter; + fn out_bytes(&mut self) -> &mut dyn StatsCumulative; + fn outbuf_len(&mut self) -> &mut dyn StatsHisto; + fn tcp_recv_count(&mut self) -> &mut dyn StatsCounter; + fn tcp_recv_bytes(&mut self) -> &mut dyn StatsCumulative; + fn payload_ext_very_large(&mut self) -> &mut dyn StatsCounter; + fn payload_ext_but_small(&mut self) -> &mut dyn StatsCounter; + fn payload_size(&mut self) -> &mut dyn StatsHisto; + fn protocol_issue(&mut self) -> &mut dyn StatsCounter; + fn data_count(&mut self) -> &mut dyn StatsHisto; +} + +impl CaProtoStatsRecv for () { + fn out_msg_placed(&mut self) -> &mut dyn StatsCounter { + self + } + + fn out_bytes(&mut self) -> &mut dyn StatsCumulative { + self + } + + fn outbuf_len(&mut self) -> &mut dyn StatsHisto { + self + } + + fn tcp_recv_count(&mut self) -> &mut dyn StatsCounter { + self + } + + fn tcp_recv_bytes(&mut self) -> &mut dyn StatsCumulative { + self + } + + fn payload_ext_very_large(&mut self) -> &mut dyn StatsCounter { + self + } + + fn payload_ext_but_small(&mut self) -> &mut dyn StatsCounter { + self + } + + fn payload_size(&mut self) -> &mut dyn StatsHisto { + self + } + + fn protocol_issue(&mut self) -> &mut dyn StatsCounter { + self + } + + fn data_count(&mut self) -> &mut dyn StatsHisto { + self + } +} + #[derive(Debug)] pub struct Search { pub id: u32, @@ -1173,7 +1244,7 @@ pub trait AsyncWriteRead: AsyncWrite + AsyncRead + Send + 'static {} impl AsyncWriteRead for T where T: AsyncWrite + AsyncRead + Send + 'static {} -pub struct CaProto { +pub struct CaProto { tcp: Pin>, tcp_eof: bool, remote_name: String, @@ -1182,19 +1253,17 @@ pub struct CaProto { outbuf: SlideBuf, out: VecDeque, array_truncate: usize, - stats: Arc, + stats: STATS, resqu: VecDeque, event_add_res_cnt: u32, bytes_recv_testing: u32, } -impl CaProto { - pub fn new( - tcp: T, - remote_name: String, - array_truncate: usize, - stats: Arc, - ) -> Self { +impl CaProto +where + STATS: CaProtoStatsRecv, +{ + pub fn new(tcp: T, remote_name: String, array_truncate: usize, stats: STATS) -> Self { Self { tcp: Box::pin(tcp), tcp_eof: false, @@ -1310,23 +1379,22 @@ impl CaProto { let this = self.as_mut().get_mut(); let tcp = Pin::new(&mut this.tcp); let buf = this.buf.available_writable_area(need_min)?; - let mut rbuf = ReadBuf::new(buf); - if rbuf.remaining() == 0 { + if buf.len() == 0 { return Err(Error::NoReadBufferSpace); } - break match tcp.poll_read(cx, &mut rbuf) { + break match tcp.poll_read(cx, buf) { Ready(k) => match k { - Ok(()) => { - let nf = rbuf.filled().len(); + Ok(nf) => { + // let nf = rbuf.filled().len(); if nf == 0 { debug!("peer done {:?} {:?}", self.remote_name, self.state); self.tcp_eof = true; } else { - if false { - debug!("received {} bytes", rbuf.filled().len()); - let t = rbuf.filled().len().min(32); - debug!("received data {:?}", &rbuf.filled()[0..t]); - } + // if false { + // debug!("received {} bytes", nf); + // let t = nf.min(32); + // debug!("received data {:?}", &rbuf.filled()[0..t]); + // } if TESTING_PROTOCOL_ERROR_TODO_REMOVE { self.bytes_recv_testing = self.bytes_recv_testing.saturating_add(nf as u32); if self.bytes_recv_testing <= TESTING_PROTOCOL_ERROR_AFTER_BYTES { @@ -1343,8 +1411,8 @@ impl CaProto { self.buf.wadv(nf)?; } have_progress = true; - self.stats.tcp_recv_bytes().add(nf as _); self.stats.tcp_recv_count().inc(); + self.stats.tcp_recv_bytes().add(nf as _); continue; } } diff --git a/netfetch/src/lib.rs b/netfetch/src/lib.rs index 2edae26..6b58354 100644 --- a/netfetch/src/lib.rs +++ b/netfetch/src/lib.rs @@ -8,6 +8,7 @@ pub mod netbuf; pub mod polltimer; pub mod ratelimit; pub mod rt; +pub mod tcpasyncwriteread; #[cfg(test)] pub mod test; pub mod throttletrace; diff --git a/netfetch/src/tcpasyncwriteread.rs b/netfetch/src/tcpasyncwriteread.rs new file mode 100644 index 0000000..33df685 --- /dev/null +++ b/netfetch/src/tcpasyncwriteread.rs @@ -0,0 +1,54 @@ +use futures_util::AsyncRead; +use futures_util::AsyncWrite; +use std::io; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use taskrun::tokio::io::ReadBuf; +use taskrun::tokio::net::TcpStream; + +#[pin_project::pin_project] +pub struct TcpAsyncWriteRead { + #[pin] + tcp: TcpStream, +} + +impl From for TcpAsyncWriteRead { + fn from(value: TcpStream) -> Self { + Self { tcp: value } + } +} + +impl AsyncWrite for TcpAsyncWriteRead { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + use taskrun::tokio::io::AsyncWrite; + let this = self.project(); + this.tcp.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + use taskrun::tokio::io::AsyncWrite; + let this = self.project(); + this.tcp.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + use taskrun::tokio::io::AsyncWrite; + let this = self.project(); + this.tcp.poll_shutdown(cx) + } +} + +impl AsyncRead for TcpAsyncWriteRead { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + use taskrun::tokio::io::AsyncRead; + use Poll::*; + let this = self.project(); + let mut readbuf = ReadBuf::new(buf); + match this.tcp.poll_read(cx, &mut readbuf) { + Ready(Ok(())) => Ready(Ok(readbuf.filled().len())), + Ready(Err(e)) => Ready(Err(e)), + Pending => Pending, + } + } +}