common_datasource/file_format/
csv.rs1use std::collections::HashMap;
16use std::str::FromStr;
17
18use arrow::csv;
19use arrow::csv::reader::Format;
20use arrow::record_batch::RecordBatch;
21use arrow_schema::Schema;
22use async_trait::async_trait;
23use common_runtime;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use object_store::ObjectStore;
26use snafu::ResultExt;
27use tokio_util::compat::FuturesAsyncReadCompatExt;
28use tokio_util::io::SyncIoBridge;
29
30use crate::buffered_writer::DfRecordBatchEncoder;
31use crate::compression::CompressionType;
32use crate::error::{self, Result};
33use crate::file_format::{self, stream_to_file, FileFormat};
34use crate::share_buffer::SharedBuffer;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct CsvFormat {
38 pub has_header: bool,
39 pub delimiter: u8,
40 pub schema_infer_max_record: Option<usize>,
41 pub compression_type: CompressionType,
42}
43
44impl TryFrom<&HashMap<String, String>> for CsvFormat {
45 type Error = error::Error;
46
47 fn try_from(value: &HashMap<String, String>) -> Result<Self> {
48 let mut format = CsvFormat::default();
49 if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
50 format.delimiter = u8::from_str(delimiter).map_err(|_| {
52 error::ParseFormatSnafu {
53 key: file_format::FORMAT_DELIMITER,
54 value: delimiter,
55 }
56 .build()
57 })?;
58 };
59 if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
60 format.compression_type = CompressionType::from_str(compression_type)?;
61 };
62 if let Some(schema_infer_max_record) =
63 value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
64 {
65 format.schema_infer_max_record =
66 Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
67 error::ParseFormatSnafu {
68 key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
69 value: schema_infer_max_record,
70 }
71 .build()
72 })?);
73 };
74 if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
75 format.has_header = has_header.parse().map_err(|_| {
76 error::ParseFormatSnafu {
77 key: file_format::FORMAT_HAS_HEADER,
78 value: has_header,
79 }
80 .build()
81 })?;
82 }
83 Ok(format)
84 }
85}
86
87impl Default for CsvFormat {
88 fn default() -> Self {
89 Self {
90 has_header: true,
91 delimiter: b',',
92 schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
93 compression_type: CompressionType::Uncompressed,
94 }
95 }
96}
97
98#[async_trait]
99impl FileFormat for CsvFormat {
100 async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
101 let meta = store
102 .stat(path)
103 .await
104 .context(error::ReadObjectSnafu { path })?;
105
106 let reader = store
107 .reader(path)
108 .await
109 .context(error::ReadObjectSnafu { path })?
110 .into_futures_async_read(0..meta.content_length())
111 .await
112 .context(error::ReadObjectSnafu { path })?
113 .compat();
114
115 let decoded = self.compression_type.convert_async_read(reader);
116
117 let delimiter = self.delimiter;
118 let schema_infer_max_record = self.schema_infer_max_record;
119 let has_header = self.has_header;
120
121 common_runtime::spawn_blocking_global(move || {
122 let reader = SyncIoBridge::new(decoded);
123
124 let format = Format::default()
125 .with_delimiter(delimiter)
126 .with_header(has_header);
127 let (schema, _records_read) = format
128 .infer_schema(reader, schema_infer_max_record)
129 .context(error::InferSchemaSnafu)?;
130 Ok(schema)
131 })
132 .await
133 .context(error::JoinHandleSnafu)?
134 }
135}
136
137pub async fn stream_to_csv(
138 stream: SendableRecordBatchStream,
139 store: ObjectStore,
140 path: &str,
141 threshold: usize,
142 concurrency: usize,
143) -> Result<usize> {
144 stream_to_file(stream, store, path, threshold, concurrency, |buffer| {
145 csv::Writer::new(buffer)
146 })
147 .await
148}
149
150impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
151 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
152 self.write(batch).context(error::WriteRecordBatchSnafu)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158
159 use common_test_util::find_workspace_path;
160
161 use super::*;
162 use crate::file_format::{
163 FileFormat, FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
164 FORMAT_SCHEMA_INFER_MAX_RECORD,
165 };
166 use crate::test_util::{format_schema, test_store};
167
168 fn test_data_root() -> String {
169 find_workspace_path("/src/common/datasource/tests/csv")
170 .display()
171 .to_string()
172 }
173
174 #[tokio::test]
175 async fn infer_schema_basic() {
176 let csv = CsvFormat::default();
177 let store = test_store(&test_data_root());
178 let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
179 let formatted: Vec<_> = format_schema(schema);
180
181 assert_eq!(
182 vec![
183 "c1: Utf8: NULL",
184 "c2: Int64: NULL",
185 "c3: Int64: NULL",
186 "c4: Int64: NULL",
187 "c5: Int64: NULL",
188 "c6: Int64: NULL",
189 "c7: Int64: NULL",
190 "c8: Int64: NULL",
191 "c9: Int64: NULL",
192 "c10: Utf8: NULL",
193 "c11: Float64: NULL",
194 "c12: Float64: NULL",
195 "c13: Utf8: NULL"
196 ],
197 formatted,
198 );
199 }
200
201 #[tokio::test]
202 async fn infer_schema_with_limit() {
203 let json = CsvFormat {
204 schema_infer_max_record: Some(3),
205 ..CsvFormat::default()
206 };
207 let store = test_store(&test_data_root());
208 let schema = json
209 .infer_schema(&store, "schema_infer_limit.csv")
210 .await
211 .unwrap();
212 let formatted: Vec<_> = format_schema(schema);
213
214 assert_eq!(
215 vec![
216 "a: Int64: NULL",
217 "b: Float64: NULL",
218 "c: Int64: NULL",
219 "d: Int64: NULL"
220 ],
221 formatted
222 );
223
224 let json = CsvFormat::default();
225 let store = test_store(&test_data_root());
226 let schema = json
227 .infer_schema(&store, "schema_infer_limit.csv")
228 .await
229 .unwrap();
230 let formatted: Vec<_> = format_schema(schema);
231
232 assert_eq!(
233 vec![
234 "a: Int64: NULL",
235 "b: Float64: NULL",
236 "c: Int64: NULL",
237 "d: Utf8: NULL"
238 ],
239 formatted
240 );
241 }
242
243 #[test]
244 fn test_try_from() {
245 let map = HashMap::new();
246 let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
247
248 assert_eq!(format, CsvFormat::default());
249
250 let map = HashMap::from([
251 (
252 FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
253 "2000".to_string(),
254 ),
255 (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
256 (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
257 (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
258 ]);
259 let format = CsvFormat::try_from(&map).unwrap();
260
261 assert_eq!(
262 format,
263 CsvFormat {
264 compression_type: CompressionType::Zstd,
265 schema_infer_max_record: Some(2000),
266 delimiter: b'\t',
267 has_header: false,
268 }
269 );
270 }
271}