1use std::sync::Arc;
16
17use common_telemetry::warn;
18use dashmap::DashMap;
19use rskafka::client::partition::{Compression, OffsetAt, PartitionClient};
20use rskafka::record::Record;
21use store_api::logstore::provider::KafkaProvider;
22use store_api::storage::RegionId;
23use tokio::sync::mpsc::{self, Receiver, Sender};
24
25use crate::error::{self, Result};
26use crate::kafka::index::IndexCollector;
27use crate::kafka::worker::{BackgroundProducerWorker, ProduceResultHandle, WorkerRequest};
28use crate::metrics::{
29 METRIC_KAFKA_CLIENT_BYTES_TOTAL, METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED,
30 METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL,
31};
32
33pub type OrderedBatchProducerRef = Arc<OrderedBatchProducer>;
34
35const REQUEST_BATCH_SIZE: usize = 64;
37
38const PRODUCER_CHANNEL_SIZE: usize = REQUEST_BATCH_SIZE * 2;
40
41#[derive(Debug)]
43pub(crate) struct OrderedBatchProducer {
44 pub(crate) sender: Sender<WorkerRequest>,
45}
46
47impl OrderedBatchProducer {
48 pub(crate) fn channel() -> (Sender<WorkerRequest>, Receiver<WorkerRequest>) {
49 mpsc::channel(PRODUCER_CHANNEL_SIZE)
50 }
51
52 pub(crate) fn new(
54 (tx, rx): (Sender<WorkerRequest>, Receiver<WorkerRequest>),
55 provider: Arc<KafkaProvider>,
56 client: Arc<dyn ProducerClient>,
57 compression: Compression,
58 max_batch_bytes: usize,
59 index_collector: Box<dyn IndexCollector>,
60 high_watermark: Arc<DashMap<Arc<KafkaProvider>, u64>>,
61 ) -> Self {
62 let mut worker = BackgroundProducerWorker {
63 provider,
64 client,
65 compression,
66 receiver: rx,
67 request_batch_size: REQUEST_BATCH_SIZE,
68 max_batch_bytes,
69 index_collector,
70 high_watermark,
71 };
72 tokio::spawn(async move { worker.run().await });
73 Self { sender: tx }
74 }
75
76 pub(crate) async fn produce(
84 &self,
85 region_id: RegionId,
86 batch: Vec<Record>,
87 ) -> Result<ProduceResultHandle> {
88 let (req, handle) = WorkerRequest::new_produce_request(region_id, batch);
89 if self.sender.send(req).await.is_err() {
90 warn!("OrderedBatchProducer is already exited");
91 return error::OrderedBatchProducerStoppedSnafu {}.fail();
92 }
93
94 Ok(handle)
95 }
96
97 pub(crate) async fn update_high_watermark(&self) -> Result<()> {
100 if self
101 .sender
102 .send(WorkerRequest::UpdateHighWatermark)
103 .await
104 .is_err()
105 {
106 warn!("OrderedBatchProducer is already exited");
107 return error::OrderedBatchProducerStoppedSnafu {}.fail();
108 }
109 Ok(())
110 }
111}
112
113#[async_trait::async_trait]
114pub trait ProducerClient: std::fmt::Debug + Send + Sync {
115 async fn produce(
116 &self,
117 records: Vec<Record>,
118 compression: Compression,
119 ) -> rskafka::client::error::Result<Vec<i64>>;
120
121 async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64>;
122}
123
124#[async_trait::async_trait]
125impl ProducerClient for PartitionClient {
126 async fn produce(
127 &self,
128 records: Vec<Record>,
129 compression: Compression,
130 ) -> rskafka::client::error::Result<Vec<i64>> {
131 let total_size = records.iter().map(|r| r.approximate_size()).sum::<usize>();
132 let partition = self.partition().to_string();
133 METRIC_KAFKA_CLIENT_BYTES_TOTAL
134 .with_label_values(&[self.topic(), &partition])
135 .inc_by(total_size as u64);
136 METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL
137 .with_label_values(&[self.topic(), &partition])
138 .inc();
139 let _timer = METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED
140 .with_label_values(&[self.topic(), &partition])
141 .start_timer();
142
143 self.produce(records, compression).await
144 }
145
146 async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64> {
147 self.get_offset(at).await
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::sync::{Arc, Mutex};
154 use std::time::Duration;
155
156 use common_base::readable_size::ReadableSize;
157 use common_telemetry::debug;
158 use futures::stream::FuturesUnordered;
159 use futures::{FutureExt, StreamExt};
160 use rskafka::client::error::{Error as ClientError, RequestContext};
161 use rskafka::client::partition::Compression;
162 use rskafka::protocol::error::Error as ProtocolError;
163 use rskafka::record::Record;
164 use store_api::storage::RegionId;
165
166 use super::*;
167 use crate::kafka::index::NoopCollector;
168 use crate::kafka::producer::OrderedBatchProducer;
169 use crate::kafka::test_util::record;
170
171 #[derive(Debug)]
172 struct MockClient {
173 error: Option<ProtocolError>,
174 panic: Option<String>,
175 delay: Duration,
176 batch_sizes: Mutex<Vec<usize>>,
177 }
178
179 #[async_trait::async_trait]
180 impl ProducerClient for MockClient {
181 async fn produce(
182 &self,
183 records: Vec<Record>,
184 _compression: Compression,
185 ) -> rskafka::client::error::Result<Vec<i64>> {
186 tokio::time::sleep(self.delay).await;
187
188 if let Some(e) = self.error {
189 return Err(ClientError::ServerError {
190 protocol_error: e,
191 error_message: None,
192 request: RequestContext::Partition("foo".into(), 1),
193 response: None,
194 is_virtual: false,
195 });
196 }
197
198 if let Some(p) = self.panic.as_ref() {
199 panic!("{}", p);
200 }
201
202 let mut batch_sizes = self.batch_sizes.lock().unwrap();
203 let offset_base = batch_sizes.iter().sum::<usize>();
204 let offsets = (0..records.len())
205 .map(|x| (x + offset_base) as i64)
206 .collect();
207 batch_sizes.push(records.len());
208 debug!("Return offsets: {offsets:?}");
209 Ok(offsets)
210 }
211
212 async fn get_offset(&self, _at: OffsetAt) -> rskafka::client::error::Result<i64> {
213 todo!()
214 }
215 }
216
217 #[tokio::test]
218 async fn test_producer() {
219 common_telemetry::init_default_ut_logging();
220 let record = record();
221 let delay = Duration::from_secs(0);
222 let client = Arc::new(MockClient {
223 error: None,
224 panic: None,
225 delay,
226 batch_sizes: Default::default(),
227 });
228 let provider = Arc::new(KafkaProvider::new(String::new()));
229 let producer = OrderedBatchProducer::new(
230 OrderedBatchProducer::channel(),
231 provider,
232 client.clone(),
233 Compression::NoCompression,
234 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
235 Box::new(NoopCollector),
236 Arc::new(DashMap::new()),
237 );
238
239 let region_id = RegionId::new(1, 1);
240 let handle = producer
242 .produce(
243 region_id,
244 vec![record.clone(), record.clone(), record.clone()],
245 )
246 .await
247 .unwrap();
248 assert_eq!(handle.wait().await.unwrap(), 2);
249 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1]);
250
251 let handle = producer
253 .produce(region_id, vec![record.clone(), record.clone()])
254 .await
255 .unwrap();
256 assert_eq!(handle.wait().await.unwrap(), 4);
257 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2]);
258
259 let handle = producer
261 .produce(region_id, vec![record.clone()])
262 .await
263 .unwrap();
264 assert_eq!(handle.wait().await.unwrap(), 5);
265 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2, 1]);
266 }
267
268 #[tokio::test]
269 async fn test_producer_client_error() {
270 let record = record();
271 let client = Arc::new(MockClient {
272 error: Some(ProtocolError::NetworkException),
273 panic: None,
274 delay: Duration::from_millis(1),
275 batch_sizes: Default::default(),
276 });
277 let provider = Arc::new(KafkaProvider::new(String::new()));
278 let producer = OrderedBatchProducer::new(
279 OrderedBatchProducer::channel(),
280 provider,
281 client.clone(),
282 Compression::NoCompression,
283 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
284 Box::new(NoopCollector),
285 Arc::new(DashMap::new()),
286 );
287
288 let region_id = RegionId::new(1, 1);
289 let mut futures = FuturesUnordered::new();
290 futures.push(
291 producer
292 .produce(
293 region_id,
294 vec![record.clone(), record.clone(), record.clone()],
295 )
296 .await
297 .unwrap()
298 .wait(),
299 );
300 futures.push(
301 producer
302 .produce(region_id, vec![record.clone(), record.clone()])
303 .await
304 .unwrap()
305 .wait(),
306 );
307 futures.push(
308 producer
309 .produce(region_id, vec![record.clone()])
310 .await
311 .unwrap()
312 .wait(),
313 );
314
315 futures.next().await.unwrap().unwrap_err();
316 futures.next().await.unwrap().unwrap_err();
317 futures.next().await.unwrap().unwrap_err();
318 }
319
320 #[tokio::test]
321 async fn test_producer_cancel() {
322 let record = record();
323 let client = Arc::new(MockClient {
324 error: None,
325 panic: None,
326 delay: Duration::from_millis(1),
327 batch_sizes: Default::default(),
328 });
329
330 let provider = Arc::new(KafkaProvider::new(String::new()));
331 let producer = OrderedBatchProducer::new(
332 OrderedBatchProducer::channel(),
333 provider,
334 client.clone(),
335 Compression::NoCompression,
336 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
337 Box::new(NoopCollector),
338 Arc::new(DashMap::new()),
339 );
340
341 let region_id = RegionId::new(1, 1);
342 let a = producer
343 .produce(
344 region_id,
345 vec![record.clone(), record.clone(), record.clone()],
346 )
347 .await
348 .unwrap()
349 .wait()
350 .fuse();
351
352 let b = producer
353 .produce(region_id, vec![record])
354 .await
355 .unwrap()
356 .wait()
357 .fuse();
358
359 let mut b = Box::pin(b);
360
361 {
362 let mut a = Box::pin(a);
364
365 futures::select_biased! {
368 _ = &mut a => panic!("a should not have flushed"),
369 _ = &mut b => panic!("b should not have flushed"),
370 _ = tokio::time::sleep(Duration::from_millis(1)).fuse() => {},
371 }
372 }
373
374 tokio::time::timeout(Duration::from_secs(1), b)
376 .await
377 .unwrap()
378 .unwrap();
379
380 assert_eq!(
381 client
382 .batch_sizes
383 .lock()
384 .unwrap()
385 .as_slice()
386 .iter()
387 .sum::<usize>(),
388 4
389 );
390 }
391}