common_datasource/file_format/
csv.rs1use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv::reader::Format;
19use arrow::csv::{self, WriterBuilder};
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, FileFormat, stream_to_file};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct CsvFormat {
38 pub has_header: bool,
39 pub delimiter: u8,
40 pub schema_infer_max_record: Option<usize>,
41 pub compression_type: CompressionType,
42 pub timestamp_format: Option<String>,
43 pub time_format: Option<String>,
44 pub date_format: Option<String>,
45}
46
47impl TryFrom<&HashMap<String, String>> for CsvFormat {
48 type Error = error::Error;
49
50 fn try_from(value: &HashMap<String, String>) -> Result<Self> {
51 let mut format = CsvFormat::default();
52 if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
53 format.delimiter = u8::from_str(delimiter).map_err(|_| {
55 error::ParseFormatSnafu {
56 key: file_format::FORMAT_DELIMITER,
57 value: delimiter,
58 }
59 .build()
60 })?;
61 };
62 if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
63 format.compression_type = CompressionType::from_str(compression_type)?;
64 };
65 if let Some(schema_infer_max_record) =
66 value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
67 {
68 format.schema_infer_max_record =
69 Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
70 error::ParseFormatSnafu {
71 key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
72 value: schema_infer_max_record,
73 }
74 .build()
75 })?);
76 };
77 if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
78 format.has_header = has_header.parse().map_err(|_| {
79 error::ParseFormatSnafu {
80 key: file_format::FORMAT_HAS_HEADER,
81 value: has_header,
82 }
83 .build()
84 })?;
85 };
86 if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
87 format.timestamp_format = Some(timestamp_format.clone());
88 }
89 if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
90 format.time_format = Some(time_format.clone());
91 }
92 if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
93 format.date_format = Some(date_format.clone());
94 }
95 Ok(format)
96 }
97}
98
99impl Default for CsvFormat {
100 fn default() -> Self {
101 Self {
102 has_header: true,
103 delimiter: b',',
104 schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
105 compression_type: CompressionType::Uncompressed,
106 timestamp_format: None,
107 time_format: None,
108 date_format: None,
109 }
110 }
111}
112
113#[async_trait]
114impl FileFormat for CsvFormat {
115 async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
116 let meta = store
117 .stat(path)
118 .await
119 .context(error::ReadObjectSnafu { path })?;
120
121 let reader = store
122 .reader(path)
123 .await
124 .context(error::ReadObjectSnafu { path })?
125 .into_futures_async_read(0..meta.content_length())
126 .await
127 .context(error::ReadObjectSnafu { path })?
128 .compat();
129
130 let decoded = self.compression_type.convert_async_read(reader);
131
132 let delimiter = self.delimiter;
133 let schema_infer_max_record = self.schema_infer_max_record;
134 let has_header = self.has_header;
135
136 common_runtime::spawn_blocking_global(move || {
137 let reader = SyncIoBridge::new(decoded);
138
139 let format = Format::default()
140 .with_delimiter(delimiter)
141 .with_header(has_header);
142 let (schema, _records_read) = format
143 .infer_schema(reader, schema_infer_max_record)
144 .context(error::InferSchemaSnafu)?;
145 Ok(schema)
146 })
147 .await
148 .context(error::JoinHandleSnafu)?
149 }
150}
151
152pub async fn stream_to_csv(
153 stream: SendableRecordBatchStream,
154 store: ObjectStore,
155 path: &str,
156 threshold: usize,
157 concurrency: usize,
158 format: &CsvFormat,
159) -> Result<usize> {
160 stream_to_file(stream, store, path, threshold, concurrency, |buffer| {
161 let mut builder = WriterBuilder::new();
162 if let Some(timestamp_format) = &format.timestamp_format {
163 builder = builder.with_timestamp_format(timestamp_format.to_owned())
164 }
165 if let Some(date_format) = &format.date_format {
166 builder = builder.with_date_format(date_format.to_owned())
167 }
168 if let Some(time_format) = &format.time_format {
169 builder = builder.with_time_format(time_format.to_owned())
170 }
171 builder.build(buffer)
172 })
173 .await
174}
175
176impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
177 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
178 self.write(batch).context(error::WriteRecordBatchSnafu)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184
185 use common_test_util::find_workspace_path;
186
187 use super::*;
188 use crate::file_format::{
189 FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
190 FORMAT_SCHEMA_INFER_MAX_RECORD, FileFormat,
191 };
192 use crate::test_util::{format_schema, test_store};
193
194 fn test_data_root() -> String {
195 find_workspace_path("/src/common/datasource/tests/csv")
196 .display()
197 .to_string()
198 }
199
200 #[tokio::test]
201 async fn infer_schema_basic() {
202 let csv = CsvFormat::default();
203 let store = test_store(&test_data_root());
204 let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
205 let formatted: Vec<_> = format_schema(schema);
206
207 assert_eq!(
208 vec![
209 "c1: Utf8: NULL",
210 "c2: Int64: NULL",
211 "c3: Int64: NULL",
212 "c4: Int64: NULL",
213 "c5: Int64: NULL",
214 "c6: Int64: NULL",
215 "c7: Int64: NULL",
216 "c8: Int64: NULL",
217 "c9: Int64: NULL",
218 "c10: Utf8: NULL",
219 "c11: Float64: NULL",
220 "c12: Float64: NULL",
221 "c13: Utf8: NULL"
222 ],
223 formatted,
224 );
225 }
226
227 #[tokio::test]
228 async fn infer_schema_with_limit() {
229 let json = CsvFormat {
230 schema_infer_max_record: Some(3),
231 ..CsvFormat::default()
232 };
233 let store = test_store(&test_data_root());
234 let schema = json
235 .infer_schema(&store, "schema_infer_limit.csv")
236 .await
237 .unwrap();
238 let formatted: Vec<_> = format_schema(schema);
239
240 assert_eq!(
241 vec![
242 "a: Int64: NULL",
243 "b: Float64: NULL",
244 "c: Int64: NULL",
245 "d: Int64: NULL"
246 ],
247 formatted
248 );
249
250 let json = CsvFormat::default();
251 let store = test_store(&test_data_root());
252 let schema = json
253 .infer_schema(&store, "schema_infer_limit.csv")
254 .await
255 .unwrap();
256 let formatted: Vec<_> = format_schema(schema);
257
258 assert_eq!(
259 vec![
260 "a: Int64: NULL",
261 "b: Float64: NULL",
262 "c: Int64: NULL",
263 "d: Utf8: NULL"
264 ],
265 formatted
266 );
267 }
268
269 #[test]
270 fn test_try_from() {
271 let map = HashMap::new();
272 let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
273
274 assert_eq!(format, CsvFormat::default());
275
276 let map = HashMap::from([
277 (
278 FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
279 "2000".to_string(),
280 ),
281 (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
282 (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
283 (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
284 ]);
285 let format = CsvFormat::try_from(&map).unwrap();
286
287 assert_eq!(
288 format,
289 CsvFormat {
290 compression_type: CompressionType::Zstd,
291 schema_infer_max_record: Some(2000),
292 delimiter: b'\t',
293 has_header: false,
294 timestamp_format: None,
295 time_format: None,
296 date_format: None
297 }
298 );
299 }
300}