common_datasource/file_format/
csv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct CsvFormat {
38    pub has_header: bool,
39    pub delimiter: u8,
40    pub schema_infer_max_record: Option<usize>,
41    pub compression_type: CompressionType,
42    pub timestamp_format: Option<String>,
43    pub time_format: Option<String>,
44    pub date_format: Option<String>,
45}
46
47impl TryFrom<&HashMap<String, String>> for CsvFormat {
48    type Error = error::Error;
49
50    fn try_from(value: &HashMap<String, String>) -> Result<Self> {
51        let mut format = CsvFormat::default();
52        if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
53            // TODO(weny): considers to support parse like "\t" (not only b'\t')
54            format.delimiter = u8::from_str(delimiter).map_err(|_| {
55                error::ParseFormatSnafu {
56                    key: file_format::FORMAT_DELIMITER,
57                    value: delimiter,
58                }
59                .build()
60            })?;
61        };
62        if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
63            format.compression_type = CompressionType::from_str(compression_type)?;
64        };
65        if let Some(schema_infer_max_record) =
66            value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
67        {
68            format.schema_infer_max_record =
69                Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
70                    error::ParseFormatSnafu {
71                        key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
72                        value: schema_infer_max_record,
73                    }
74                    .build()
75                })?);
76        };
77        if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
78            format.has_header = has_header.parse().map_err(|_| {
79                error::ParseFormatSnafu {
80                    key: file_format::FORMAT_HAS_HEADER,
81                    value: has_header,
82                }
83                .build()
84            })?;
85        };
86        if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
87            format.timestamp_format = Some(timestamp_format.clone());
88        }
89        if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
90            format.time_format = Some(time_format.clone());
91        }
92        if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
93            format.date_format = Some(date_format.clone());
94        }
95        Ok(format)
96    }
97}
98
99impl Default for CsvFormat {
100    fn default() -> Self {
101        Self {
102            has_header: true,
103            delimiter: b',',
104            schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
105            compression_type: CompressionType::Uncompressed,
106            timestamp_format: None,
107            time_format: None,
108            date_format: None,
109        }
110    }
111}
112
113#[async_trait]
114impl FileFormat for CsvFormat {
115    async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
116        let meta = store
117            .stat(path)
118            .await
119            .context(error::ReadObjectSnafu { path })?;
120
121        let reader = store
122            .reader(path)
123            .await
124            .context(error::ReadObjectSnafu { path })?
125            .into_futures_async_read(0..meta.content_length())
126            .await
127            .context(error::ReadObjectSnafu { path })?
128            .compat();
129
130        let decoded = self.compression_type.convert_async_read(reader);
131
132        let delimiter = self.delimiter;
133        let schema_infer_max_record = self.schema_infer_max_record;
134        let has_header = self.has_header;
135
136        common_runtime::spawn_blocking_global(move || {
137            let reader = SyncIoBridge::new(decoded);
138
139            let format = Format::default()
140                .with_delimiter(delimiter)
141                .with_header(has_header);
142            let (schema, _records_read) = format
143                .infer_schema(reader, schema_infer_max_record)
144                .context(error::InferSchemaSnafu)?;
145            Ok(schema)
146        })
147        .await
148        .context(error::JoinHandleSnafu)?
149    }
150}
151
152pub async fn stream_to_csv(
153    stream: SendableRecordBatchStream,
154    store: ObjectStore,
155    path: &str,
156    threshold: usize,
157    concurrency: usize,
158    format: &CsvFormat,
159) -> Result<usize> {
160    stream_to_file(
161        stream,
162        store,
163        path,
164        threshold,
165        concurrency,
166        format.compression_type,
167        |buffer| {
168            let mut builder = WriterBuilder::new();
169            if let Some(timestamp_format) = &format.timestamp_format {
170                builder = builder.with_timestamp_format(timestamp_format.to_owned())
171            }
172            if let Some(date_format) = &format.date_format {
173                builder = builder.with_date_format(date_format.to_owned())
174            }
175            if let Some(time_format) = &format.time_format {
176                builder = builder.with_time_format(time_format.to_owned())
177            }
178            builder.build(buffer)
179        },
180    )
181    .await
182}
183
184impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
185    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
186        self.write(batch).context(error::WriteRecordBatchSnafu)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::sync::Arc;
193
194    use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
195    use common_recordbatch::{RecordBatch, RecordBatches};
196    use common_test_util::find_workspace_path;
197    use datafusion::datasource::physical_plan::{CsvSource, FileSource};
198    use datatypes::prelude::ConcreteDataType;
199    use datatypes::schema::{ColumnSchema, Schema};
200    use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
201    use futures::TryStreamExt;
202
203    use super::*;
204    use crate::file_format::{
205        FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
206        FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat, file_to_stream,
207    };
208    use crate::test_util::{format_schema, test_store};
209
210    fn test_data_root() -> String {
211        find_workspace_path("/src/common/datasource/tests/csv")
212            .display()
213            .to_string()
214    }
215
216    #[tokio::test]
217    async fn infer_schema_basic() {
218        let csv = CsvFormat::default();
219        let store = test_store(&test_data_root());
220        let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
221        let formatted: Vec<_> = format_schema(schema);
222
223        assert_eq!(
224            vec![
225                "c1: Utf8: NULL",
226                "c2: Int64: NULL",
227                "c3: Int64: NULL",
228                "c4: Int64: NULL",
229                "c5: Int64: NULL",
230                "c6: Int64: NULL",
231                "c7: Int64: NULL",
232                "c8: Int64: NULL",
233                "c9: Int64: NULL",
234                "c10: Utf8: NULL",
235                "c11: Float64: NULL",
236                "c12: Float64: NULL",
237                "c13: Utf8: NULL"
238            ],
239            formatted,
240        );
241    }
242
243    #[tokio::test]
244    async fn infer_schema_with_limit() {
245        let json = CsvFormat {
246            schema_infer_max_record: Some(3),
247            ..CsvFormat::default()
248        };
249        let store = test_store(&test_data_root());
250        let schema = json
251            .infer_schema(&store, "schema_infer_limit.csv")
252            .await
253            .unwrap();
254        let formatted: Vec<_> = format_schema(schema);
255
256        assert_eq!(
257            vec![
258                "a: Int64: NULL",
259                "b: Float64: NULL",
260                "c: Int64: NULL",
261                "d: Int64: NULL"
262            ],
263            formatted
264        );
265
266        let json = CsvFormat::default();
267        let store = test_store(&test_data_root());
268        let schema = json
269            .infer_schema(&store, "schema_infer_limit.csv")
270            .await
271            .unwrap();
272        let formatted: Vec<_> = format_schema(schema);
273
274        assert_eq!(
275            vec![
276                "a: Int64: NULL",
277                "b: Float64: NULL",
278                "c: Int64: NULL",
279                "d: Utf8: NULL"
280            ],
281            formatted
282        );
283    }
284
285    #[test]
286    fn test_try_from() {
287        let map = HashMap::new();
288        let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
289
290        assert_eq!(format, CsvFormat::default());
291
292        let map = HashMap::from([
293            (
294                FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
295                "2000".to_string(),
296            ),
297            (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
298            (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
299            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
300        ]);
301        let format = CsvFormat::try_from(&map).unwrap();
302
303        assert_eq!(
304            format,
305            CsvFormat {
306                compression_type: CompressionType::Zstd,
307                schema_infer_max_record: Some(2000),
308                delimiter: b'\t',
309                has_header: false,
310                timestamp_format: None,
311                time_format: None,
312                date_format: None
313            }
314        );
315    }
316
317    #[tokio::test]
318    async fn test_compressed_csv() {
319        // Create test data
320        let column_schemas = vec![
321            ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
322            ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
323            ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
324        ];
325        let schema = Arc::new(Schema::new(column_schemas));
326
327        // Create multiple record batches with different data
328        let batch1_columns: Vec<VectorRef> = vec![
329            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
330            Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
331            Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
332        ];
333        let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
334
335        let batch2_columns: Vec<VectorRef> = vec![
336            Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
337            Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
338            Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
339        ];
340        let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
341
342        let batch3_columns: Vec<VectorRef> = vec![
343            Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
344            Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
345            Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
346        ];
347        let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
348
349        // Combine all batches into a RecordBatches collection
350        let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
351
352        // Test with different compression types
353        let compression_types = vec![
354            CompressionType::Gzip,
355            CompressionType::Bzip2,
356            CompressionType::Xz,
357            CompressionType::Zstd,
358        ];
359
360        // Create a temporary file path
361        let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
362        for compression_type in compression_types {
363            let format = CsvFormat {
364                compression_type,
365                ..CsvFormat::default()
366            };
367
368            // Use correct format without Debug formatter
369            let compressed_file_name =
370                format!("test_compressed_csv.{}", compression_type.file_extension());
371            let compressed_file_path = temp_dir.path().join(&compressed_file_name);
372            let compressed_file_path_str = compressed_file_path.to_str().unwrap();
373
374            // Create a simple file store for testing
375            let store = test_store("/");
376
377            // Export CSV with compression
378            let rows = stream_to_csv(
379                Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
380                store,
381                compressed_file_path_str,
382                1024,
383                1,
384                &format,
385            )
386            .await
387            .unwrap();
388
389            assert_eq!(rows, 9);
390
391            // Verify compressed file was created and has content
392            assert!(compressed_file_path.exists());
393            let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
394            assert!(file_size > 0);
395
396            // Verify the file is actually compressed
397            let file_content = std::fs::read(&compressed_file_path).unwrap();
398            // Compressed files should not start with CSV header
399            // They should have compression magic bytes
400            match compression_type {
401                CompressionType::Gzip => {
402                    // Gzip magic bytes: 0x1f 0x8b
403                    assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
404                    assert_eq!(
405                        file_content[1], 0x8b,
406                        "Gzip file should have 0x8b as second byte"
407                    );
408                }
409                CompressionType::Bzip2 => {
410                    // Bzip2 magic bytes: 'BZ'
411                    assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
412                    assert_eq!(
413                        file_content[1], b'Z',
414                        "Bzip2 file should have 'Z' as second byte"
415                    );
416                }
417                CompressionType::Xz => {
418                    // XZ magic bytes: 0xFD '7zXZ'
419                    assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
420                }
421                CompressionType::Zstd => {
422                    // Zstd magic bytes: 0x28 0xB5 0x2F 0xFD
423                    assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
424                    assert_eq!(
425                        file_content[1], 0xB5,
426                        "Zstd file should have 0xB5 as second byte"
427                    );
428                }
429                _ => {}
430            }
431
432            // Verify the compressed file can be decompressed and content matches original data
433            let store = test_store("/");
434            let schema = Arc::new(
435                CsvFormat {
436                    compression_type,
437                    ..Default::default()
438                }
439                .infer_schema(&store, compressed_file_path_str)
440                .await
441                .unwrap(),
442            );
443            let csv_source = CsvSource::new(true, b',', b'"')
444                .with_schema(schema.clone())
445                .with_batch_size(8192);
446
447            let stream = file_to_stream(
448                &store,
449                compressed_file_path_str,
450                schema.clone(),
451                csv_source.clone(),
452                None,
453                compression_type,
454            )
455            .await
456            .unwrap();
457
458            let batches = stream.try_collect::<Vec<_>>().await.unwrap();
459            let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
460                .unwrap()
461                .to_string();
462            let expected = r#"+----+---------+-------+
463| id | name    | value |
464+----+---------+-------+
465| 1  | Alice   | 10.5  |
466| 2  | Bob     | 20.3  |
467| 3  | Charlie | 30.7  |
468| 4  | David   | 40.1  |
469| 5  | Eva     | 50.2  |
470| 6  | Frank   | 60.3  |
471| 7  | Grace   | 70.4  |
472| 8  | Henry   | 80.5  |
473| 9  | Ivy     | 90.6  |
474+----+---------+-------+"#;
475            assert_eq!(expected, pretty_print);
476        }
477    }
478}