1use std::collections::HashMap;
16use std::sync::Arc;
17
18use async_stream::stream;
19use common_telemetry::{debug, error};
20use futures::future::join_all;
21use snafu::OptionExt;
22use store_api::logstore::entry::Entry;
23use store_api::logstore::provider::Provider;
24use store_api::storage::RegionId;
25use tokio::sync::mpsc::{self, Receiver, Sender};
26use tokio::sync::oneshot;
27use tokio_stream::StreamExt;
28
29use crate::error::{self, Result};
30use crate::wal::entry_reader::{decode_raw_entry, WalEntryReader};
31use crate::wal::raw_entry_reader::RawEntryReader;
32use crate::wal::{EntryId, WalEntryStream};
33
34pub(crate) struct WalEntryDistributor {
36 raw_wal_reader: Arc<dyn RawEntryReader>,
37 provider: Provider,
38 senders: HashMap<RegionId, Sender<Entry>>,
40 arg_receivers: Vec<(RegionId, oneshot::Receiver<EntryId>)>,
42}
43
44impl WalEntryDistributor {
45 pub async fn distribute(mut self) -> Result<()> {
47 let arg_futures = self
48 .arg_receivers
49 .iter_mut()
50 .map(|(region_id, receiver)| async { (*region_id, receiver.await.ok()) });
51 let args = join_all(arg_futures)
52 .await
53 .into_iter()
54 .filter_map(|(region_id, start_id)| start_id.map(|start_id| (region_id, start_id)))
55 .collect::<Vec<_>>();
56
57 if args.is_empty() {
59 return Ok(());
60 }
61 let min_start_id = args.iter().map(|(_, start_id)| *start_id).min().unwrap();
63 let receivers: HashMap<_, _> = args
64 .into_iter()
65 .map(|(region_id, start_id)| {
66 (
67 region_id,
68 EntryReceiver {
69 start_id,
70 sender: self.senders[®ion_id].clone(),
71 },
72 )
73 })
74 .collect();
75
76 let mut stream = self.raw_wal_reader.read(&self.provider, min_start_id)?;
77 while let Some(entry) = stream.next().await {
78 let entry = entry?;
79 let entry_id = entry.entry_id();
80 let region_id = entry.region_id();
81
82 if let Some(EntryReceiver { sender, start_id }) = receivers.get(®ion_id) {
83 if entry_id >= *start_id {
84 if let Err(err) = sender.send(entry).await {
85 error!(err; "Failed to distribute raw entry, entry_id:{}, region_id: {}", entry_id, region_id);
86 }
87 }
88 } else {
89 debug!("Subscriber not found, region_id: {}", region_id);
90 }
91 }
92
93 Ok(())
94 }
95}
96
97#[derive(Debug)]
99pub(crate) struct WalEntryReceiver {
100 entry_receiver: Option<Receiver<Entry>>,
102 arg_sender: Option<oneshot::Sender<EntryId>>,
104}
105
106impl WalEntryReceiver {
107 pub fn new(entry_receiver: Receiver<Entry>, arg_sender: oneshot::Sender<EntryId>) -> Self {
108 Self {
109 entry_receiver: Some(entry_receiver),
110 arg_sender: Some(arg_sender),
111 }
112 }
113}
114
115impl WalEntryReader for WalEntryReceiver {
116 fn read(&mut self, _provider: &Provider, start_id: EntryId) -> Result<WalEntryStream<'static>> {
117 let arg_sender =
118 self.arg_sender
119 .take()
120 .with_context(|| error::InvalidWalReadRequestSnafu {
121 reason: format!("Call WalEntryReceiver multiple time, start_id: {start_id}"),
122 })?;
123 let mut entry_receiver = self.entry_receiver.take().unwrap();
125
126 if arg_sender.send(start_id).is_err() {
127 return error::InvalidWalReadRequestSnafu {
128 reason: format!(
129 "WalEntryDistributor is dropped, failed to send arg, start_id: {start_id}"
130 ),
131 }
132 .fail();
133 }
134
135 let stream = stream! {
136 let mut buffered_entry = None;
137 while let Some(next_entry) = entry_receiver.recv().await {
138 match buffered_entry.take() {
139 Some(entry) => {
140 yield decode_raw_entry(entry);
141 buffered_entry = Some(next_entry);
142 },
143 None => {
144 buffered_entry = Some(next_entry);
145 }
146 };
147 }
148 if let Some(entry) = buffered_entry {
149 if entry.is_complete() {
151 yield decode_raw_entry(entry);
152 }
153 }
154 };
155
156 Ok(Box::pin(stream))
157 }
158}
159
160struct EntryReceiver {
161 start_id: EntryId,
162 sender: Sender<Entry>,
163}
164
165pub const DEFAULT_ENTRY_RECEIVER_BUFFER_SIZE: usize = 2048;
167
168pub fn build_wal_entry_distributor_and_receivers(
185 provider: Provider,
186 raw_wal_reader: Arc<dyn RawEntryReader>,
187 region_ids: &[RegionId],
188 buffer_size: usize,
189) -> (WalEntryDistributor, Vec<WalEntryReceiver>) {
190 let mut senders = HashMap::with_capacity(region_ids.len());
191 let mut readers = Vec::with_capacity(region_ids.len());
192 let mut arg_receivers = Vec::with_capacity(region_ids.len());
193
194 for ®ion_id in region_ids {
195 let (entry_sender, entry_receiver) = mpsc::channel(buffer_size);
196 let (arg_sender, arg_receiver) = oneshot::channel();
197
198 senders.insert(region_id, entry_sender);
199 arg_receivers.push((region_id, arg_receiver));
200 readers.push(WalEntryReceiver::new(entry_receiver, arg_sender));
201 }
202
203 (
204 WalEntryDistributor {
205 provider,
206 raw_wal_reader,
207 senders,
208 arg_receivers,
209 },
210 readers,
211 )
212}
213
214#[cfg(test)]
215mod tests {
216 use std::assert_matches::assert_matches;
217
218 use api::v1::{Mutation, OpType, WalEntry};
219 use futures::{stream, TryStreamExt};
220 use prost::Message;
221 use store_api::logstore::entry::{Entry, MultiplePartEntry, MultiplePartHeader, NaiveEntry};
222
223 use super::*;
224 use crate::test_util::wal_util::generate_tail_corrupted_stream;
225 use crate::wal::raw_entry_reader::{EntryStream, RawEntryReader};
226 use crate::wal::EntryId;
227
228 struct MockRawEntryReader {
229 entries: Vec<Entry>,
230 }
231
232 impl MockRawEntryReader {
233 pub fn new(entries: Vec<Entry>) -> MockRawEntryReader {
234 Self { entries }
235 }
236 }
237
238 impl RawEntryReader for MockRawEntryReader {
239 fn read(&self, _provider: &Provider, _start_id: EntryId) -> Result<EntryStream<'static>> {
240 let stream = stream::iter(self.entries.clone().into_iter().map(Ok));
241 Ok(Box::pin(stream))
242 }
243 }
244
245 #[tokio::test]
246 async fn test_wal_entry_distributor_without_receivers() {
247 let provider = Provider::kafka_provider("my_topic".to_string());
248 let reader = Arc::new(MockRawEntryReader::new(vec![Entry::Naive(NaiveEntry {
249 region_id: RegionId::new(1024, 1),
250 provider: provider.clone(),
251 entry_id: 1,
252 data: vec![1],
253 })]));
254
255 let (distributor, receivers) = build_wal_entry_distributor_and_receivers(
256 provider,
257 reader,
258 &[RegionId::new(1024, 1), RegionId::new(1025, 1)],
259 128,
260 );
261
262 drop(receivers);
264 distributor.distribute().await.unwrap();
266 }
267
268 #[tokio::test]
269 async fn test_wal_entry_distributor() {
270 common_telemetry::init_default_ut_logging();
271 let provider = Provider::kafka_provider("my_topic".to_string());
272 let reader = Arc::new(MockRawEntryReader::new(vec![
273 Entry::Naive(NaiveEntry {
274 provider: provider.clone(),
275 region_id: RegionId::new(1024, 1),
276 entry_id: 1,
277 data: WalEntry {
278 mutations: vec![Mutation {
279 op_type: OpType::Put as i32,
280 sequence: 1u64,
281 rows: None,
282 write_hint: None,
283 }],
284 }
285 .encode_to_vec(),
286 }),
287 Entry::Naive(NaiveEntry {
288 provider: provider.clone(),
289 region_id: RegionId::new(1024, 2),
290 entry_id: 2,
291 data: WalEntry {
292 mutations: vec![Mutation {
293 op_type: OpType::Put as i32,
294 sequence: 2u64,
295 rows: None,
296 write_hint: None,
297 }],
298 }
299 .encode_to_vec(),
300 }),
301 Entry::Naive(NaiveEntry {
302 provider: provider.clone(),
303 region_id: RegionId::new(1024, 3),
304 entry_id: 3,
305 data: WalEntry {
306 mutations: vec![Mutation {
307 op_type: OpType::Put as i32,
308 sequence: 3u64,
309 rows: None,
310 write_hint: None,
311 }],
312 }
313 .encode_to_vec(),
314 }),
315 ]));
316
317 let (distributor, mut receivers) = build_wal_entry_distributor_and_receivers(
319 provider.clone(),
320 reader,
321 &[
322 RegionId::new(1024, 1),
323 RegionId::new(1024, 2),
324 RegionId::new(1024, 3),
325 ],
326 128,
327 );
328 assert_eq!(receivers.len(), 3);
329
330 let last = receivers.pop().unwrap();
332 drop(last);
333
334 let mut streams = receivers
335 .iter_mut()
336 .map(|receiver| receiver.read(&provider, 0).unwrap())
337 .collect::<Vec<_>>();
338 distributor.distribute().await.unwrap();
339 let entries = streams
340 .get_mut(0)
341 .unwrap()
342 .try_collect::<Vec<_>>()
343 .await
344 .unwrap();
345 assert_eq!(
346 entries,
347 vec![(
348 1,
349 WalEntry {
350 mutations: vec![Mutation {
351 op_type: OpType::Put as i32,
352 sequence: 1u64,
353 rows: None,
354 write_hint: None,
355 }],
356 }
357 )]
358 );
359 let entries = streams
360 .get_mut(1)
361 .unwrap()
362 .try_collect::<Vec<_>>()
363 .await
364 .unwrap();
365 assert_eq!(
366 entries,
367 vec![(
368 2,
369 WalEntry {
370 mutations: vec![Mutation {
371 op_type: OpType::Put as i32,
372 sequence: 2u64,
373 rows: None,
374 write_hint: None,
375 }],
376 }
377 )]
378 );
379 }
380
381 #[tokio::test]
382 async fn test_tail_corrupted_stream() {
383 let mut entries = vec![];
384 let region1 = RegionId::new(1, 1);
385 let region1_expected_wal_entry = WalEntry {
386 mutations: vec![Mutation {
387 op_type: OpType::Put as i32,
388 sequence: 1u64,
389 rows: None,
390 write_hint: None,
391 }],
392 };
393 let region2 = RegionId::new(1, 2);
394 let region2_expected_wal_entry = WalEntry {
395 mutations: vec![Mutation {
396 op_type: OpType::Put as i32,
397 sequence: 3u64,
398 rows: None,
399 write_hint: None,
400 }],
401 };
402 let region3 = RegionId::new(1, 3);
403 let region3_expected_wal_entry = WalEntry {
404 mutations: vec![Mutation {
405 op_type: OpType::Put as i32,
406 sequence: 3u64,
407 rows: None,
408 write_hint: None,
409 }],
410 };
411 let provider = Provider::kafka_provider("my_topic".to_string());
412 entries.extend(generate_tail_corrupted_stream(
413 provider.clone(),
414 region1,
415 ®ion1_expected_wal_entry,
416 3,
417 ));
418 entries.extend(generate_tail_corrupted_stream(
419 provider.clone(),
420 region2,
421 ®ion2_expected_wal_entry,
422 2,
423 ));
424 entries.extend(generate_tail_corrupted_stream(
425 provider.clone(),
426 region3,
427 ®ion3_expected_wal_entry,
428 4,
429 ));
430
431 let corrupted_stream = MockRawEntryReader { entries };
432 let (distributor, mut receivers) = build_wal_entry_distributor_and_receivers(
434 provider.clone(),
435 Arc::new(corrupted_stream),
436 &[region1, region2, region3],
437 128,
438 );
439 assert_eq!(receivers.len(), 3);
440 let mut streams = receivers
441 .iter_mut()
442 .map(|receiver| receiver.read(&provider, 0).unwrap())
443 .collect::<Vec<_>>();
444 distributor.distribute().await.unwrap();
445
446 assert_eq!(
447 streams
448 .get_mut(0)
449 .unwrap()
450 .try_collect::<Vec<_>>()
451 .await
452 .unwrap(),
453 vec![(0, region1_expected_wal_entry)]
454 );
455
456 assert_eq!(
457 streams
458 .get_mut(1)
459 .unwrap()
460 .try_collect::<Vec<_>>()
461 .await
462 .unwrap(),
463 vec![(0, region2_expected_wal_entry)]
464 );
465
466 assert_eq!(
467 streams
468 .get_mut(2)
469 .unwrap()
470 .try_collect::<Vec<_>>()
471 .await
472 .unwrap(),
473 vec![(0, region3_expected_wal_entry)]
474 );
475 }
476
477 #[tokio::test]
478 async fn test_part_corrupted_stream() {
479 let mut entries = vec![];
480 let region1 = RegionId::new(1, 1);
481 let region1_expected_wal_entry = WalEntry {
482 mutations: vec![Mutation {
483 op_type: OpType::Put as i32,
484 sequence: 1u64,
485 rows: None,
486 write_hint: None,
487 }],
488 };
489 let region2 = RegionId::new(1, 2);
490 let provider = Provider::kafka_provider("my_topic".to_string());
491 entries.extend(generate_tail_corrupted_stream(
492 provider.clone(),
493 region1,
494 ®ion1_expected_wal_entry,
495 3,
496 ));
497 entries.extend(vec![
498 Entry::MultiplePart(MultiplePartEntry {
500 provider: provider.clone(),
501 region_id: region2,
502 entry_id: 0,
503 headers: vec![MultiplePartHeader::First],
504 parts: vec![vec![1; 100]],
505 }),
506 Entry::MultiplePart(MultiplePartEntry {
507 provider: provider.clone(),
508 region_id: region2,
509 entry_id: 0,
510 headers: vec![MultiplePartHeader::First],
511 parts: vec![vec![1; 100]],
512 }),
513 ]);
514
515 let corrupted_stream = MockRawEntryReader { entries };
516 let (distributor, mut receivers) = build_wal_entry_distributor_and_receivers(
518 provider.clone(),
519 Arc::new(corrupted_stream),
520 &[region1, region2],
521 128,
522 );
523 assert_eq!(receivers.len(), 2);
524 let mut streams = receivers
525 .iter_mut()
526 .map(|receiver| receiver.read(&provider, 0).unwrap())
527 .collect::<Vec<_>>();
528 distributor.distribute().await.unwrap();
529 assert_eq!(
530 streams
531 .get_mut(0)
532 .unwrap()
533 .try_collect::<Vec<_>>()
534 .await
535 .unwrap(),
536 vec![(0, region1_expected_wal_entry)]
537 );
538
539 assert_matches!(
540 streams
541 .get_mut(1)
542 .unwrap()
543 .try_collect::<Vec<_>>()
544 .await
545 .unwrap_err(),
546 error::Error::CorruptedEntry { .. }
547 );
548 }
549
550 #[tokio::test]
551 async fn test_wal_entry_receiver_start_id() {
552 let provider = Provider::kafka_provider("my_topic".to_string());
553 let reader = Arc::new(MockRawEntryReader::new(vec![
554 Entry::Naive(NaiveEntry {
555 provider: provider.clone(),
556 region_id: RegionId::new(1024, 1),
557 entry_id: 1,
558 data: WalEntry {
559 mutations: vec![Mutation {
560 op_type: OpType::Put as i32,
561 sequence: 1u64,
562 rows: None,
563 write_hint: None,
564 }],
565 }
566 .encode_to_vec(),
567 }),
568 Entry::Naive(NaiveEntry {
569 provider: provider.clone(),
570 region_id: RegionId::new(1024, 2),
571 entry_id: 2,
572 data: WalEntry {
573 mutations: vec![Mutation {
574 op_type: OpType::Put as i32,
575 sequence: 2u64,
576 rows: None,
577 write_hint: None,
578 }],
579 }
580 .encode_to_vec(),
581 }),
582 Entry::Naive(NaiveEntry {
583 provider: provider.clone(),
584 region_id: RegionId::new(1024, 1),
585 entry_id: 3,
586 data: WalEntry {
587 mutations: vec![Mutation {
588 op_type: OpType::Put as i32,
589 sequence: 3u64,
590 rows: None,
591 write_hint: None,
592 }],
593 }
594 .encode_to_vec(),
595 }),
596 Entry::Naive(NaiveEntry {
597 provider: provider.clone(),
598 region_id: RegionId::new(1024, 2),
599 entry_id: 4,
600 data: WalEntry {
601 mutations: vec![Mutation {
602 op_type: OpType::Put as i32,
603 sequence: 4u64,
604 rows: None,
605 write_hint: None,
606 }],
607 }
608 .encode_to_vec(),
609 }),
610 ]));
611
612 let (distributor, mut receivers) = build_wal_entry_distributor_and_receivers(
614 provider.clone(),
615 reader,
616 &[RegionId::new(1024, 1), RegionId::new(1024, 2)],
617 128,
618 );
619 assert_eq!(receivers.len(), 2);
620 let mut streams = receivers
621 .iter_mut()
622 .map(|receiver| receiver.read(&provider, 4).unwrap())
623 .collect::<Vec<_>>();
624 distributor.distribute().await.unwrap();
625
626 assert_eq!(
627 streams
628 .get_mut(1)
629 .unwrap()
630 .try_collect::<Vec<_>>()
631 .await
632 .unwrap(),
633 vec![(
634 4,
635 WalEntry {
636 mutations: vec![Mutation {
637 op_type: OpType::Put as i32,
638 sequence: 4u64,
639 rows: None,
640 write_hint: None,
641 }],
642 }
643 )]
644 );
645 }
646}