1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use datafusion_common::cast_column;
19use datafusion_common::format::DEFAULT_CAST_OPTIONS;
20use datatypes::arrow::array::{ArrayRef, new_null_array};
21use datatypes::arrow::datatypes::{DataType, FieldRef, SchemaRef};
22use datatypes::arrow::record_batch::RecordBatch;
23use futures::Stream;
24use futures::stream::BoxStream;
25use parquet::arrow::async_reader::ParquetRecordBatchStream;
26use snafu::{IntoError, ResultExt, ensure};
27
28use crate::error::{
29 CastColumnSnafu, NewRecordBatchSnafu, ReadParquetSnafu, Result, UnexpectedSnafu,
30};
31use crate::sst::parquet::async_reader::SstAsyncFileReader;
32
33pub struct NestedSchemaAligner<S> {
52 inner: S,
53 output_schema: SchemaRef,
55 projected_root_presence: Vec<bool>,
58 expected_input_col_num: usize,
60 all_roots_present: bool,
63 is_schema_matched: Option<bool>,
65}
66
67pub(crate) type ProjectedRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
68
69impl<S> NestedSchemaAligner<S>
70where
71 S: Stream<Item = Result<RecordBatch>>,
72{
73 pub fn new(
74 inner: S,
75 projected_root_presence: Vec<bool>,
76 output_schema: SchemaRef,
77 ) -> Result<NestedSchemaAligner<S>> {
78 ensure!(
79 projected_root_presence.len() == output_schema.fields().len(),
80 UnexpectedSnafu {
81 reason: format!(
82 "NestedSchemaAligner projected root presence len {} does not match output schema columns {}",
83 projected_root_presence.len(),
84 output_schema.fields().len()
85 ),
86 }
87 );
88
89 let expected_input_col_num = projected_root_presence
90 .iter()
91 .filter(|matched| **matched)
92 .count();
93 let all_roots_present = projected_root_presence.iter().all(|&m| m);
94
95 Ok(NestedSchemaAligner {
96 inner,
97 output_schema,
98 projected_root_presence,
99 expected_input_col_num,
100 all_roots_present,
101 is_schema_matched: None,
102 })
103 }
104}
105
106impl<S> Stream for NestedSchemaAligner<S>
107where
108 S: Stream<Item = Result<RecordBatch>> + Unpin,
109{
110 type Item = Result<RecordBatch>;
111
112 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
113 let this = self.get_mut();
114
115 match Pin::new(&mut this.inner).poll_next(cx) {
116 Poll::Ready(Some(Ok(rb))) => {
117 let rb = if this.all_roots_present {
118 rb
119 } else {
120 fill_missing_cols(
121 rb,
122 &this.output_schema,
123 &this.projected_root_presence,
124 this.expected_input_col_num,
125 )?
126 };
127
128 let is_schema_matched = *this
129 .is_schema_matched
130 .get_or_insert_with(|| rb.schema() == this.output_schema);
131
132 if is_schema_matched {
133 Poll::Ready(Some(Ok(rb)))
134 } else {
135 Poll::Ready(Some(align_batch_to_schema(rb, &this.output_schema)))
136 }
137 }
138 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
139 Poll::Ready(None) => Poll::Ready(None),
140 Poll::Pending => Poll::Pending,
141 }
142 }
143}
144
145fn fill_missing_cols(
146 rb: RecordBatch,
147 output_schema: &SchemaRef,
148 projected_root_matches: &[bool],
149 expected_input_col_num: usize,
150) -> Result<RecordBatch> {
151 ensure!(
152 rb.columns().len() == expected_input_col_num,
153 UnexpectedSnafu {
154 reason: format!(
155 "NestedSchemaAligner expected {} input columns but got {}",
156 expected_input_col_num,
157 rb.columns().len()
158 ),
159 }
160 );
161
162 let mut cols = Vec::with_capacity(projected_root_matches.len());
163 let mut idx = 0;
164
165 for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
166 if *matched {
167 cols.push(rb.column(idx).clone());
168 idx += 1;
169 } else {
170 cols.push(new_null_array(field.data_type(), rb.num_rows()));
171 }
172 }
173
174 RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
175}
176
177fn align_batch_to_schema(rb: RecordBatch, output_schema: &SchemaRef) -> Result<RecordBatch> {
178 ensure!(
179 rb.num_columns() == output_schema.fields().len(),
180 UnexpectedSnafu {
181 reason: format!(
182 "NestedSchemaAligner expected {} columns but got {}",
183 output_schema.fields().len(),
184 rb.num_columns()
185 ),
186 }
187 );
188
189 let columns = rb
190 .columns()
191 .iter()
192 .zip(output_schema.fields())
193 .map(|(array, field)| align_array(array, field))
194 .collect::<Result<Vec<_>>>()?;
195
196 RecordBatch::try_new(output_schema.clone(), columns).context(NewRecordBatchSnafu)
197}
198
199fn align_array(array: &ArrayRef, field: &FieldRef) -> Result<ArrayRef> {
200 if array.data_type() == field.data_type() {
201 return Ok(array.clone());
202 }
203
204 if !matches!(field.data_type(), DataType::Struct(_)) {
205 return Ok(array.clone());
206 }
207
208 cast_column(array, field.as_ref(), &DEFAULT_CAST_OPTIONS).context(CastColumnSnafu)
209}
210
211pub(crate) struct ParquetErrorAdapter {
213 inner: ParquetRecordBatchStream<SstAsyncFileReader>,
214 path: String,
215}
216
217impl ParquetErrorAdapter {
218 pub(crate) fn new(inner: ParquetRecordBatchStream<SstAsyncFileReader>, path: String) -> Self {
219 Self { inner, path }
220 }
221}
222
223impl Stream for ParquetErrorAdapter {
224 type Item = Result<RecordBatch>;
225
226 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227 let this = self.get_mut();
228
229 match Pin::new(&mut this.inner).poll_next(cx) {
230 Poll::Ready(Some(Ok(rb))) => Poll::Ready(Some(Ok(rb))),
231 Poll::Ready(Some(Err(err))) => {
232 Poll::Ready(Some(Err(
233 ReadParquetSnafu { path: &this.path }.into_error(err)
234 )))
235 }
236 Poll::Ready(None) => Poll::Ready(None),
237 Poll::Pending => Poll::Pending,
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use std::sync::Arc;
245
246 use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray, StructArray};
247 use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
248 use futures::{StreamExt, stream};
249
250 use super::*;
251
252 #[tokio::test]
253 async fn test_aligner_with_all_projected_roots_match() {
254 let output_schema = schema([
255 Field::new("a", DataType::Int64, true),
256 Field::new("b", DataType::Utf8, true),
257 ]);
258 let input = RecordBatch::try_new(
259 output_schema.clone(),
260 vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
261 )
262 .unwrap();
263 let stream = stream::iter([Ok(input.clone())]);
264
265 let mut aligner =
266 NestedSchemaAligner::new(stream, vec![true, true], output_schema.clone()).unwrap();
267 let output = aligner.next().await.unwrap().unwrap();
268
269 assert_eq!(input, output);
270 assert!(aligner.next().await.is_none());
271 }
272
273 #[tokio::test]
274 async fn test_aligner_with_fills_null_root_columns() {
275 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
276 let output_schema = schema([
277 Field::new("a", DataType::Int64, true),
278 Field::new("missing", DataType::Utf8, true),
279 Field::new("c", DataType::Int64, true),
280 ]);
281 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
282 let stream = stream::iter([Ok(input)]);
283
284 let mut aligner =
285 NestedSchemaAligner::new(stream, vec![true, false, false], output_schema.clone())
286 .unwrap();
287 let output = aligner.next().await.unwrap().unwrap();
288
289 assert_eq!(output_schema, output.schema());
290 assert_eq!(3, output.num_columns());
291 assert_eq!(
292 &[Some(10), Some(20)],
293 output
294 .column(0)
295 .as_any()
296 .downcast_ref::<Int64Array>()
297 .unwrap()
298 .iter()
299 .collect::<Vec<_>>()
300 .as_slice()
301 );
302 assert_eq!(DataType::Utf8, *output.column(1).data_type());
303 assert_eq!(output.num_rows(), output.column(1).null_count());
304 assert_eq!(DataType::Int64, *output.column(2).data_type());
305 assert_eq!(output.num_rows(), output.column(2).null_count());
306 }
307
308 #[tokio::test]
309 async fn test_aligner_with_fills_missing_struct_root_column() {
310 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
311 let struct_type = DataType::Struct(Fields::from(vec![
312 Field::new("x", DataType::Int64, true),
313 Field::new("y", DataType::Utf8, true),
314 ]));
315 let output_schema = schema([
316 Field::new("a", DataType::Int64, true),
317 Field::new("missing_struct", struct_type.clone(), true),
318 ]);
319 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
320 let stream = stream::iter([Ok(input)]);
321
322 let mut aligner =
323 NestedSchemaAligner::new(stream, vec![true, false], output_schema.clone()).unwrap();
324 let output = aligner.next().await.unwrap().unwrap();
325
326 assert_eq!(output_schema, output.schema());
327 assert_eq!(2, output.num_columns());
328 assert_eq!(struct_type, output.column(1).data_type().clone());
329 assert_eq!(output.num_rows(), output.column(1).null_count());
330 }
331
332 #[tokio::test]
333 async fn test_aligner_reject_projection_len_mismatch() {
334 let output_schema = schema([Field::new("a", DataType::Int64, true)]);
335 let stream = stream::iter([]);
336
337 let err = match NestedSchemaAligner::new(stream, vec![true, false], output_schema) {
338 Ok(_) => panic!("NestedSchemaAligner should reject projection length mismatch"),
339 Err(err) => err,
340 };
341
342 assert!(
343 err.to_string()
344 .contains("projected root presence len 2 does not match output schema columns 1")
345 );
346 }
347
348 #[tokio::test]
349 async fn test_aligner_reject_with_input_column_mismatch() {
350 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
351 let output_schema = schema([
352 Field::new("a", DataType::Int64, true),
353 Field::new("b", DataType::Int64, true),
354 Field::new("missing", DataType::Int64, true),
355 ]);
356 let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
357 let stream = stream::iter([Ok(input)]);
358
359 let mut aligner =
360 NestedSchemaAligner::new(stream, vec![true, true, false], output_schema).unwrap();
361 let err = aligner.next().await.unwrap().unwrap_err();
362
363 assert!(
364 err.to_string()
365 .contains("expected 2 input columns but got 1")
366 );
367 }
368
369 #[tokio::test]
370 async fn test_nested_schema_aligner_aligns_struct_field() {
371 let output_schema = schema([Field::new(
372 "nested",
373 DataType::Struct(Fields::from(vec![
374 Field::new("x", DataType::Int64, true),
375 Field::new("y", DataType::Utf8, true),
376 ])),
377 true,
378 )]);
379 let input = RecordBatch::try_new(
380 schema([Field::new(
381 "nested",
382 DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int64, true)])),
383 true,
384 )]),
385 vec![Arc::new(StructArray::from(vec![(
386 Arc::new(Field::new("x", DataType::Int64, true)),
387 int_array([1, 2]),
388 )]))],
389 )
390 .unwrap();
391
392 let mut aligner =
393 NestedSchemaAligner::new(stream::iter([Ok(input)]), vec![true], output_schema.clone())
394 .unwrap();
395 let output = aligner.next().await.unwrap().unwrap();
396
397 assert_eq!(output_schema, output.schema());
398 let nested = output
399 .column(0)
400 .as_any()
401 .downcast_ref::<StructArray>()
402 .unwrap();
403 assert_eq!(2, nested.columns().len());
404 assert_eq!(2, nested.column(1).null_count());
405 }
406
407 fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
408 Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
409 }
410
411 fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
412 Arc::new(Int64Array::from_iter_values(values))
413 }
414
415 fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
416 Arc::new(StringArray::from_iter_values(values))
417 }
418}