1use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct CsvFormat {
38 pub has_header: bool,
39 pub delimiter: u8,
40 pub schema_infer_max_record: Option<usize>,
41 pub compression_type: CompressionType,
42 pub timestamp_format: Option<String>,
43 pub time_format: Option<String>,
44 pub date_format: Option<String>,
45}
46
47impl TryFrom<&HashMap<String, String>> for CsvFormat {
48 type Error = error::Error;
49
50 fn try_from(value: &HashMap<String, String>) -> Result<Self> {
51 let mut format = CsvFormat::default();
52 if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
53 format.delimiter = u8::from_str(delimiter).map_err(|_| {
55 error::ParseFormatSnafu {
56 key: file_format::FORMAT_DELIMITER,
57 value: delimiter,
58 }
59 .build()
60 })?;
61 };
62 if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
63 format.compression_type = CompressionType::from_str(compression_type)?;
64 };
65 if let Some(schema_infer_max_record) =
66 value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
67 {
68 format.schema_infer_max_record =
69 Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
70 error::ParseFormatSnafu {
71 key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
72 value: schema_infer_max_record,
73 }
74 .build()
75 })?);
76 };
77 if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
78 format.has_header = has_header.parse().map_err(|_| {
79 error::ParseFormatSnafu {
80 key: file_format::FORMAT_HAS_HEADER,
81 value: has_header,
82 }
83 .build()
84 })?;
85 };
86 if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
87 format.timestamp_format = Some(timestamp_format.clone());
88 }
89 if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
90 format.time_format = Some(time_format.clone());
91 }
92 if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
93 format.date_format = Some(date_format.clone());
94 }
95 Ok(format)
96 }
97}
98
99impl Default for CsvFormat {
100 fn default() -> Self {
101 Self {
102 has_header: true,
103 delimiter: b',',
104 schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
105 compression_type: CompressionType::Uncompressed,
106 timestamp_format: None,
107 time_format: None,
108 date_format: None,
109 }
110 }
111}
112
113#[async_trait]
114impl FileFormat for CsvFormat {
115 async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
116 let meta = store
117 .stat(path)
118 .await
119 .context(error::ReadObjectSnafu { path })?;
120
121 let reader = store
122 .reader(path)
123 .await
124 .context(error::ReadObjectSnafu { path })?
125 .into_futures_async_read(0..meta.content_length())
126 .await
127 .context(error::ReadObjectSnafu { path })?
128 .compat();
129
130 let decoded = self.compression_type.convert_async_read(reader);
131
132 let delimiter = self.delimiter;
133 let schema_infer_max_record = self.schema_infer_max_record;
134 let has_header = self.has_header;
135
136 common_runtime::spawn_blocking_global(move || {
137 let reader = SyncIoBridge::new(decoded);
138
139 let format = Format::default()
140 .with_delimiter(delimiter)
141 .with_header(has_header);
142 let (schema, _records_read) = format
143 .infer_schema(reader, schema_infer_max_record)
144 .context(error::InferSchemaSnafu)?;
145 Ok(schema)
146 })
147 .await
148 .context(error::JoinHandleSnafu)?
149 }
150}
151
152pub async fn stream_to_csv(
153 stream: SendableRecordBatchStream,
154 store: ObjectStore,
155 path: &str,
156 threshold: usize,
157 concurrency: usize,
158 format: &CsvFormat,
159) -> Result<usize> {
160 stream_to_file(
161 stream,
162 store,
163 path,
164 threshold,
165 concurrency,
166 format.compression_type,
167 |buffer| {
168 let mut builder = WriterBuilder::new();
169 if let Some(timestamp_format) = &format.timestamp_format {
170 builder = builder.with_timestamp_format(timestamp_format.to_owned())
171 }
172 if let Some(date_format) = &format.date_format {
173 builder = builder.with_date_format(date_format.to_owned())
174 }
175 if let Some(time_format) = &format.time_format {
176 builder = builder.with_time_format(time_format.to_owned())
177 }
178 builder.build(buffer)
179 },
180 )
181 .await
182}
183
184impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
185 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
186 self.write(batch).context(error::WriteRecordBatchSnafu)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::Arc;
193
194 use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
195 use common_recordbatch::{RecordBatch, RecordBatches};
196 use common_test_util::find_workspace_path;
197 use datafusion::datasource::physical_plan::{CsvSource, FileSource};
198 use datatypes::prelude::ConcreteDataType;
199 use datatypes::schema::{ColumnSchema, Schema};
200 use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
201 use futures::TryStreamExt;
202
203 use super::*;
204 use crate::file_format::{
205 FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
206 FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat, file_to_stream,
207 };
208 use crate::test_util::{format_schema, test_store};
209
210 fn test_data_root() -> String {
211 find_workspace_path("/src/common/datasource/tests/csv")
212 .display()
213 .to_string()
214 }
215
216 #[tokio::test]
217 async fn infer_schema_basic() {
218 let csv = CsvFormat::default();
219 let store = test_store(&test_data_root());
220 let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
221 let formatted: Vec<_> = format_schema(schema);
222
223 assert_eq!(
224 vec![
225 "c1: Utf8: NULL",
226 "c2: Int64: NULL",
227 "c3: Int64: NULL",
228 "c4: Int64: NULL",
229 "c5: Int64: NULL",
230 "c6: Int64: NULL",
231 "c7: Int64: NULL",
232 "c8: Int64: NULL",
233 "c9: Int64: NULL",
234 "c10: Utf8: NULL",
235 "c11: Float64: NULL",
236 "c12: Float64: NULL",
237 "c13: Utf8: NULL"
238 ],
239 formatted,
240 );
241 }
242
243 #[tokio::test]
244 async fn infer_schema_with_limit() {
245 let json = CsvFormat {
246 schema_infer_max_record: Some(3),
247 ..CsvFormat::default()
248 };
249 let store = test_store(&test_data_root());
250 let schema = json
251 .infer_schema(&store, "schema_infer_limit.csv")
252 .await
253 .unwrap();
254 let formatted: Vec<_> = format_schema(schema);
255
256 assert_eq!(
257 vec![
258 "a: Int64: NULL",
259 "b: Float64: NULL",
260 "c: Int64: NULL",
261 "d: Int64: NULL"
262 ],
263 formatted
264 );
265
266 let json = CsvFormat::default();
267 let store = test_store(&test_data_root());
268 let schema = json
269 .infer_schema(&store, "schema_infer_limit.csv")
270 .await
271 .unwrap();
272 let formatted: Vec<_> = format_schema(schema);
273
274 assert_eq!(
275 vec![
276 "a: Int64: NULL",
277 "b: Float64: NULL",
278 "c: Int64: NULL",
279 "d: Utf8: NULL"
280 ],
281 formatted
282 );
283 }
284
285 #[test]
286 fn test_try_from() {
287 let map = HashMap::new();
288 let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
289
290 assert_eq!(format, CsvFormat::default());
291
292 let map = HashMap::from([
293 (
294 FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
295 "2000".to_string(),
296 ),
297 (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
298 (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
299 (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
300 ]);
301 let format = CsvFormat::try_from(&map).unwrap();
302
303 assert_eq!(
304 format,
305 CsvFormat {
306 compression_type: CompressionType::Zstd,
307 schema_infer_max_record: Some(2000),
308 delimiter: b'\t',
309 has_header: false,
310 timestamp_format: None,
311 time_format: None,
312 date_format: None
313 }
314 );
315 }
316
317 #[tokio::test]
318 async fn test_compressed_csv() {
319 let column_schemas = vec![
321 ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
322 ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
323 ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
324 ];
325 let schema = Arc::new(Schema::new(column_schemas));
326
327 let batch1_columns: Vec<VectorRef> = vec![
329 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
330 Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
331 Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
332 ];
333 let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
334
335 let batch2_columns: Vec<VectorRef> = vec![
336 Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
337 Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
338 Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
339 ];
340 let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
341
342 let batch3_columns: Vec<VectorRef> = vec![
343 Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
344 Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
345 Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
346 ];
347 let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
348
349 let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
351
352 let compression_types = vec![
354 CompressionType::Gzip,
355 CompressionType::Bzip2,
356 CompressionType::Xz,
357 CompressionType::Zstd,
358 ];
359
360 let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
362 for compression_type in compression_types {
363 let format = CsvFormat {
364 compression_type,
365 ..CsvFormat::default()
366 };
367
368 let compressed_file_name =
370 format!("test_compressed_csv.{}", compression_type.file_extension());
371 let compressed_file_path = temp_dir.path().join(&compressed_file_name);
372 let compressed_file_path_str = compressed_file_path.to_str().unwrap();
373
374 let store = test_store("/");
376
377 let rows = stream_to_csv(
379 Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
380 store,
381 compressed_file_path_str,
382 1024,
383 1,
384 &format,
385 )
386 .await
387 .unwrap();
388
389 assert_eq!(rows, 9);
390
391 assert!(compressed_file_path.exists());
393 let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
394 assert!(file_size > 0);
395
396 let file_content = std::fs::read(&compressed_file_path).unwrap();
398 match compression_type {
401 CompressionType::Gzip => {
402 assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
404 assert_eq!(
405 file_content[1], 0x8b,
406 "Gzip file should have 0x8b as second byte"
407 );
408 }
409 CompressionType::Bzip2 => {
410 assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
412 assert_eq!(
413 file_content[1], b'Z',
414 "Bzip2 file should have 'Z' as second byte"
415 );
416 }
417 CompressionType::Xz => {
418 assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
420 }
421 CompressionType::Zstd => {
422 assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
424 assert_eq!(
425 file_content[1], 0xB5,
426 "Zstd file should have 0xB5 as second byte"
427 );
428 }
429 _ => {}
430 }
431
432 let store = test_store("/");
434 let schema = Arc::new(
435 CsvFormat {
436 compression_type,
437 ..Default::default()
438 }
439 .infer_schema(&store, compressed_file_path_str)
440 .await
441 .unwrap(),
442 );
443 let csv_source = CsvSource::new(true, b',', b'"')
444 .with_schema(schema.clone())
445 .with_batch_size(8192);
446
447 let stream = file_to_stream(
448 &store,
449 compressed_file_path_str,
450 schema.clone(),
451 csv_source.clone(),
452 None,
453 compression_type,
454 )
455 .await
456 .unwrap();
457
458 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
459 let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
460 .unwrap()
461 .to_string();
462 let expected = r#"+----+---------+-------+
463| id | name | value |
464+----+---------+-------+
465| 1 | Alice | 10.5 |
466| 2 | Bob | 20.3 |
467| 3 | Charlie | 30.7 |
468| 4 | David | 40.1 |
469| 5 | Eva | 50.2 |
470| 6 | Frank | 60.3 |
471| 7 | Grace | 70.4 |
472| 8 | Henry | 80.5 |
473| 9 | Ivy | 90.6 |
474+----+---------+-------+"#;
475 assert_eq!(expected, pretty_print);
476 }
477 }
478}