common_datasource/file_format/
csv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct CsvFormat {
38    pub has_header: bool,
39    pub delimiter: u8,
40    pub schema_infer_max_record: Option<usize>,
41    pub compression_type: CompressionType,
42    pub timestamp_format: Option<String>,
43    pub time_format: Option<String>,
44    pub date_format: Option<String>,
45}
46
47impl TryFrom<&HashMap<String, String>> for CsvFormat {
48    type Error = error::Error;
49
50    fn try_from(value: &HashMap<String, String>) -> Result<Self> {
51        let mut format = CsvFormat::default();
52        if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
53            // TODO(weny): considers to support parse like "\t" (not only b'\t')
54            format.delimiter = u8::from_str(delimiter).map_err(|_| {
55                error::ParseFormatSnafu {
56                    key: file_format::FORMAT_DELIMITER,
57                    value: delimiter,
58                }
59                .build()
60            })?;
61        };
62        if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
63            format.compression_type = CompressionType::from_str(compression_type)?;
64        };
65        if let Some(schema_infer_max_record) =
66            value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
67        {
68            format.schema_infer_max_record =
69                Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
70                    error::ParseFormatSnafu {
71                        key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
72                        value: schema_infer_max_record,
73                    }
74                    .build()
75                })?);
76        };
77        if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
78            format.has_header = has_header.parse().map_err(|_| {
79                error::ParseFormatSnafu {
80                    key: file_format::FORMAT_HAS_HEADER,
81                    value: has_header,
82                }
83                .build()
84            })?;
85        };
86        if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
87            format.timestamp_format = Some(timestamp_format.clone());
88        }
89        if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
90            format.time_format = Some(time_format.clone());
91        }
92        if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
93            format.date_format = Some(date_format.clone());
94        }
95        Ok(format)
96    }
97}
98
99impl Default for CsvFormat {
100    fn default() -> Self {
101        Self {
102            has_header: true,
103            delimiter: b',',
104            schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
105            compression_type: CompressionType::Uncompressed,
106            timestamp_format: None,
107            time_format: None,
108            date_format: None,
109        }
110    }
111}
112
113#[async_trait]
114impl FileFormat for CsvFormat {
115    async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
116        let meta = store
117            .stat(path)
118            .await
119            .context(error::ReadObjectSnafu { path })?;
120
121        let reader = store
122            .reader(path)
123            .await
124            .context(error::ReadObjectSnafu { path })?
125            .into_futures_async_read(0..meta.content_length())
126            .await
127            .context(error::ReadObjectSnafu { path })?
128            .compat();
129
130        let decoded = self.compression_type.convert_async_read(reader);
131
132        let delimiter = self.delimiter;
133        let schema_infer_max_record = self.schema_infer_max_record;
134        let has_header = self.has_header;
135
136        common_runtime::spawn_blocking_global(move || {
137            let reader = SyncIoBridge::new(decoded);
138
139            let format = Format::default()
140                .with_delimiter(delimiter)
141                .with_header(has_header);
142            let (schema, _records_read) = format
143                .infer_schema(reader, schema_infer_max_record)
144                .context(error::InferSchemaSnafu)?;
145            Ok(schema)
146        })
147        .await
148        .context(error::JoinHandleSnafu)?
149    }
150}
151
152pub async fn stream_to_csv(
153    stream: SendableRecordBatchStream,
154    store: ObjectStore,
155    path: &str,
156    threshold: usize,
157    concurrency: usize,
158    format: &CsvFormat,
159) -> Result<usize> {
160    stream_to_file(stream, store, path, threshold, concurrency, |buffer| {
161        let mut builder = WriterBuilder::new();
162        if let Some(timestamp_format) = &format.timestamp_format {
163            builder = builder.with_timestamp_format(timestamp_format.to_owned())
164        }
165        if let Some(date_format) = &format.date_format {
166            builder = builder.with_date_format(date_format.to_owned())
167        }
168        if let Some(time_format) = &format.time_format {
169            builder = builder.with_time_format(time_format.to_owned())
170        }
171        builder.build(buffer)
172    })
173    .await
174}
175
176impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
177    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
178        self.write(batch).context(error::WriteRecordBatchSnafu)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184
185    use common_test_util::find_workspace_path;
186
187    use super::*;
188    use crate::file_format::{
189        FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
190        FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat,
191    };
192    use crate::test_util::{format_schema, test_store};
193
194    fn test_data_root() -> String {
195        find_workspace_path("/src/common/datasource/tests/csv")
196            .display()
197            .to_string()
198    }
199
200    #[tokio::test]
201    async fn infer_schema_basic() {
202        let csv = CsvFormat::default();
203        let store = test_store(&test_data_root());
204        let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
205        let formatted: Vec<_> = format_schema(schema);
206
207        assert_eq!(
208            vec![
209                "c1: Utf8: NULL",
210                "c2: Int64: NULL",
211                "c3: Int64: NULL",
212                "c4: Int64: NULL",
213                "c5: Int64: NULL",
214                "c6: Int64: NULL",
215                "c7: Int64: NULL",
216                "c8: Int64: NULL",
217                "c9: Int64: NULL",
218                "c10: Utf8: NULL",
219                "c11: Float64: NULL",
220                "c12: Float64: NULL",
221                "c13: Utf8: NULL"
222            ],
223            formatted,
224        );
225    }
226
227    #[tokio::test]
228    async fn infer_schema_with_limit() {
229        let json = CsvFormat {
230            schema_infer_max_record: Some(3),
231            ..CsvFormat::default()
232        };
233        let store = test_store(&test_data_root());
234        let schema = json
235            .infer_schema(&store, "schema_infer_limit.csv")
236            .await
237            .unwrap();
238        let formatted: Vec<_> = format_schema(schema);
239
240        assert_eq!(
241            vec![
242                "a: Int64: NULL",
243                "b: Float64: NULL",
244                "c: Int64: NULL",
245                "d: Int64: NULL"
246            ],
247            formatted
248        );
249
250        let json = CsvFormat::default();
251        let store = test_store(&test_data_root());
252        let schema = json
253            .infer_schema(&store, "schema_infer_limit.csv")
254            .await
255            .unwrap();
256        let formatted: Vec<_> = format_schema(schema);
257
258        assert_eq!(
259            vec![
260                "a: Int64: NULL",
261                "b: Float64: NULL",
262                "c: Int64: NULL",
263                "d: Utf8: NULL"
264            ],
265            formatted
266        );
267    }
268
269    #[test]
270    fn test_try_from() {
271        let map = HashMap::new();
272        let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
273
274        assert_eq!(format, CsvFormat::default());
275
276        let map = HashMap::from([
277            (
278                FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
279                "2000".to_string(),
280            ),
281            (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
282            (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
283            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
284        ]);
285        let format = CsvFormat::try_from(&map).unwrap();
286
287        assert_eq!(
288            format,
289            CsvFormat {
290                compression_type: CompressionType::Zstd,
291                schema_infer_max_record: Some(2000),
292                delimiter: b'\t',
293                has_header: false,
294                timestamp_format: None,
295                time_format: None,
296                date_format: None
297            }
298        );
299    }
300}