common_datasource/file_format/
csv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv;
19use arrow::csv::reader::Format;
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, stream_to_file, FileFormat};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct CsvFormat {
38    pub has_header: bool,
39    pub delimiter: u8,
40    pub schema_infer_max_record: Option<usize>,
41    pub compression_type: CompressionType,
42}
43
44impl TryFrom<&HashMap<String, String>> for CsvFormat {
45    type Error = error::Error;
46
47    fn try_from(value: &HashMap<String, String>) -> Result<Self> {
48        let mut format = CsvFormat::default();
49        if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
50            // TODO(weny): considers to support parse like "\t" (not only b'\t')
51            format.delimiter = u8::from_str(delimiter).map_err(|_| {
52                error::ParseFormatSnafu {
53                    key: file_format::FORMAT_DELIMITER,
54                    value: delimiter,
55                }
56                .build()
57            })?;
58        };
59        if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
60            format.compression_type = CompressionType::from_str(compression_type)?;
61        };
62        if let Some(schema_infer_max_record) =
63            value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
64        {
65            format.schema_infer_max_record =
66                Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
67                    error::ParseFormatSnafu {
68                        key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
69                        value: schema_infer_max_record,
70                    }
71                    .build()
72                })?);
73        };
74        if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
75            format.has_header = has_header.parse().map_err(|_| {
76                error::ParseFormatSnafu {
77                    key: file_format::FORMAT_HAS_HEADER,
78                    value: has_header,
79                }
80                .build()
81            })?;
82        }
83        Ok(format)
84    }
85}
86
87impl Default for CsvFormat {
88    fn default() -> Self {
89        Self {
90            has_header: true,
91            delimiter: b',',
92            schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
93            compression_type: CompressionType::Uncompressed,
94        }
95    }
96}
97
98#[async_trait]
99impl FileFormat for CsvFormat {
100    async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
101        let meta = store
102            .stat(path)
103            .await
104            .context(error::ReadObjectSnafu { path })?;
105
106        let reader = store
107            .reader(path)
108            .await
109            .context(error::ReadObjectSnafu { path })?
110            .into_futures_async_read(0..meta.content_length())
111            .await
112            .context(error::ReadObjectSnafu { path })?
113            .compat();
114
115        let decoded = self.compression_type.convert_async_read(reader);
116
117        let delimiter = self.delimiter;
118        let schema_infer_max_record = self.schema_infer_max_record;
119        let has_header = self.has_header;
120
121        common_runtime::spawn_blocking_global(move || {
122            let reader = SyncIoBridge::new(decoded);
123
124            let format = Format::default()
125                .with_delimiter(delimiter)
126                .with_header(has_header);
127            let (schema, _records_read) = format
128                .infer_schema(reader, schema_infer_max_record)
129                .context(error::InferSchemaSnafu)?;
130            Ok(schema)
131        })
132        .await
133        .context(error::JoinHandleSnafu)?
134    }
135}
136
137pub async fn stream_to_csv(
138    stream: SendableRecordBatchStream,
139    store: ObjectStore,
140    path: &str,
141    threshold: usize,
142    concurrency: usize,
143) -> Result<usize> {
144    stream_to_file(stream, store, path, threshold, concurrency, |buffer| {
145        csv::Writer::new(buffer)
146    })
147    .await
148}
149
150impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
151    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
152        self.write(batch).context(error::WriteRecordBatchSnafu)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158
159    use common_test_util::find_workspace_path;
160
161    use super::*;
162    use crate::file_format::{
163        FileFormat, FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
164        FORMAT_SCHEMA_INFER_MAX_RECORD,
165    };
166    use crate::test_util::{format_schema, test_store};
167
168    fn test_data_root() -> String {
169        find_workspace_path("/src/common/datasource/tests/csv")
170            .display()
171            .to_string()
172    }
173
174    #[tokio::test]
175    async fn infer_schema_basic() {
176        let csv = CsvFormat::default();
177        let store = test_store(&test_data_root());
178        let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
179        let formatted: Vec<_> = format_schema(schema);
180
181        assert_eq!(
182            vec![
183                "c1: Utf8: NULL",
184                "c2: Int64: NULL",
185                "c3: Int64: NULL",
186                "c4: Int64: NULL",
187                "c5: Int64: NULL",
188                "c6: Int64: NULL",
189                "c7: Int64: NULL",
190                "c8: Int64: NULL",
191                "c9: Int64: NULL",
192                "c10: Utf8: NULL",
193                "c11: Float64: NULL",
194                "c12: Float64: NULL",
195                "c13: Utf8: NULL"
196            ],
197            formatted,
198        );
199    }
200
201    #[tokio::test]
202    async fn infer_schema_with_limit() {
203        let json = CsvFormat {
204            schema_infer_max_record: Some(3),
205            ..CsvFormat::default()
206        };
207        let store = test_store(&test_data_root());
208        let schema = json
209            .infer_schema(&store, "schema_infer_limit.csv")
210            .await
211            .unwrap();
212        let formatted: Vec<_> = format_schema(schema);
213
214        assert_eq!(
215            vec![
216                "a: Int64: NULL",
217                "b: Float64: NULL",
218                "c: Int64: NULL",
219                "d: Int64: NULL"
220            ],
221            formatted
222        );
223
224        let json = CsvFormat::default();
225        let store = test_store(&test_data_root());
226        let schema = json
227            .infer_schema(&store, "schema_infer_limit.csv")
228            .await
229            .unwrap();
230        let formatted: Vec<_> = format_schema(schema);
231
232        assert_eq!(
233            vec![
234                "a: Int64: NULL",
235                "b: Float64: NULL",
236                "c: Int64: NULL",
237                "d: Utf8: NULL"
238            ],
239            formatted
240        );
241    }
242
243    #[test]
244    fn test_try_from() {
245        let map = HashMap::new();
246        let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
247
248        assert_eq!(format, CsvFormat::default());
249
250        let map = HashMap::from([
251            (
252                FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
253                "2000".to_string(),
254            ),
255            (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
256            (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
257            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
258        ]);
259        let format = CsvFormat::try_from(&map).unwrap();
260
261        assert_eq!(
262            format,
263            CsvFormat {
264                compression_type: CompressionType::Zstd,
265                schema_infer_max_record: Some(2000),
266                delimiter: b'\t',
267                has_header: false,
268            }
269        );
270    }
271}