common_datasource/file_format/
csv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35use crate::util::normalize_infer_schema;
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct CsvFormat {
39    pub has_header: bool,
40    pub delimiter: u8,
41    pub schema_infer_max_record: Option<usize>,
42    pub compression_type: CompressionType,
43    pub timestamp_format: Option<String>,
44    pub time_format: Option<String>,
45    pub date_format: Option<String>,
46}
47
48impl TryFrom<&HashMap<String, String>> for CsvFormat {
49    type Error = error::Error;
50
51    fn try_from(value: &HashMap<String, String>) -> Result<Self> {
52        let mut format = CsvFormat::default();
53        if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
54            // TODO(weny): considers to support parse like "\t" (not only b'\t')
55            format.delimiter = u8::from_str(delimiter).map_err(|_| {
56                error::ParseFormatSnafu {
57                    key: file_format::FORMAT_DELIMITER,
58                    value: delimiter,
59                }
60                .build()
61            })?;
62        };
63        if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
64            format.compression_type = CompressionType::from_str(compression_type)?;
65        };
66        if let Some(schema_infer_max_record) =
67            value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
68        {
69            format.schema_infer_max_record =
70                Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
71                    error::ParseFormatSnafu {
72                        key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
73                        value: schema_infer_max_record,
74                    }
75                    .build()
76                })?);
77        };
78        if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
79            format.has_header = has_header.parse().map_err(|_| {
80                error::ParseFormatSnafu {
81                    key: file_format::FORMAT_HAS_HEADER,
82                    value: has_header,
83                }
84                .build()
85            })?;
86        };
87        if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
88            format.timestamp_format = Some(timestamp_format.clone());
89        }
90        if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
91            format.time_format = Some(time_format.clone());
92        }
93        if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
94            format.date_format = Some(date_format.clone());
95        }
96        Ok(format)
97    }
98}
99
100impl Default for CsvFormat {
101    fn default() -> Self {
102        Self {
103            has_header: true,
104            delimiter: b',',
105            schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
106            compression_type: CompressionType::Uncompressed,
107            timestamp_format: None,
108            time_format: None,
109            date_format: None,
110        }
111    }
112}
113
114#[async_trait]
115impl FileFormat for CsvFormat {
116    async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
117        let meta = store
118            .stat(path)
119            .await
120            .context(error::ReadObjectSnafu { path })?;
121
122        let reader = store
123            .reader(path)
124            .await
125            .context(error::ReadObjectSnafu { path })?
126            .into_futures_async_read(0..meta.content_length())
127            .await
128            .context(error::ReadObjectSnafu { path })?
129            .compat();
130
131        let decoded = self.compression_type.convert_async_read(reader);
132
133        let delimiter = self.delimiter;
134        let schema_infer_max_record = self.schema_infer_max_record;
135        let has_header = self.has_header;
136
137        common_runtime::spawn_blocking_global(move || {
138            let reader = SyncIoBridge::new(decoded);
139
140            let format = Format::default()
141                .with_delimiter(delimiter)
142                .with_header(has_header);
143            let (schema, _records_read) = format
144                .infer_schema(reader, schema_infer_max_record)
145                .context(error::InferSchemaSnafu)?;
146
147            Ok(normalize_infer_schema(schema))
148        })
149        .await
150        .context(error::JoinHandleSnafu)?
151    }
152}
153
154pub async fn stream_to_csv(
155    stream: SendableRecordBatchStream,
156    store: ObjectStore,
157    path: &str,
158    threshold: usize,
159    concurrency: usize,
160    format: &CsvFormat,
161) -> Result<usize> {
162    stream_to_file(
163        stream,
164        store,
165        path,
166        threshold,
167        concurrency,
168        format.compression_type,
169        |buffer| {
170            let mut builder = WriterBuilder::new();
171            if let Some(timestamp_format) = &format.timestamp_format {
172                builder = builder.with_timestamp_format(timestamp_format.to_owned())
173            }
174            if let Some(date_format) = &format.date_format {
175                builder = builder.with_date_format(date_format.to_owned())
176            }
177            if let Some(time_format) = &format.time_format {
178                builder = builder.with_time_format(time_format.to_owned())
179            }
180            builder.build(buffer)
181        },
182    )
183    .await
184}
185
186impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
187    fn write(&mut self, batch: &RecordBatch) -> Result<()> {
188        self.write(batch).context(error::WriteRecordBatchSnafu)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use std::sync::Arc;
195
196    use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
197    use common_recordbatch::{RecordBatch, RecordBatches};
198    use common_test_util::find_workspace_path;
199    use datafusion::datasource::physical_plan::{CsvSource, FileSource};
200    use datatypes::prelude::ConcreteDataType;
201    use datatypes::schema::{ColumnSchema, Schema};
202    use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
203    use futures::TryStreamExt;
204
205    use super::*;
206    use crate::file_format::{
207        FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
208        FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat, file_to_stream,
209    };
210    use crate::test_util::{format_schema, test_store};
211
212    fn test_data_root() -> String {
213        find_workspace_path("/src/common/datasource/tests/csv")
214            .display()
215            .to_string()
216    }
217
218    #[tokio::test]
219    async fn infer_schema_basic() {
220        let csv = CsvFormat::default();
221        let store = test_store(&test_data_root());
222        let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
223        let formatted: Vec<_> = format_schema(schema);
224
225        assert_eq!(
226            vec![
227                "c1: Utf8: NULL",
228                "c2: Int64: NULL",
229                "c3: Int64: NULL",
230                "c4: Int64: NULL",
231                "c5: Int64: NULL",
232                "c6: Int64: NULL",
233                "c7: Int64: NULL",
234                "c8: Int64: NULL",
235                "c9: Int64: NULL",
236                "c10: Utf8: NULL",
237                "c11: Float64: NULL",
238                "c12: Float64: NULL",
239                "c13: Utf8: NULL"
240            ],
241            formatted,
242        );
243    }
244
245    #[tokio::test]
246    async fn normalize_infer_schema() {
247        let csv = CsvFormat {
248            schema_infer_max_record: Some(3),
249            ..CsvFormat::default()
250        };
251        let store = test_store(&test_data_root());
252        let schema = csv.infer_schema(&store, "max_infer.csv").await.unwrap();
253        let formatted: Vec<_> = format_schema(schema);
254
255        assert_eq!(
256            vec![
257                "num: Int64: NULL",
258                "str: Utf8: NULL",
259                "ts: Utf8: NULL",
260                "t: Utf8: NULL",
261                "date: Date32: NULL"
262            ],
263            formatted,
264        );
265    }
266
267    #[tokio::test]
268    async fn infer_schema_with_limit() {
269        let csv = CsvFormat {
270            schema_infer_max_record: Some(3),
271            ..CsvFormat::default()
272        };
273        let store = test_store(&test_data_root());
274        let schema = csv
275            .infer_schema(&store, "schema_infer_limit.csv")
276            .await
277            .unwrap();
278        let formatted: Vec<_> = format_schema(schema);
279
280        assert_eq!(
281            vec![
282                "a: Int64: NULL",
283                "b: Float64: NULL",
284                "c: Int64: NULL",
285                "d: Int64: NULL"
286            ],
287            formatted
288        );
289
290        let csv = CsvFormat::default();
291        let store = test_store(&test_data_root());
292        let schema = csv
293            .infer_schema(&store, "schema_infer_limit.csv")
294            .await
295            .unwrap();
296        let formatted: Vec<_> = format_schema(schema);
297
298        assert_eq!(
299            vec![
300                "a: Int64: NULL",
301                "b: Float64: NULL",
302                "c: Int64: NULL",
303                "d: Utf8: NULL"
304            ],
305            formatted
306        );
307    }
308
309    #[test]
310    fn test_try_from() {
311        let map = HashMap::new();
312        let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
313
314        assert_eq!(format, CsvFormat::default());
315
316        let map = HashMap::from([
317            (
318                FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
319                "2000".to_string(),
320            ),
321            (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
322            (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
323            (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
324        ]);
325        let format = CsvFormat::try_from(&map).unwrap();
326
327        assert_eq!(
328            format,
329            CsvFormat {
330                compression_type: CompressionType::Zstd,
331                schema_infer_max_record: Some(2000),
332                delimiter: b'\t',
333                has_header: false,
334                timestamp_format: None,
335                time_format: None,
336                date_format: None
337            }
338        );
339    }
340
341    #[tokio::test]
342    async fn test_compressed_csv() {
343        // Create test data
344        let column_schemas = vec![
345            ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
346            ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
347            ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
348        ];
349        let schema = Arc::new(Schema::new(column_schemas));
350
351        // Create multiple record batches with different data
352        let batch1_columns: Vec<VectorRef> = vec![
353            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
354            Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
355            Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
356        ];
357        let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
358
359        let batch2_columns: Vec<VectorRef> = vec![
360            Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
361            Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
362            Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
363        ];
364        let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
365
366        let batch3_columns: Vec<VectorRef> = vec![
367            Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
368            Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
369            Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
370        ];
371        let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
372
373        // Combine all batches into a RecordBatches collection
374        let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
375
376        // Test with different compression types
377        let compression_types = vec![
378            CompressionType::Gzip,
379            CompressionType::Bzip2,
380            CompressionType::Xz,
381            CompressionType::Zstd,
382        ];
383
384        // Create a temporary file path
385        let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
386        for compression_type in compression_types {
387            let format = CsvFormat {
388                compression_type,
389                ..CsvFormat::default()
390            };
391
392            // Use correct format without Debug formatter
393            let compressed_file_name =
394                format!("test_compressed_csv.{}", compression_type.file_extension());
395            let compressed_file_path = temp_dir.path().join(&compressed_file_name);
396            let compressed_file_path_str = compressed_file_path.to_str().unwrap();
397
398            // Create a simple file store for testing
399            let store = test_store("/");
400
401            // Export CSV with compression
402            let rows = stream_to_csv(
403                Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
404                store,
405                compressed_file_path_str,
406                1024,
407                1,
408                &format,
409            )
410            .await
411            .unwrap();
412
413            assert_eq!(rows, 9);
414
415            // Verify compressed file was created and has content
416            assert!(compressed_file_path.exists());
417            let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
418            assert!(file_size > 0);
419
420            // Verify the file is actually compressed
421            let file_content = std::fs::read(&compressed_file_path).unwrap();
422            // Compressed files should not start with CSV header
423            // They should have compression magic bytes
424            match compression_type {
425                CompressionType::Gzip => {
426                    // Gzip magic bytes: 0x1f 0x8b
427                    assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
428                    assert_eq!(
429                        file_content[1], 0x8b,
430                        "Gzip file should have 0x8b as second byte"
431                    );
432                }
433                CompressionType::Bzip2 => {
434                    // Bzip2 magic bytes: 'BZ'
435                    assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
436                    assert_eq!(
437                        file_content[1], b'Z',
438                        "Bzip2 file should have 'Z' as second byte"
439                    );
440                }
441                CompressionType::Xz => {
442                    // XZ magic bytes: 0xFD '7zXZ'
443                    assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
444                }
445                CompressionType::Zstd => {
446                    // Zstd magic bytes: 0x28 0xB5 0x2F 0xFD
447                    assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
448                    assert_eq!(
449                        file_content[1], 0xB5,
450                        "Zstd file should have 0xB5 as second byte"
451                    );
452                }
453                _ => {}
454            }
455
456            // Verify the compressed file can be decompressed and content matches original data
457            let store = test_store("/");
458            let schema = Arc::new(
459                CsvFormat {
460                    compression_type,
461                    ..Default::default()
462                }
463                .infer_schema(&store, compressed_file_path_str)
464                .await
465                .unwrap(),
466            );
467            let csv_source = CsvSource::new(schema).with_batch_size(8192);
468
469            let stream = file_to_stream(
470                &store,
471                compressed_file_path_str,
472                csv_source.clone(),
473                None,
474                compression_type,
475            )
476            .await
477            .unwrap();
478
479            let batches = stream.try_collect::<Vec<_>>().await.unwrap();
480            let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
481                .unwrap()
482                .to_string();
483            let expected = r#"+----+---------+-------+
484| id | name    | value |
485+----+---------+-------+
486| 1  | Alice   | 10.5  |
487| 2  | Bob     | 20.3  |
488| 3  | Charlie | 30.7  |
489| 4  | David   | 40.1  |
490| 5  | Eva     | 50.2  |
491| 6  | Frank   | 60.3  |
492| 7  | Grace   | 70.4  |
493| 8  | Henry   | 80.5  |
494| 9  | Ivy     | 90.6  |
495+----+---------+-------+"#;
496            assert_eq!(expected, pretty_print);
497        }
498    }
499}