common_datasource/
buffered_writer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16
17use arrow::record_batch::RecordBatch;
18use async_trait::async_trait;
19use datafusion::parquet::format::FileMetaData;
20use snafu::{OptionExt, ResultExt};
21use tokio::io::{AsyncWrite, AsyncWriteExt};
22
23use crate::error::{self, Result};
24use crate::share_buffer::SharedBuffer;
25
26pub struct LazyBufferedWriter<T, U, F> {
27    path: String,
28    writer_factory: F,
29    writer: Option<T>,
30    /// None stands for [`LazyBufferedWriter`] closed.
31    encoder: Option<U>,
32    buffer: SharedBuffer,
33    rows_written: usize,
34    bytes_written: u64,
35    threshold: usize,
36}
37
38pub trait DfRecordBatchEncoder {
39    fn write(&mut self, batch: &RecordBatch) -> Result<()>;
40}
41
42#[async_trait]
43pub trait ArrowWriterCloser {
44    async fn close(mut self) -> Result<FileMetaData>;
45}
46
47impl<
48        T: AsyncWrite + Send + Unpin,
49        U: DfRecordBatchEncoder + ArrowWriterCloser,
50        F: Fn(String) -> Fut,
51        Fut: Future<Output = Result<T>>,
52    > LazyBufferedWriter<T, U, F>
53{
54    /// Closes `LazyBufferedWriter` and optionally flushes all data to underlying storage
55    /// if any row's been written.
56    pub async fn close_with_arrow_writer(mut self) -> Result<(FileMetaData, u64)> {
57        let encoder = self
58            .encoder
59            .take()
60            .context(error::BufferedWriterClosedSnafu)?;
61        let metadata = encoder.close().await?;
62
63        // It's important to shut down! flushes all pending writes
64        self.close_inner_writer().await?;
65        Ok((metadata, self.bytes_written))
66    }
67}
68
69impl<
70        T: AsyncWrite + Send + Unpin,
71        U: DfRecordBatchEncoder,
72        F: Fn(String) -> Fut,
73        Fut: Future<Output = Result<T>>,
74    > LazyBufferedWriter<T, U, F>
75{
76    /// Closes the writer and flushes the buffer data.
77    pub async fn close_inner_writer(&mut self) -> Result<()> {
78        // Use `rows_written` to keep a track of if any rows have been written.
79        // If no row's been written, then we can simply close the underlying
80        // writer without flush so that no file will be actually created.
81        if self.rows_written != 0 {
82            self.bytes_written += self.try_flush(true).await?;
83        }
84
85        if let Some(writer) = &mut self.writer {
86            writer.shutdown().await.context(error::AsyncWriteSnafu)?;
87        }
88        Ok(())
89    }
90
91    pub fn new(
92        threshold: usize,
93        buffer: SharedBuffer,
94        encoder: U,
95        path: impl AsRef<str>,
96        writer_factory: F,
97    ) -> Self {
98        Self {
99            path: path.as_ref().to_string(),
100            threshold,
101            encoder: Some(encoder),
102            buffer,
103            rows_written: 0,
104            bytes_written: 0,
105            writer_factory,
106            writer: None,
107        }
108    }
109
110    pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
111        let encoder = self
112            .encoder
113            .as_mut()
114            .context(error::BufferedWriterClosedSnafu)?;
115        encoder.write(batch)?;
116        self.rows_written += batch.num_rows();
117        self.bytes_written += self.try_flush(false).await?;
118        Ok(())
119    }
120
121    async fn try_flush(&mut self, all: bool) -> Result<u64> {
122        let mut bytes_written: u64 = 0;
123
124        // Once buffered data size reaches threshold, split the data in chunks (typically 4MB)
125        // and write to underlying storage.
126        while self.buffer.buffer.lock().unwrap().len() >= self.threshold {
127            let chunk = {
128                let mut buffer = self.buffer.buffer.lock().unwrap();
129                buffer.split_to(self.threshold)
130            };
131            let size = chunk.len();
132
133            self.maybe_init_writer()
134                .await?
135                .write_all(&chunk)
136                .await
137                .context(error::AsyncWriteSnafu)?;
138
139            bytes_written += size as u64;
140        }
141
142        if all {
143            bytes_written += self.try_flush_all().await?;
144        }
145        Ok(bytes_written)
146    }
147
148    /// Only initiates underlying file writer when rows have been written.
149    async fn maybe_init_writer(&mut self) -> Result<&mut T> {
150        if let Some(ref mut writer) = self.writer {
151            Ok(writer)
152        } else {
153            let writer = (self.writer_factory)(self.path.to_string()).await?;
154            Ok(self.writer.insert(writer))
155        }
156    }
157
158    async fn try_flush_all(&mut self) -> Result<u64> {
159        let remain = self.buffer.buffer.lock().unwrap().split();
160        let size = remain.len();
161        self.maybe_init_writer()
162            .await?
163            .write_all(&remain)
164            .await
165            .context(error::AsyncWriteSnafu)?;
166        Ok(size as u64)
167    }
168}