log_store/kafka/
producer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_telemetry::warn;
18use dashmap::DashMap;
19use rskafka::client::partition::{Compression, OffsetAt, PartitionClient};
20use rskafka::record::Record;
21use store_api::logstore::provider::KafkaProvider;
22use store_api::storage::RegionId;
23use tokio::sync::mpsc::{self, Receiver, Sender};
24
25use crate::error::{self, Result};
26use crate::kafka::index::IndexCollector;
27use crate::kafka::worker::{BackgroundProducerWorker, ProduceResultHandle, WorkerRequest};
28use crate::metrics::{
29    METRIC_KAFKA_CLIENT_BYTES_TOTAL, METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED,
30    METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL,
31};
32
33pub type OrderedBatchProducerRef = Arc<OrderedBatchProducer>;
34
35// Max batch size for a `OrderedBatchProducer` to handle requests.
36const REQUEST_BATCH_SIZE: usize = 64;
37
38// Producer channel size
39const PRODUCER_CHANNEL_SIZE: usize = REQUEST_BATCH_SIZE * 2;
40
41/// [`OrderedBatchProducer`] attempts to aggregate multiple produce requests together
42#[derive(Debug)]
43pub(crate) struct OrderedBatchProducer {
44    pub(crate) sender: Sender<WorkerRequest>,
45}
46
47impl OrderedBatchProducer {
48    pub(crate) fn channel() -> (Sender<WorkerRequest>, Receiver<WorkerRequest>) {
49        mpsc::channel(PRODUCER_CHANNEL_SIZE)
50    }
51
52    /// Constructs a new [`OrderedBatchProducer`].
53    pub(crate) fn new(
54        (tx, rx): (Sender<WorkerRequest>, Receiver<WorkerRequest>),
55        provider: Arc<KafkaProvider>,
56        client: Arc<dyn ProducerClient>,
57        compression: Compression,
58        max_batch_bytes: usize,
59        index_collector: Box<dyn IndexCollector>,
60        high_watermark: Arc<DashMap<Arc<KafkaProvider>, u64>>,
61    ) -> Self {
62        let mut worker = BackgroundProducerWorker {
63            provider,
64            client,
65            compression,
66            receiver: rx,
67            request_batch_size: REQUEST_BATCH_SIZE,
68            max_batch_bytes,
69            index_collector,
70            high_watermark,
71        };
72        tokio::spawn(async move { worker.run().await });
73        Self { sender: tx }
74    }
75
76    /// Writes `data` to the [`OrderedBatchProducer`].
77    ///
78    /// Returns the [ProduceResultHandle], which will receive a result when data has been committed to Kafka
79    /// or an unrecoverable error has been encountered.
80    ///
81    /// ## Panic
82    /// Panic if any [Record]'s `approximate_size` > `max_batch_bytes`.
83    pub(crate) async fn produce(
84        &self,
85        region_id: RegionId,
86        batch: Vec<Record>,
87    ) -> Result<ProduceResultHandle> {
88        let (req, handle) = WorkerRequest::new_produce_request(region_id, batch);
89        if self.sender.send(req).await.is_err() {
90            warn!("OrderedBatchProducer is already exited");
91            return error::OrderedBatchProducerStoppedSnafu {}.fail();
92        }
93
94        Ok(handle)
95    }
96
97    /// Sends an [WorkerRequest::UpdateHighWatermark] request to the producer.
98    /// This is used to update the high watermark for the topic.
99    pub(crate) async fn update_high_watermark(&self) -> Result<()> {
100        if self
101            .sender
102            .send(WorkerRequest::UpdateHighWatermark)
103            .await
104            .is_err()
105        {
106            warn!("OrderedBatchProducer is already exited");
107            return error::OrderedBatchProducerStoppedSnafu {}.fail();
108        }
109        Ok(())
110    }
111}
112
113#[async_trait::async_trait]
114pub trait ProducerClient: std::fmt::Debug + Send + Sync {
115    async fn produce(
116        &self,
117        records: Vec<Record>,
118        compression: Compression,
119    ) -> rskafka::client::error::Result<Vec<i64>>;
120
121    async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64>;
122}
123
124#[async_trait::async_trait]
125impl ProducerClient for PartitionClient {
126    async fn produce(
127        &self,
128        records: Vec<Record>,
129        compression: Compression,
130    ) -> rskafka::client::error::Result<Vec<i64>> {
131        let total_size = records.iter().map(|r| r.approximate_size()).sum::<usize>();
132        let partition = self.partition().to_string();
133        METRIC_KAFKA_CLIENT_BYTES_TOTAL
134            .with_label_values(&[self.topic(), &partition])
135            .inc_by(total_size as u64);
136        METRIC_KAFKA_CLIENT_TRAFFIC_TOTAL
137            .with_label_values(&[self.topic(), &partition])
138            .inc();
139        let _timer = METRIC_KAFKA_CLIENT_PRODUCE_ELAPSED
140            .with_label_values(&[self.topic(), &partition])
141            .start_timer();
142
143        self.produce(records, compression).await
144    }
145
146    async fn get_offset(&self, at: OffsetAt) -> rskafka::client::error::Result<i64> {
147        self.get_offset(at).await
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::{Arc, Mutex};
154    use std::time::Duration;
155
156    use common_base::readable_size::ReadableSize;
157    use common_telemetry::debug;
158    use futures::stream::FuturesUnordered;
159    use futures::{FutureExt, StreamExt};
160    use rskafka::client::error::{Error as ClientError, RequestContext};
161    use rskafka::client::partition::Compression;
162    use rskafka::protocol::error::Error as ProtocolError;
163    use rskafka::record::Record;
164    use store_api::storage::RegionId;
165
166    use super::*;
167    use crate::kafka::index::NoopCollector;
168    use crate::kafka::producer::OrderedBatchProducer;
169    use crate::kafka::test_util::record;
170
171    #[derive(Debug)]
172    struct MockClient {
173        error: Option<ProtocolError>,
174        panic: Option<String>,
175        delay: Duration,
176        batch_sizes: Mutex<Vec<usize>>,
177    }
178
179    #[async_trait::async_trait]
180    impl ProducerClient for MockClient {
181        async fn produce(
182            &self,
183            records: Vec<Record>,
184            _compression: Compression,
185        ) -> rskafka::client::error::Result<Vec<i64>> {
186            tokio::time::sleep(self.delay).await;
187
188            if let Some(e) = self.error {
189                return Err(ClientError::ServerError {
190                    protocol_error: e,
191                    error_message: None,
192                    request: RequestContext::Partition("foo".into(), 1),
193                    response: None,
194                    is_virtual: false,
195                });
196            }
197
198            if let Some(p) = self.panic.as_ref() {
199                panic!("{}", p);
200            }
201
202            let mut batch_sizes = self.batch_sizes.lock().unwrap();
203            let offset_base = batch_sizes.iter().sum::<usize>();
204            let offsets = (0..records.len())
205                .map(|x| (x + offset_base) as i64)
206                .collect();
207            batch_sizes.push(records.len());
208            debug!("Return offsets: {offsets:?}");
209            Ok(offsets)
210        }
211
212        async fn get_offset(&self, _at: OffsetAt) -> rskafka::client::error::Result<i64> {
213            todo!()
214        }
215    }
216
217    #[tokio::test]
218    async fn test_producer() {
219        common_telemetry::init_default_ut_logging();
220        let record = record();
221        let delay = Duration::from_secs(0);
222        let client = Arc::new(MockClient {
223            error: None,
224            panic: None,
225            delay,
226            batch_sizes: Default::default(),
227        });
228        let provider = Arc::new(KafkaProvider::new(String::new()));
229        let producer = OrderedBatchProducer::new(
230            OrderedBatchProducer::channel(),
231            provider,
232            client.clone(),
233            Compression::NoCompression,
234            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
235            Box::new(NoopCollector),
236            Arc::new(DashMap::new()),
237        );
238
239        let region_id = RegionId::new(1, 1);
240        // Produces 3 records
241        let handle = producer
242            .produce(
243                region_id,
244                vec![record.clone(), record.clone(), record.clone()],
245            )
246            .await
247            .unwrap();
248        assert_eq!(handle.wait().await.unwrap(), 2);
249        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1]);
250
251        // Produces 2 records
252        let handle = producer
253            .produce(region_id, vec![record.clone(), record.clone()])
254            .await
255            .unwrap();
256        assert_eq!(handle.wait().await.unwrap(), 4);
257        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2]);
258
259        // Produces 1 records
260        let handle = producer
261            .produce(region_id, vec![record.clone()])
262            .await
263            .unwrap();
264        assert_eq!(handle.wait().await.unwrap(), 5);
265        assert_eq!(client.batch_sizes.lock().unwrap().as_slice(), &[2, 1, 2, 1]);
266    }
267
268    #[tokio::test]
269    async fn test_producer_client_error() {
270        let record = record();
271        let client = Arc::new(MockClient {
272            error: Some(ProtocolError::NetworkException),
273            panic: None,
274            delay: Duration::from_millis(1),
275            batch_sizes: Default::default(),
276        });
277        let provider = Arc::new(KafkaProvider::new(String::new()));
278        let producer = OrderedBatchProducer::new(
279            OrderedBatchProducer::channel(),
280            provider,
281            client.clone(),
282            Compression::NoCompression,
283            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
284            Box::new(NoopCollector),
285            Arc::new(DashMap::new()),
286        );
287
288        let region_id = RegionId::new(1, 1);
289        let mut futures = FuturesUnordered::new();
290        futures.push(
291            producer
292                .produce(
293                    region_id,
294                    vec![record.clone(), record.clone(), record.clone()],
295                )
296                .await
297                .unwrap()
298                .wait(),
299        );
300        futures.push(
301            producer
302                .produce(region_id, vec![record.clone(), record.clone()])
303                .await
304                .unwrap()
305                .wait(),
306        );
307        futures.push(
308            producer
309                .produce(region_id, vec![record.clone()])
310                .await
311                .unwrap()
312                .wait(),
313        );
314
315        futures.next().await.unwrap().unwrap_err();
316        futures.next().await.unwrap().unwrap_err();
317        futures.next().await.unwrap().unwrap_err();
318    }
319
320    #[tokio::test]
321    async fn test_producer_cancel() {
322        let record = record();
323        let client = Arc::new(MockClient {
324            error: None,
325            panic: None,
326            delay: Duration::from_millis(1),
327            batch_sizes: Default::default(),
328        });
329
330        let provider = Arc::new(KafkaProvider::new(String::new()));
331        let producer = OrderedBatchProducer::new(
332            OrderedBatchProducer::channel(),
333            provider,
334            client.clone(),
335            Compression::NoCompression,
336            ReadableSize((record.approximate_size() * 2) as u64).as_bytes() as usize,
337            Box::new(NoopCollector),
338            Arc::new(DashMap::new()),
339        );
340
341        let region_id = RegionId::new(1, 1);
342        let a = producer
343            .produce(
344                region_id,
345                vec![record.clone(), record.clone(), record.clone()],
346            )
347            .await
348            .unwrap()
349            .wait()
350            .fuse();
351
352        let b = producer
353            .produce(region_id, vec![record])
354            .await
355            .unwrap()
356            .wait()
357            .fuse();
358
359        let mut b = Box::pin(b);
360
361        {
362            // Cancel a when it exits this block
363            let mut a = Box::pin(a);
364
365            // Select biased to encourage `a` to be the one with the linger that
366            // expires first and performs the produce operation
367            futures::select_biased! {
368                _ = &mut a => panic!("a should not have flushed"),
369                _ = &mut b => panic!("b should not have flushed"),
370                _ = tokio::time::sleep(Duration::from_millis(1)).fuse() => {},
371            }
372        }
373
374        // But `b` should still complete successfully
375        tokio::time::timeout(Duration::from_secs(1), b)
376            .await
377            .unwrap()
378            .unwrap();
379
380        assert_eq!(
381            client
382                .batch_sizes
383                .lock()
384                .unwrap()
385                .as_slice()
386                .iter()
387                .sum::<usize>(),
388            4
389        );
390    }
391}