log_store/kafka/
consumer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::ops::Range;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use common_telemetry::debug;
22use derive_builder::Builder;
23use futures::future::{BoxFuture, Fuse, FusedFuture};
24use futures::{FutureExt, Stream};
25use pin_project::pin_project;
26use rskafka::client::partition::PartitionClient;
27use rskafka::record::RecordAndOffset;
28
29use crate::kafka::index::{NextBatchHint, RegionWalIndexIterator};
30
31#[async_trait::async_trait]
32pub trait FetchClient: std::fmt::Debug + Send + Sync {
33    /// Fetch records.
34    ///
35    /// Arguments are identical to [`PartitionClient::fetch_records`].
36    async fn fetch_records(
37        &self,
38        offset: i64,
39        bytes: Range<i32>,
40        max_wait_ms: i32,
41    ) -> rskafka::client::error::Result<(Vec<RecordAndOffset>, i64)>;
42}
43
44#[async_trait::async_trait]
45impl FetchClient for PartitionClient {
46    async fn fetch_records(
47        &self,
48        offset: i64,
49        bytes: Range<i32>,
50        max_wait_ms: i32,
51    ) -> rskafka::client::error::Result<(Vec<RecordAndOffset>, i64)> {
52        self.fetch_records(offset, bytes, max_wait_ms).await
53    }
54}
55
56struct FetchResult {
57    records_and_offsets: Vec<RecordAndOffset>,
58    batch_size: usize,
59    fetch_bytes: i32,
60    watermark: i64,
61    used_offset: i64,
62}
63
64const MAX_BATCH_SIZE: usize = 52428800;
65const AVG_RECORD_SIZE: usize = 256 * 1024;
66
67/// The [`Consumer`] struct represents a Kafka consumer that fetches messages from
68/// a Kafka cluster. Yielding records respecting the [`RegionWalIndexIterator`].
69#[pin_project]
70#[derive(Builder)]
71#[builder(pattern = "owned")]
72pub struct Consumer {
73    #[builder(default = "-1")]
74    last_high_watermark: i64,
75
76    /// The client is used to fetch records from kafka topic.
77    client: Arc<dyn FetchClient>,
78
79    /// The max batch size in a single fetch request.
80    #[builder(default = "MAX_BATCH_SIZE")]
81    max_batch_size: usize,
82
83    /// The max wait milliseconds.
84    #[builder(default = "500")]
85    max_wait_ms: u32,
86
87    /// The avg record size
88    #[builder(default = "AVG_RECORD_SIZE")]
89    avg_record_size: usize,
90
91    /// Termination flag
92    #[builder(default = "false")]
93    terminated: bool,
94
95    /// The buffer of records.
96    buffer: RecordsBuffer,
97
98    /// The fetch future.
99    #[builder(default = "Fuse::terminated()")]
100    fetch_fut: Fuse<BoxFuture<'static, rskafka::client::error::Result<FetchResult>>>,
101}
102
103pub(crate) struct RecordsBuffer {
104    buffer: VecDeque<RecordAndOffset>,
105
106    index: Box<dyn RegionWalIndexIterator>,
107}
108
109impl RecordsBuffer {
110    /// Creates an empty [`RecordsBuffer`]
111    pub fn new(index: Box<dyn RegionWalIndexIterator>) -> Self {
112        RecordsBuffer {
113            buffer: VecDeque::new(),
114            index,
115        }
116    }
117}
118
119impl RecordsBuffer {
120    fn pop_front(&mut self) -> Option<RecordAndOffset> {
121        while let Some(index) = self.index.peek() {
122            if let Some(record_and_offset) = self.buffer.pop_front() {
123                if index == record_and_offset.offset as u64 {
124                    self.index.next();
125                    return Some(record_and_offset);
126                }
127            } else {
128                return None;
129            }
130        }
131
132        self.buffer.clear();
133        None
134    }
135
136    fn extend(&mut self, records: Vec<RecordAndOffset>) {
137        if let (Some(first), Some(index)) = (records.first(), self.index.peek()) {
138            // TODO(weny): throw an error?
139            assert!(
140                index <= first.offset as u64,
141                "index: {index}, first offset: {}",
142                first.offset
143            );
144        }
145        self.buffer.extend(records);
146    }
147}
148
149impl Stream for Consumer {
150    type Item = rskafka::client::error::Result<(RecordAndOffset, i64)>;
151
152    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153        let this = self.project();
154
155        loop {
156            if *this.terminated {
157                return Poll::Ready(None);
158            }
159
160            if this.buffer.index.peek().is_none() {
161                return Poll::Ready(None);
162            }
163
164            if let Some(x) = this.buffer.pop_front() {
165                debug!("Yielding record with offset: {}", x.offset);
166                return Poll::Ready(Some(Ok((x, *this.last_high_watermark))));
167            }
168
169            if this.fetch_fut.is_terminated() {
170                match this.buffer.index.peek() {
171                    Some(next_offset) => {
172                        let client = Arc::clone(this.client);
173                        let max_wait_ms = *this.max_wait_ms as i32;
174                        let offset = next_offset as i64;
175                        let NextBatchHint { bytes, len } = this
176                            .buffer
177                            .index
178                            .next_batch_hint(*this.avg_record_size)
179                            .unwrap_or(NextBatchHint {
180                                bytes: *this.avg_record_size,
181                                len: 1,
182                            });
183
184                        let fetch_range =
185                            1i32..(bytes.saturating_add(1).min(*this.max_batch_size) as i32);
186                        *this.fetch_fut = FutureExt::fuse(Box::pin(async move {
187                            let (records_and_offsets, watermark) = client
188                                .fetch_records(offset, fetch_range, max_wait_ms)
189                                .await?;
190
191                            Ok(FetchResult {
192                                records_and_offsets,
193                                watermark,
194                                used_offset: offset,
195                                fetch_bytes: bytes as i32,
196                                batch_size: len,
197                            })
198                        }));
199                    }
200                    None => {
201                        return Poll::Ready(None);
202                    }
203                }
204            }
205
206            let data = futures::ready!(this.fetch_fut.poll_unpin(cx));
207
208            match data {
209                Ok(FetchResult {
210                    mut records_and_offsets,
211                    watermark,
212                    used_offset,
213                    fetch_bytes,
214                    batch_size,
215                }) => {
216                    // Sort records by offset in case they aren't in order
217                    records_and_offsets.sort_unstable_by_key(|x| x.offset);
218                    *this.last_high_watermark = watermark;
219                    if !records_and_offsets.is_empty() {
220                        *this.avg_record_size = fetch_bytes as usize / records_and_offsets.len();
221                        debug!("set avg_record_size: {}", *this.avg_record_size);
222                    }
223
224                    debug!(
225                        "Fetch result: {:?}, used_offset: {used_offset}, max_batch_size: {fetch_bytes}, expected batch_num: {batch_size}, actual batch_num: {}",
226                        records_and_offsets
227                            .iter()
228                            .map(|record| record.offset)
229                            .collect::<Vec<_>>(),
230                        records_and_offsets
231                            .len()
232                    );
233                    this.buffer.extend(records_and_offsets);
234                    continue;
235                }
236                Err(e) => {
237                    *this.terminated = true;
238
239                    return Poll::Ready(Some(Err(e)));
240                }
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use std::collections::VecDeque;
249    use std::ops::Range;
250    use std::sync::Arc;
251
252    use chrono::{TimeZone, Utc};
253    use futures::future::Fuse;
254    use futures::TryStreamExt;
255    use rskafka::record::{Record, RecordAndOffset};
256
257    use super::FetchClient;
258    use crate::kafka::consumer::{Consumer, RecordsBuffer};
259    use crate::kafka::index::{MultipleRegionWalIndexIterator, RegionWalRange, RegionWalVecIndex};
260
261    #[derive(Debug)]
262    struct MockFetchClient {
263        record: Record,
264    }
265
266    #[async_trait::async_trait]
267    impl FetchClient for MockFetchClient {
268        async fn fetch_records(
269            &self,
270            offset: i64,
271            bytes: Range<i32>,
272            _max_wait_ms: i32,
273        ) -> rskafka::client::error::Result<(Vec<RecordAndOffset>, i64)> {
274            let record_size = self.record.approximate_size();
275            let num = (bytes.end.unsigned_abs() as usize / record_size).max(1);
276
277            let records = (0..num)
278                .map(|idx| RecordAndOffset {
279                    record: self.record.clone(),
280                    offset: offset + idx as i64,
281                })
282                .collect::<Vec<_>>();
283            let max_offset = offset + records.len() as i64;
284            Ok((records, max_offset))
285        }
286    }
287
288    fn test_record() -> Record {
289        Record {
290            key: Some(vec![0; 4]),
291            value: Some(vec![0; 6]),
292            headers: Default::default(),
293            timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
294        }
295    }
296
297    #[tokio::test]
298    async fn test_consumer_with_index() {
299        common_telemetry::init_default_ut_logging();
300        let record = test_record();
301        let record_size = record.approximate_size();
302        let mock_client = MockFetchClient {
303            record: record.clone(),
304        };
305        let index = RegionWalVecIndex::new([1, 3, 4, 8, 10, 12], record_size * 3);
306        let consumer = Consumer {
307            last_high_watermark: -1,
308            client: Arc::new(mock_client),
309            max_batch_size: usize::MAX,
310            max_wait_ms: 500,
311            avg_record_size: record_size,
312            terminated: false,
313            buffer: RecordsBuffer {
314                buffer: VecDeque::new(),
315                index: Box::new(index),
316            },
317            fetch_fut: Fuse::terminated(),
318        };
319
320        let records = consumer.try_collect::<Vec<_>>().await.unwrap();
321        assert_eq!(
322            records
323                .into_iter()
324                .map(|(x, _)| x.offset)
325                .collect::<Vec<_>>(),
326            vec![1, 3, 4, 8, 10, 12]
327        )
328    }
329
330    #[tokio::test]
331    async fn test_consumer_without_index() {
332        common_telemetry::init_default_ut_logging();
333        let record = test_record();
334        let mock_client = MockFetchClient {
335            record: record.clone(),
336        };
337        let index = RegionWalRange::new(0..30, 1024);
338        let consumer = Consumer {
339            last_high_watermark: -1,
340            client: Arc::new(mock_client),
341            max_batch_size: usize::MAX,
342            max_wait_ms: 500,
343            avg_record_size: record.approximate_size(),
344            terminated: false,
345            buffer: RecordsBuffer {
346                buffer: VecDeque::new(),
347                index: Box::new(index),
348            },
349            fetch_fut: Fuse::terminated(),
350        };
351
352        let records = consumer.try_collect::<Vec<_>>().await.unwrap();
353        assert_eq!(
354            records
355                .into_iter()
356                .map(|(x, _)| x.offset)
357                .collect::<Vec<_>>(),
358            (0..30).collect::<Vec<_>>()
359        )
360    }
361
362    #[tokio::test]
363    async fn test_consumer_with_multiple_index() {
364        common_telemetry::init_default_ut_logging();
365        let record = test_record();
366        let mock_client = MockFetchClient {
367            record: record.clone(),
368        };
369
370        let iter0 = Box::new(RegionWalRange::new(0..0, 1024)) as _;
371        let iter1 = Box::new(RegionWalVecIndex::new(
372            [0, 1, 2, 7, 8, 11],
373            record.approximate_size() * 4,
374        )) as _;
375        let iter2 = Box::new(RegionWalRange::new(12..12, 1024)) as _;
376        let iter3 = Box::new(RegionWalRange::new(1024..1028, 1024)) as _;
377        let iter = MultipleRegionWalIndexIterator::new([iter0, iter1, iter2, iter3]);
378
379        let consumer = Consumer {
380            last_high_watermark: -1,
381            client: Arc::new(mock_client),
382            max_batch_size: usize::MAX,
383            max_wait_ms: 500,
384            avg_record_size: record.approximate_size(),
385            terminated: false,
386            buffer: RecordsBuffer {
387                buffer: VecDeque::new(),
388                index: Box::new(iter),
389            },
390            fetch_fut: Fuse::terminated(),
391        };
392
393        let records = consumer.try_collect::<Vec<_>>().await.unwrap();
394        assert_eq!(
395            records
396                .into_iter()
397                .map(|(x, _)| x.offset)
398                .collect::<Vec<_>>(),
399            [0, 1, 2, 7, 8, 11, 1024, 1025, 1026, 1027]
400        )
401    }
402}