common_datasource/
buffered_writer.rs1use std::future::Future;
16
17use arrow::record_batch::RecordBatch;
18use async_trait::async_trait;
19use datafusion::parquet::format::FileMetaData;
20use snafu::{OptionExt, ResultExt};
21use tokio::io::{AsyncWrite, AsyncWriteExt};
22
23use crate::error::{self, Result};
24use crate::share_buffer::SharedBuffer;
25
26pub struct LazyBufferedWriter<T, U, F> {
27 path: String,
28 writer_factory: F,
29 writer: Option<T>,
30 encoder: Option<U>,
32 buffer: SharedBuffer,
33 rows_written: usize,
34 bytes_written: u64,
35 threshold: usize,
36}
37
38pub trait DfRecordBatchEncoder {
39 fn write(&mut self, batch: &RecordBatch) -> Result<()>;
40}
41
42#[async_trait]
43pub trait ArrowWriterCloser {
44 async fn close(mut self) -> Result<FileMetaData>;
45}
46
47impl<
48 T: AsyncWrite + Send + Unpin,
49 U: DfRecordBatchEncoder + ArrowWriterCloser,
50 F: Fn(String) -> Fut,
51 Fut: Future<Output = Result<T>>,
52 > LazyBufferedWriter<T, U, F>
53{
54 pub async fn close_with_arrow_writer(mut self) -> Result<(FileMetaData, u64)> {
57 let encoder = self
58 .encoder
59 .take()
60 .context(error::BufferedWriterClosedSnafu)?;
61 let metadata = encoder.close().await?;
62
63 self.close_inner_writer().await?;
65 Ok((metadata, self.bytes_written))
66 }
67}
68
69impl<
70 T: AsyncWrite + Send + Unpin,
71 U: DfRecordBatchEncoder,
72 F: Fn(String) -> Fut,
73 Fut: Future<Output = Result<T>>,
74 > LazyBufferedWriter<T, U, F>
75{
76 pub async fn close_inner_writer(&mut self) -> Result<()> {
78 if self.rows_written != 0 {
82 self.bytes_written += self.try_flush(true).await?;
83 }
84
85 if let Some(writer) = &mut self.writer {
86 writer.shutdown().await.context(error::AsyncWriteSnafu)?;
87 }
88 Ok(())
89 }
90
91 pub fn new(
92 threshold: usize,
93 buffer: SharedBuffer,
94 encoder: U,
95 path: impl AsRef<str>,
96 writer_factory: F,
97 ) -> Self {
98 Self {
99 path: path.as_ref().to_string(),
100 threshold,
101 encoder: Some(encoder),
102 buffer,
103 rows_written: 0,
104 bytes_written: 0,
105 writer_factory,
106 writer: None,
107 }
108 }
109
110 pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
111 let encoder = self
112 .encoder
113 .as_mut()
114 .context(error::BufferedWriterClosedSnafu)?;
115 encoder.write(batch)?;
116 self.rows_written += batch.num_rows();
117 self.bytes_written += self.try_flush(false).await?;
118 Ok(())
119 }
120
121 async fn try_flush(&mut self, all: bool) -> Result<u64> {
122 let mut bytes_written: u64 = 0;
123
124 while self.buffer.buffer.lock().unwrap().len() >= self.threshold {
127 let chunk = {
128 let mut buffer = self.buffer.buffer.lock().unwrap();
129 buffer.split_to(self.threshold)
130 };
131 let size = chunk.len();
132
133 self.maybe_init_writer()
134 .await?
135 .write_all(&chunk)
136 .await
137 .context(error::AsyncWriteSnafu)?;
138
139 bytes_written += size as u64;
140 }
141
142 if all {
143 bytes_written += self.try_flush_all().await?;
144 }
145 Ok(bytes_written)
146 }
147
148 async fn maybe_init_writer(&mut self) -> Result<&mut T> {
150 if let Some(ref mut writer) = self.writer {
151 Ok(writer)
152 } else {
153 let writer = (self.writer_factory)(self.path.to_string()).await?;
154 Ok(self.writer.insert(writer))
155 }
156 }
157
158 async fn try_flush_all(&mut self) -> Result<u64> {
159 let remain = self.buffer.buffer.lock().unwrap().split();
160 let size = remain.len();
161 self.maybe_init_writer()
162 .await?
163 .write_all(&remain)
164 .await
165 .context(error::AsyncWriteSnafu)?;
166 Ok(size as u64)
167 }
168}