1use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35use crate::util::normalize_infer_schema;
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct CsvFormat {
39 pub has_header: bool,
40 pub delimiter: u8,
41 pub schema_infer_max_record: Option<usize>,
42 pub compression_type: CompressionType,
43 pub timestamp_format: Option<String>,
44 pub time_format: Option<String>,
45 pub date_format: Option<String>,
46}
47
48impl TryFrom<&HashMap<String, String>> for CsvFormat {
49 type Error = error::Error;
50
51 fn try_from(value: &HashMap<String, String>) -> Result<Self> {
52 let mut format = CsvFormat::default();
53 if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
54 format.delimiter = u8::from_str(delimiter).map_err(|_| {
56 error::ParseFormatSnafu {
57 key: file_format::FORMAT_DELIMITER,
58 value: delimiter,
59 }
60 .build()
61 })?;
62 };
63 if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
64 format.compression_type = CompressionType::from_str(compression_type)?;
65 };
66 if let Some(schema_infer_max_record) =
67 value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
68 {
69 format.schema_infer_max_record =
70 Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
71 error::ParseFormatSnafu {
72 key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
73 value: schema_infer_max_record,
74 }
75 .build()
76 })?);
77 };
78 if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
79 format.has_header = has_header.parse().map_err(|_| {
80 error::ParseFormatSnafu {
81 key: file_format::FORMAT_HAS_HEADER,
82 value: has_header,
83 }
84 .build()
85 })?;
86 };
87 if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
88 format.timestamp_format = Some(timestamp_format.clone());
89 }
90 if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
91 format.time_format = Some(time_format.clone());
92 }
93 if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
94 format.date_format = Some(date_format.clone());
95 }
96 Ok(format)
97 }
98}
99
100impl Default for CsvFormat {
101 fn default() -> Self {
102 Self {
103 has_header: true,
104 delimiter: b',',
105 schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
106 compression_type: CompressionType::Uncompressed,
107 timestamp_format: None,
108 time_format: None,
109 date_format: None,
110 }
111 }
112}
113
114#[async_trait]
115impl FileFormat for CsvFormat {
116 async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
117 let meta = store
118 .stat(path)
119 .await
120 .context(error::ReadObjectSnafu { path })?;
121
122 let reader = store
123 .reader(path)
124 .await
125 .context(error::ReadObjectSnafu { path })?
126 .into_futures_async_read(0..meta.content_length())
127 .await
128 .context(error::ReadObjectSnafu { path })?
129 .compat();
130
131 let decoded = self.compression_type.convert_async_read(reader);
132
133 let delimiter = self.delimiter;
134 let schema_infer_max_record = self.schema_infer_max_record;
135 let has_header = self.has_header;
136
137 common_runtime::spawn_blocking_global(move || {
138 let reader = SyncIoBridge::new(decoded);
139
140 let format = Format::default()
141 .with_delimiter(delimiter)
142 .with_header(has_header);
143 let (schema, _records_read) = format
144 .infer_schema(reader, schema_infer_max_record)
145 .context(error::InferSchemaSnafu)?;
146
147 Ok(normalize_infer_schema(schema))
148 })
149 .await
150 .context(error::JoinHandleSnafu)?
151 }
152}
153
154pub async fn stream_to_csv(
155 stream: SendableRecordBatchStream,
156 store: ObjectStore,
157 path: &str,
158 threshold: usize,
159 concurrency: usize,
160 format: &CsvFormat,
161) -> Result<usize> {
162 stream_to_file(
163 stream,
164 store,
165 path,
166 threshold,
167 concurrency,
168 format.compression_type,
169 |buffer| {
170 let mut builder = WriterBuilder::new();
171 if let Some(timestamp_format) = &format.timestamp_format {
172 builder = builder.with_timestamp_format(timestamp_format.to_owned())
173 }
174 if let Some(date_format) = &format.date_format {
175 builder = builder.with_date_format(date_format.to_owned())
176 }
177 if let Some(time_format) = &format.time_format {
178 builder = builder.with_time_format(time_format.to_owned())
179 }
180 builder.build(buffer)
181 },
182 )
183 .await
184}
185
186impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
187 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
188 self.write(batch).context(error::WriteRecordBatchSnafu)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::Arc;
195
196 use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
197 use common_recordbatch::{RecordBatch, RecordBatches};
198 use common_test_util::find_workspace_path;
199 use datafusion::datasource::physical_plan::{CsvSource, FileSource};
200 use datatypes::prelude::ConcreteDataType;
201 use datatypes::schema::{ColumnSchema, Schema};
202 use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
203 use futures::TryStreamExt;
204
205 use super::*;
206 use crate::file_format::{
207 FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
208 FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat, file_to_stream,
209 };
210 use crate::test_util::{format_schema, test_store};
211
212 fn test_data_root() -> String {
213 find_workspace_path("/src/common/datasource/tests/csv")
214 .display()
215 .to_string()
216 }
217
218 #[tokio::test]
219 async fn infer_schema_basic() {
220 let csv = CsvFormat::default();
221 let store = test_store(&test_data_root());
222 let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
223 let formatted: Vec<_> = format_schema(schema);
224
225 assert_eq!(
226 vec![
227 "c1: Utf8: NULL",
228 "c2: Int64: NULL",
229 "c3: Int64: NULL",
230 "c4: Int64: NULL",
231 "c5: Int64: NULL",
232 "c6: Int64: NULL",
233 "c7: Int64: NULL",
234 "c8: Int64: NULL",
235 "c9: Int64: NULL",
236 "c10: Utf8: NULL",
237 "c11: Float64: NULL",
238 "c12: Float64: NULL",
239 "c13: Utf8: NULL"
240 ],
241 formatted,
242 );
243 }
244
245 #[tokio::test]
246 async fn normalize_infer_schema() {
247 let csv = CsvFormat {
248 schema_infer_max_record: Some(3),
249 ..CsvFormat::default()
250 };
251 let store = test_store(&test_data_root());
252 let schema = csv.infer_schema(&store, "max_infer.csv").await.unwrap();
253 let formatted: Vec<_> = format_schema(schema);
254
255 assert_eq!(
256 vec![
257 "num: Int64: NULL",
258 "str: Utf8: NULL",
259 "ts: Utf8: NULL",
260 "t: Utf8: NULL",
261 "date: Date32: NULL"
262 ],
263 formatted,
264 );
265 }
266
267 #[tokio::test]
268 async fn infer_schema_with_limit() {
269 let csv = CsvFormat {
270 schema_infer_max_record: Some(3),
271 ..CsvFormat::default()
272 };
273 let store = test_store(&test_data_root());
274 let schema = csv
275 .infer_schema(&store, "schema_infer_limit.csv")
276 .await
277 .unwrap();
278 let formatted: Vec<_> = format_schema(schema);
279
280 assert_eq!(
281 vec![
282 "a: Int64: NULL",
283 "b: Float64: NULL",
284 "c: Int64: NULL",
285 "d: Int64: NULL"
286 ],
287 formatted
288 );
289
290 let csv = CsvFormat::default();
291 let store = test_store(&test_data_root());
292 let schema = csv
293 .infer_schema(&store, "schema_infer_limit.csv")
294 .await
295 .unwrap();
296 let formatted: Vec<_> = format_schema(schema);
297
298 assert_eq!(
299 vec![
300 "a: Int64: NULL",
301 "b: Float64: NULL",
302 "c: Int64: NULL",
303 "d: Utf8: NULL"
304 ],
305 formatted
306 );
307 }
308
309 #[test]
310 fn test_try_from() {
311 let map = HashMap::new();
312 let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
313
314 assert_eq!(format, CsvFormat::default());
315
316 let map = HashMap::from([
317 (
318 FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
319 "2000".to_string(),
320 ),
321 (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
322 (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
323 (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
324 ]);
325 let format = CsvFormat::try_from(&map).unwrap();
326
327 assert_eq!(
328 format,
329 CsvFormat {
330 compression_type: CompressionType::Zstd,
331 schema_infer_max_record: Some(2000),
332 delimiter: b'\t',
333 has_header: false,
334 timestamp_format: None,
335 time_format: None,
336 date_format: None
337 }
338 );
339 }
340
341 #[tokio::test]
342 async fn test_compressed_csv() {
343 let column_schemas = vec![
345 ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
346 ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
347 ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
348 ];
349 let schema = Arc::new(Schema::new(column_schemas));
350
351 let batch1_columns: Vec<VectorRef> = vec![
353 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
354 Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
355 Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
356 ];
357 let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
358
359 let batch2_columns: Vec<VectorRef> = vec![
360 Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
361 Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
362 Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
363 ];
364 let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
365
366 let batch3_columns: Vec<VectorRef> = vec![
367 Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
368 Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
369 Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
370 ];
371 let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
372
373 let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
375
376 let compression_types = vec![
378 CompressionType::Gzip,
379 CompressionType::Bzip2,
380 CompressionType::Xz,
381 CompressionType::Zstd,
382 ];
383
384 let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
386 for compression_type in compression_types {
387 let format = CsvFormat {
388 compression_type,
389 ..CsvFormat::default()
390 };
391
392 let compressed_file_name =
394 format!("test_compressed_csv.{}", compression_type.file_extension());
395 let compressed_file_path = temp_dir.path().join(&compressed_file_name);
396 let compressed_file_path_str = compressed_file_path.to_str().unwrap();
397
398 let store = test_store("/");
400
401 let rows = stream_to_csv(
403 Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
404 store,
405 compressed_file_path_str,
406 1024,
407 1,
408 &format,
409 )
410 .await
411 .unwrap();
412
413 assert_eq!(rows, 9);
414
415 assert!(compressed_file_path.exists());
417 let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
418 assert!(file_size > 0);
419
420 let file_content = std::fs::read(&compressed_file_path).unwrap();
422 match compression_type {
425 CompressionType::Gzip => {
426 assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
428 assert_eq!(
429 file_content[1], 0x8b,
430 "Gzip file should have 0x8b as second byte"
431 );
432 }
433 CompressionType::Bzip2 => {
434 assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
436 assert_eq!(
437 file_content[1], b'Z',
438 "Bzip2 file should have 'Z' as second byte"
439 );
440 }
441 CompressionType::Xz => {
442 assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
444 }
445 CompressionType::Zstd => {
446 assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
448 assert_eq!(
449 file_content[1], 0xB5,
450 "Zstd file should have 0xB5 as second byte"
451 );
452 }
453 _ => {}
454 }
455
456 let store = test_store("/");
458 let schema = Arc::new(
459 CsvFormat {
460 compression_type,
461 ..Default::default()
462 }
463 .infer_schema(&store, compressed_file_path_str)
464 .await
465 .unwrap(),
466 );
467 let csv_source = CsvSource::new(schema).with_batch_size(8192);
468
469 let stream = file_to_stream(
470 &store,
471 compressed_file_path_str,
472 csv_source.clone(),
473 None,
474 compression_type,
475 )
476 .await
477 .unwrap();
478
479 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
480 let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
481 .unwrap()
482 .to_string();
483 let expected = r#"+----+---------+-------+
484| id | name | value |
485+----+---------+-------+
486| 1 | Alice | 10.5 |
487| 2 | Bob | 20.3 |
488| 3 | Charlie | 30.7 |
489| 4 | David | 40.1 |
490| 5 | Eva | 50.2 |
491| 6 | Frank | 60.3 |
492| 7 | Grace | 70.4 |
493| 8 | Henry | 80.5 |
494| 9 | Ivy | 90.6 |
495+----+---------+-------+"#;
496 assert_eq!(expected, pretty_print);
497 }
498 }
499}