1use std::sync::Arc;
16
17use common_telemetry::warn;
18use dashmap::DashMap;
19use rskafka::client::partition::{Compression, OffsetAt, PartitionClient};
20use rskafka::record::Record;
21use store_api::logstore::provider::KafkaProvider;
22use store_api::storage::RegionId;
23use tokio::sync::mpsc::{self, Receiver, Sender};
24
25use crate::error::{self, Result};
26use crate::kafka::index::IndexCollector;
27use crate::kafka::log_store::TopicStat;
28use crate::kafka::worker::{BackgroundProducerWorker, ProduceResultHandle, WorkerRequest};
29use crate::metrics::{
30 METRIC_KAFKA_CLIENT_BYTES_TOTAL, METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED,
31 METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL,
32};
33
34pub type OrderedBatchProducerRef = Arc<OrderedBatchProducer>;
35
36const REQUEST_BATCH_SIZE: usize = 64;
38
39const PRODUCER_CHANNEL_SIZE: usize = REQUEST_BATCH_SIZE * 2;
41
42#[derive(Debug)]
44pub(crate) struct OrderedBatchProducer {
45 pub(crate) sender: Sender<WorkerRequest>,
46}
47
48impl OrderedBatchProducer {
49 pub(crate) fn channel() -> (Sender<WorkerRequest>, Receiver<WorkerRequest>) {
50 mpsc::channel(PRODUCER_CHANNEL_SIZE)
51 }
52
53 pub(crate) fn new(
55 (tx, rx): (Sender<WorkerRequest>, Receiver<WorkerRequest>),
56 provider: Arc<KafkaProvider>,
57 client: Arc<dyn ProducerClient>,
58 compression: Compression,
59 max_batch_bytes: usize,
60 index_collector: Box<dyn IndexCollector>,
61 topic_stats: Arc<DashMap<Arc<KafkaProvider>, TopicStat>>,
62 ) -> Self {
63 let mut worker = BackgroundProducerWorker {
64 provider,
65 client,
66 compression,
67 receiver: rx,
68 request_batch_size: REQUEST_BATCH_SIZE,
69 max_batch_bytes,
70 index_collector,
71 topic_stats,
72 };
73 tokio::spawn(async move { worker.run().await });
74 Self { sender: tx }
75 }
76
77 pub(crate) async fn produce(
85 &self,
86 region_id: RegionId,
87 batch: Vec<Record>,
88 ) -> Result<ProduceResultHandle> {
89 let (req, handle) = WorkerRequest::new_produce_request(region_id, batch);
90 if self.sender.send(req).await.is_err() {
91 warn!("OrderedBatchProducer is already exited");
92 return error::OrderedBatchProducerStoppedSnafu {}.fail();
93 }
94
95 Ok(handle)
96 }
97
98 pub(crate) async fn fetch_latest_offset(&self) -> Result<()> {
101 if self
102 .sender
103 .send(WorkerRequest::FetchLatestOffset)
104 .await
105 .is_err()
106 {
107 warn!("OrderedBatchProducer is already exited");
108 return error::OrderedBatchProducerStoppedSnafu {}.fail();
109 }
110 Ok(())
111 }
112}
113
114pub struct ProduceResult {
115 pub offsets: Vec<i64>,
116 pub encoded_request_size: usize,
117}
118
119#[async_trait::async_trait]
120pub trait ProducerClient: std::fmt::Debug + Send + Sync {
121 async fn produce(
122 &self,
123 records: Vec<Record>,
124 compression: Compression,
125 ) -> rskafka::client::error::Result<ProduceResult>;
126
127 async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64>;
128}
129
130#[async_trait::async_trait]
131impl ProducerClient for PartitionClient {
132 async fn produce(
133 &self,
134 records: Vec<Record>,
135 compression: Compression,
136 ) -> rskafka::client::error::Result<ProduceResult> {
137 let partition = self.partition().to_string();
138 let _timer = METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED
139 .with_label_values(&[self.topic(), &partition])
140 .start_timer();
141
142 let result = self.produce(records, compression).await?;
143
144 METRIC_KAFKA_CLIENT_BYTES_TOTAL
145 .with_label_values(&[self.topic(), &partition])
146 .inc_by(result.encoded_request_size as u64);
147 METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL
148 .with_label_values(&[self.topic(), &partition])
149 .inc();
150
151 Ok(ProduceResult {
152 offsets: result.offsets,
153 encoded_request_size: result.encoded_request_size,
154 })
155 }
156
157 async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64> {
158 self.get_offset(at).await
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use std::sync::{Arc, Mutex};
165 use std::time::Duration;
166
167 use common_base::readable_size::ReadableSize;
168 use common_telemetry::debug;
169 use futures::stream::FuturesUnordered;
170 use futures::{FutureExt, StreamExt};
171 use rskafka::client::error::{Error as ClientError, RequestContext};
172 use rskafka::client::partition::Compression;
173 use rskafka::protocol::error::Error as ProtocolError;
174 use rskafka::record::Record;
175 use store_api::storage::RegionId;
176
177 use super::*;
178 use crate::kafka::index::NoopCollector;
179 use crate::kafka::producer::OrderedBatchProducer;
180 use crate::kafka::test_util::record;
181
182 #[derive(Debug)]
183 struct MockClient {
184 error: Option<ProtocolError>,
185 panic: Option<String>,
186 delay: Duration,
187 batch_sizes: Mutex<Vec<usize>>,
188 }
189
190 #[async_trait::async_trait]
191 impl ProducerClient for MockClient {
192 async fn produce(
193 &self,
194 records: Vec<Record>,
195 _compression: Compression,
196 ) -> rskafka::client::error::Result<ProduceResult> {
197 tokio::time::sleep(self.delay).await;
198
199 if let Some(e) = self.error {
200 return Err(ClientError::ServerError {
201 protocol_error: e,
202 error_message: None,
203 request: RequestContext::Partition("foo".into(), 1),
204 response: None,
205 is_virtual: false,
206 });
207 }
208
209 if let Some(p) = self.panic.as_ref() {
210 panic!("{}", p);
211 }
212
213 let mut batch_sizes = self.batch_sizes.lock().unwrap();
214 let offset_base = batch_sizes.iter().sum::<usize>();
215 let offsets = (0..records.len())
216 .map(|x| (x + offset_base) as i64)
217 .collect();
218 batch_sizes.push(records.len());
219 debug!("Return offsets: {offsets:?}");
220 Ok(ProduceResult {
221 offsets,
222 encoded_request_size: 0,
223 })
224 }
225
226 async fn get_offset(&self, _at: OffsetAt) -> rskafka::client::error::Result<i64> {
227 todo!()
228 }
229 }
230
231 #[tokio::test]
232 async fn test_producer() {
233 common_telemetry::init_default_ut_logging();
234 let record = record();
235 let delay = Duration::from_secs(0);
236 let client = Arc::new(MockClient {
237 error: None,
238 panic: None,
239 delay,
240 batch_sizes: Default::default(),
241 });
242 let provider = Arc::new(KafkaProvider::new(String::new()));
243 let producer = OrderedBatchProducer::new(
244 OrderedBatchProducer::channel(),
245 provider,
246 client.clone(),
247 Compression::NoCompression,
248 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
249 Box::new(NoopCollector),
250 Arc::new(DashMap::new()),
251 );
252
253 let region_id = RegionId::new(1, 1);
254 let handle = producer
256 .produce(
257 region_id,
258 vec![record.clone(), record.clone(), record.clone()],
259 )
260 .await
261 .unwrap();
262 assert_eq!(handle.wait().await.unwrap(), 2);
263 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1]);
264
265 let handle = producer
267 .produce(region_id, vec![record.clone(), record.clone()])
268 .await
269 .unwrap();
270 assert_eq!(handle.wait().await.unwrap(), 4);
271 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2]);
272
273 let handle = producer
275 .produce(region_id, vec![record.clone()])
276 .await
277 .unwrap();
278 assert_eq!(handle.wait().await.unwrap(), 5);
279 assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2, 1]);
280 }
281
282 #[tokio::test]
283 async fn test_producer_client_error() {
284 let record = record();
285 let client = Arc::new(MockClient {
286 error: Some(ProtocolError::NetworkException),
287 panic: None,
288 delay: Duration::from_millis(1),
289 batch_sizes: Default::default(),
290 });
291 let provider = Arc::new(KafkaProvider::new(String::new()));
292 let producer = OrderedBatchProducer::new(
293 OrderedBatchProducer::channel(),
294 provider,
295 client.clone(),
296 Compression::NoCompression,
297 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
298 Box::new(NoopCollector),
299 Arc::new(DashMap::new()),
300 );
301
302 let region_id = RegionId::new(1, 1);
303 let mut futures = FuturesUnordered::new();
304 futures.push(
305 producer
306 .produce(
307 region_id,
308 vec![record.clone(), record.clone(), record.clone()],
309 )
310 .await
311 .unwrap()
312 .wait(),
313 );
314 futures.push(
315 producer
316 .produce(region_id, vec![record.clone(), record.clone()])
317 .await
318 .unwrap()
319 .wait(),
320 );
321 futures.push(
322 producer
323 .produce(region_id, vec![record.clone()])
324 .await
325 .unwrap()
326 .wait(),
327 );
328
329 futures.next().await.unwrap().unwrap_err();
330 futures.next().await.unwrap().unwrap_err();
331 futures.next().await.unwrap().unwrap_err();
332 }
333
334 #[tokio::test]
335 async fn test_producer_cancel() {
336 let record = record();
337 let client = Arc::new(MockClient {
338 error: None,
339 panic: None,
340 delay: Duration::from_millis(1),
341 batch_sizes: Default::default(),
342 });
343
344 let provider = Arc::new(KafkaProvider::new(String::new()));
345 let producer = OrderedBatchProducer::new(
346 OrderedBatchProducer::channel(),
347 provider,
348 client.clone(),
349 Compression::NoCompression,
350 ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
351 Box::new(NoopCollector),
352 Arc::new(DashMap::new()),
353 );
354
355 let region_id = RegionId::new(1, 1);
356 let a = producer
357 .produce(
358 region_id,
359 vec![record.clone(), record.clone(), record.clone()],
360 )
361 .await
362 .unwrap()
363 .wait()
364 .fuse();
365
366 let b = producer
367 .produce(region_id, vec![record])
368 .await
369 .unwrap()
370 .wait()
371 .fuse();
372
373 let mut b = Box::pin(b);
374
375 {
376 let mut a = Box::pin(a);
378
379 futures::select_biased! {
382 _ = &mut a => panic!("a should not have flushed"),
383 _ = &mut b => panic!("b should not have flushed"),
384 _ = tokio::time::sleep(Duration::from_millis(1)).fuse() => {},
385 }
386 }
387
388 tokio::time::timeout(Duration::from_secs(1), b)
390 .await
391 .unwrap()
392 .unwrap();
393
394 assert_eq!(
395 client
396 .batch_sizes
397 .lock()
398 .unwrap()
399 .as_slice()
400 .iter()
401 .sum::<usize>(),
402 4
403 );
404 }
405}