log_store/kafka/
producer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::warn;
18use dashmap::DashMap;
19use rskafka::client::partition::{Compression, OffsetAt, PartitionClient};
20use rskafka::record::Record;
21use store_api::logstore::provider::KafkaProvider;
22use store_api::storage::RegionId;
23use tokio::sync::mpsc::{self, Receiver, Sender};
24
25use crate::error::{self, Result};
26use crate::kafka::index::IndexCollector;
27use crate::kafka::log_store::TopicStat;
28use crate::kafka::worker::{BackgroundProducerWorker, ProduceResultHandle, WorkerRequest};
29use crate::metrics::{
30    METRIC_KAFKA_CLIENT_BYTES_TOTAL, METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED,
31    METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL,
32};
33
34pub type OrderedBatchProducerRef = Arc<OrderedBatchProducer>;
35
36// Max batch size for a `OrderedBatchProducer` to handle requests.
37const REQUEST_BATCH_SIZE: usize = 64;
38
39// Producer channel size
40const PRODUCER_CHANNEL_SIZE: usize = REQUEST_BATCH_SIZE * 2;
41
42/// [`OrderedBatchProducer`] attempts to aggregate multiple produce requests together
43#[derive(Debug)]
44pub(crate) struct OrderedBatchProducer {
45    pub(crate) sender: Sender<WorkerRequest>,
46}
47
48impl OrderedBatchProducer {
49    pub(crate) fn channel() -> (Sender<WorkerRequest>, Receiver<WorkerRequest>) {
50        mpsc::channel(PRODUCER_CHANNEL_SIZE)
51    }
52
53    /// Constructs a new [`OrderedBatchProducer`].
54    pub(crate) fn new(
55        (tx, rx): (Sender<WorkerRequest>, Receiver<WorkerRequest>),
56        provider: Arc<KafkaProvider>,
57        client: Arc<dyn ProducerClient>,
58        compression: Compression,
59        max_batch_bytes: usize,
60        index_collector: Box<dyn IndexCollector>,
61        topic_stats: Arc<DashMap<Arc<KafkaProvider>, TopicStat>>,
62    ) -> Self {
63        let mut worker = BackgroundProducerWorker {
64            provider,
65            client,
66            compression,
67            receiver: rx,
68            request_batch_size: REQUEST_BATCH_SIZE,
69            max_batch_bytes,
70            index_collector,
71            topic_stats,
72        };
73        tokio::spawn(async move { worker.run().await });
74        Self { sender: tx }
75    }
76
77    /// Writes `data` to the [`OrderedBatchProducer`].
78    ///
79    /// Returns the [ProduceResultHandle], which will receive a result when data has been committed to Kafka
80    /// or an unrecoverable error has been encountered.
81    ///
82    /// ## Panic
83    /// Panic if any [Record]'s `approximate_size` > `max_batch_bytes`.
84    pub(crate) async fn produce(
85        &self,
86        region_id: RegionId,
87        batch: Vec<Record>,
88    ) -> Result<ProduceResultHandle> {
89        let (req, handle) = WorkerRequest::new_produce_request(region_id, batch);
90        if self.sender.send(req).await.is_err() {
91            warn!("OrderedBatchProducer is already exited");
92            return error::OrderedBatchProducerStoppedSnafu {}.fail();
93        }
94
95        Ok(handle)
96    }
97
98    /// Sends an [WorkerRequest::FetchLatestOffset] request to the producer.
99    /// This is used to fetch the latest offset for the topic.
100    pub(crate) async fn fetch_latest_offset(&self) -> Result<()> {
101        if self
102            .sender
103            .send(WorkerRequest::FetchLatestOffset)
104            .await
105            .is_err()
106        {
107            warn!("OrderedBatchProducer is already exited");
108            return error::OrderedBatchProducerStoppedSnafu {}.fail();
109        }
110        Ok(())
111    }
112}
113
114pub struct ProduceResult {
115    pub offsets: Vec<i64>,
116    pub encoded_request_size: usize,
117}
118
119#[async_trait::async_trait]
120pub trait ProducerClient: std::fmt::Debug + Send + Sync {
121    async fn produce(
122        &self,
123        records: Vec<Record>,
124        compression: Compression,
125    ) -> rskafka::client::error::Result<ProduceResult>;
126
127    async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64>;
128}
129
130#[async_trait::async_trait]
131impl ProducerClient for PartitionClient {
132    async fn produce(
133        &self,
134        records: Vec<Record>,
135        compression: Compression,
136    ) -> rskafka::client::error::Result<ProduceResult> {
137        let partition = self.partition().to_string();
138        let _timer = METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED
139            .with_label_values(&[self.topic(), &partition])
140            .start_timer();
141
142        let result = self.produce(records, compression).await?;
143
144        METRIC_KAFKA_CLIENT_BYTES_TOTAL
145            .with_label_values(&[self.topic(), &partition])
146            .inc_by(result.encoded_request_size as u64);
147        METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL
148            .with_label_values(&[self.topic(), &partition])
149            .inc();
150
151        Ok(ProduceResult {
152            offsets: result.offsets,
153            encoded_request_size: result.encoded_request_size,
154        })
155    }
156
157    async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64> {
158        self.get_offset(at).await
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use std::sync::{Arc, Mutex};
165    use std::time::Duration;
166
167    use common_base::readable_size::ReadableSize;
168    use common_telemetry::debug;
169    use futures::stream::FuturesUnordered;
170    use futures::{FutureExt, StreamExt};
171    use rskafka::client::error::{Error as ClientError, RequestContext};
172    use rskafka::client::partition::Compression;
173    use rskafka::protocol::error::Error as ProtocolError;
174    use rskafka::record::Record;
175    use store_api::storage::RegionId;
176
177    use super::*;
178    use crate::kafka::index::NoopCollector;
179    use crate::kafka::producer::OrderedBatchProducer;
180    use crate::kafka::test_util::record;
181
182    #[derive(Debug)]
183    struct MockClient {
184        error: Option<ProtocolError>,
185        panic: Option<String>,
186        delay: Duration,
187        batch_sizes: Mutex<Vec<usize>>,
188    }
189
190    #[async_trait::async_trait]
191    impl ProducerClient for MockClient {
192        async fn produce(
193            &self,
194            records: Vec<Record>,
195            _compression: Compression,
196        ) -> rskafka::client::error::Result<ProduceResult> {
197            tokio::time::sleep(self.delay).await;
198
199            if let Some(e) = self.error {
200                return Err(ClientError::ServerError {
201                    protocol_error: e,
202                    error_message: None,
203                    request: RequestContext::Partition("foo".into(), 1),
204                    response: None,
205                    is_virtual: false,
206                });
207            }
208
209            if let Some(p) = self.panic.as_ref() {
210                panic!("{}", p);
211            }
212
213            let mut batch_sizes = self.batch_sizes.lock().unwrap();
214            let offset_base = batch_sizes.iter().sum::<usize>();
215            let offsets = (0..records.len())
216                .map(|x| (x + offset_base) as i64)
217                .collect();
218            batch_sizes.push(records.len());
219            debug!("Return offsets: {offsets:?}");
220            Ok(ProduceResult {
221                offsets,
222                encoded_request_size: 0,
223            })
224        }
225
226        async fn get_offset(&self, _at: OffsetAt) -> rskafka::client::error::Result<i64> {
227            todo!()
228        }
229    }
230
231    #[tokio::test]
232    async fn test_producer() {
233        common_telemetry::init_default_ut_logging();
234        let record = record();
235        let delay = Duration::from_secs(0);
236        let client = Arc::new(MockClient {
237            error: None,
238            panic: None,
239            delay,
240            batch_sizes: Default::default(),
241        });
242        let provider = Arc::new(KafkaProvider::new(String::new()));
243        let producer = OrderedBatchProducer::new(
244            OrderedBatchProducer::channel(),
245            provider,
246            client.clone(),
247            Compression::NoCompression,
248            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
249            Box::new(NoopCollector),
250            Arc::new(DashMap::new()),
251        );
252
253        let region_id = RegionId::new(1, 1);
254        // Produces 3 records
255        let handle = producer
256            .produce(
257                region_id,
258                vec![record.clone(), record.clone(), record.clone()],
259            )
260            .await
261            .unwrap();
262        assert_eq!(handle.wait().await.unwrap(), 2);
263        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1]);
264
265        // Produces 2 records
266        let handle = producer
267            .produce(region_id, vec![record.clone(), record.clone()])
268            .await
269            .unwrap();
270        assert_eq!(handle.wait().await.unwrap(), 4);
271        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2]);
272
273        // Produces 1 records
274        let handle = producer
275            .produce(region_id, vec![record.clone()])
276            .await
277            .unwrap();
278        assert_eq!(handle.wait().await.unwrap(), 5);
279        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2, 1]);
280    }
281
282    #[tokio::test]
283    async fn test_producer_client_error() {
284        let record = record();
285        let client = Arc::new(MockClient {
286            error: Some(ProtocolError::NetworkException),
287            panic: None,
288            delay: Duration::from_millis(1),
289            batch_sizes: Default::default(),
290        });
291        let provider = Arc::new(KafkaProvider::new(String::new()));
292        let producer = OrderedBatchProducer::new(
293            OrderedBatchProducer::channel(),
294            provider,
295            client.clone(),
296            Compression::NoCompression,
297            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
298            Box::new(NoopCollector),
299            Arc::new(DashMap::new()),
300        );
301
302        let region_id = RegionId::new(1, 1);
303        let mut futures = FuturesUnordered::new();
304        futures.push(
305            producer
306                .produce(
307                    region_id,
308                    vec![record.clone(), record.clone(), record.clone()],
309                )
310                .await
311                .unwrap()
312                .wait(),
313        );
314        futures.push(
315            producer
316                .produce(region_id, vec![record.clone(), record.clone()])
317                .await
318                .unwrap()
319                .wait(),
320        );
321        futures.push(
322            producer
323                .produce(region_id, vec![record.clone()])
324                .await
325                .unwrap()
326                .wait(),
327        );
328
329        futures.next().await.unwrap().unwrap_err();
330        futures.next().await.unwrap().unwrap_err();
331        futures.next().await.unwrap().unwrap_err();
332    }
333
334    #[tokio::test]
335    async fn test_producer_cancel() {
336        let record = record();
337        let client = Arc::new(MockClient {
338            error: None,
339            panic: None,
340            delay: Duration::from_millis(1),
341            batch_sizes: Default::default(),
342        });
343
344        let provider = Arc::new(KafkaProvider::new(String::new()));
345        let producer = OrderedBatchProducer::new(
346            OrderedBatchProducer::channel(),
347            provider,
348            client.clone(),
349            Compression::NoCompression,
350            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
351            Box::new(NoopCollector),
352            Arc::new(DashMap::new()),
353        );
354
355        let region_id = RegionId::new(1, 1);
356        let a = producer
357            .produce(
358                region_id,
359                vec![record.clone(), record.clone(), record.clone()],
360            )
361            .await
362            .unwrap()
363            .wait()
364            .fuse();
365
366        let b = producer
367            .produce(region_id, vec![record])
368            .await
369            .unwrap()
370            .wait()
371            .fuse();
372
373        let mut b = Box::pin(b);
374
375        {
376            // Cancel a when it exits this block
377            let mut a = Box::pin(a);
378
379            // Select biased to encourage `a` to be the one with the linger that
380            // expires first and performs the produce operation
381            futures::select_biased! {
382                _ = &mut a => panic!("a should not have flushed"),
383                _ = &mut b => panic!("b should not have flushed"),
384                _ = tokio::time::sleep(Duration::from_millis(1)).fuse() => {},
385            }
386        }
387
388        // But `b` should still complete successfully
389        tokio::time::timeout(Duration::from_secs(1), b)
390            .await
391            .unwrap()
392            .unwrap();
393
394        assert_eq!(
395            client
396                .batch_sizes
397                .lock()
398                .unwrap()
399                .as_slice()
400                .iter()
401                .sum::<usize>(),
402            4
403        );
404    }
405}