1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use datafusion_common::cast_column;
19use datafusion_common::format::DEFAULT_CAST_OPTIONS;
20use datatypes::arrow::array::{ArrayRef, new_null_array};
21use datatypes::arrow::datatypes::{DataType, FieldRef, SchemaRef};
22use datatypes::arrow::record_batch::RecordBatch;
23use futures::Stream;
24use futures::stream::BoxStream;
25use snafu::{ResultExt, ensure};
26
27use crate::error::{CastColumnSnafu, NewRecordBatchSnafu, Result, UnexpectedSnafu};
28
29#[derive(derive_more::Debug)]
48pub struct NestedSchemaAligner<S> {
49 #[debug(skip)]
50 inner: S,
51 output_schema: SchemaRef,
53 projected_root_presence: Vec<bool>,
56 expected_input_col_num: usize,
58 all_roots_present: bool,
61 is_schema_matched: Option<bool>,
63}
64
65pub(crate) type ProjectedRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
66
67impl<S> NestedSchemaAligner<S>
68where
69 S: Stream<Item = Result<RecordBatch>>,
70{
71 pub fn new(
72 inner: S,
73 projected_root_presence: Vec<bool>,
74 output_schema: SchemaRef,
75 ) -> Result<NestedSchemaAligner<S>> {
76 ensure!(
77 projected_root_presence.len() == output_schema.fields().len(),
78 UnexpectedSnafu {
79 reason: format!(
80 "NestedSchemaAligner projected root presence len {} does not match output schema columns {}",
81 projected_root_presence.len(),
82 output_schema.fields().len()
83 ),
84 }
85 );
86
87 let expected_input_col_num = projected_root_presence
88 .iter()
89 .filter(|matched| **matched)
90 .count();
91 let all_roots_present = projected_root_presence.iter().all(|&m| m);
92
93 Ok(NestedSchemaAligner {
94 inner,
95 output_schema,
96 projected_root_presence,
97 expected_input_col_num,
98 all_roots_present,
99 is_schema_matched: None,
100 })
101 }
102}
103
104impl<S> Stream for NestedSchemaAligner<S>
105where
106 S: Stream<Item = Result<RecordBatch>> + Unpin,
107{
108 type Item = Result<RecordBatch>;
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 let this = self.get_mut();
112
113 match Pin::new(&mut this.inner).poll_next(cx) {
114 Poll::Ready(Some(Ok(rb))) => {
115 let rb = if this.all_roots_present {
116 rb
117 } else {
118 fill_missing_cols(
119 rb,
120 &this.output_schema,
121 &this.projected_root_presence,
122 this.expected_input_col_num,
123 )?
124 };
125
126 let is_schema_matched = *this
127 .is_schema_matched
128 .get_or_insert_with(|| rb.schema() == this.output_schema);
129
130 if is_schema_matched {
131 Poll::Ready(Some(Ok(rb)))
132 } else {
133 Poll::Ready(Some(align_batch_to_schema(rb, &this.output_schema)))
134 }
135 }
136 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
137 Poll::Ready(None) => Poll::Ready(None),
138 Poll::Pending => Poll::Pending,
139 }
140 }
141}
142
143fn fill_missing_cols(
144 rb: RecordBatch,
145 output_schema: &SchemaRef,
146 projected_root_matches: &[bool],
147 expected_input_col_num: usize,
148) -> Result<RecordBatch> {
149 ensure!(
150 rb.columns().len() == expected_input_col_num,
151 UnexpectedSnafu {
152 reason: format!(
153 "NestedSchemaAligner expected {} input columns but got {}",
154 expected_input_col_num,
155 rb.columns().len()
156 ),
157 }
158 );
159
160 let mut cols = Vec::with_capacity(projected_root_matches.len());
161 let mut idx = 0;
162
163 for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
164 if *matched {
165 cols.push(rb.column(idx).clone());
166 idx += 1;
167 } else {
168 cols.push(new_null_array(field.data_type(), rb.num_rows()));
169 }
170 }
171
172 RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
173}
174
175fn align_batch_to_schema(rb: RecordBatch, output_schema: &SchemaRef) -> Result<RecordBatch> {
176 ensure!(
177 rb.num_columns() == output_schema.fields().len(),
178 UnexpectedSnafu {
179 reason: format!(
180 "NestedSchemaAligner expected {} columns but got {}",
181 output_schema.fields().len(),
182 rb.num_columns()
183 ),
184 }
185 );
186
187 let columns = rb
188 .columns()
189 .iter()
190 .zip(output_schema.fields())
191 .map(|(array, field)| align_array(array, field))
192 .collect::<Result<Vec<_>>>()?;
193
194 RecordBatch::try_new(output_schema.clone(), columns).context(NewRecordBatchSnafu)
195}
196
197fn align_array(array: &ArrayRef, field: &FieldRef) -> Result<ArrayRef> {
198 if array.data_type() == field.data_type() {
199 return Ok(array.clone());
200 }
201
202 if !matches!(field.data_type(), DataType::Struct(_)) {
203 return Ok(array.clone());
204 }
205
206 cast_column(array, field.as_ref(), &DEFAULT_CAST_OPTIONS).context(CastColumnSnafu)
207}
208
209#[cfg(test)]
210mod tests {
211 use std::sync::Arc;
212
213 use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray, StructArray};
214 use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
215 use futures::{StreamExt, stream};
216
217 use super::*;
218
219 #[tokio::test]
220 async fn test_aligner_with_all_projected_roots_match() {
221 let output_schema = schema([
222 Field::new("a", DataType::Int64, true),
223 Field::new("b", DataType::Utf8, true),
224 ]);
225 let input = RecordBatch::try_new(
226 output_schema.clone(),
227 vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
228 )
229 .unwrap();
230 let stream = stream::iter([Ok(input.clone())]);
231
232 let mut aligner =
233 NestedSchemaAligner::new(stream, vec![true, true], output_schema.clone()).unwrap();
234 let output = aligner.next().await.unwrap().unwrap();
235
236 assert_eq!(input, output);
237 assert!(aligner.next().await.is_none());
238 }
239
240 #[tokio::test]
241 async fn test_aligner_with_fills_null_root_columns() {
242 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
243 let output_schema = schema([
244 Field::new("a", DataType::Int64, true),
245 Field::new("missing", DataType::Utf8, true),
246 Field::new("c", DataType::Int64, true),
247 ]);
248 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
249 let stream = stream::iter([Ok(input)]);
250
251 let mut aligner =
252 NestedSchemaAligner::new(stream, vec![true, false, false], output_schema.clone())
253 .unwrap();
254 let output = aligner.next().await.unwrap().unwrap();
255
256 assert_eq!(output_schema, output.schema());
257 assert_eq!(3, output.num_columns());
258 assert_eq!(
259 &[Some(10), Some(20)],
260 output
261 .column(0)
262 .as_any()
263 .downcast_ref::<Int64Array>()
264 .unwrap()
265 .iter()
266 .collect::<Vec<_>>()
267 .as_slice()
268 );
269 assert_eq!(DataType::Utf8, *output.column(1).data_type());
270 assert_eq!(output.num_rows(), output.column(1).null_count());
271 assert_eq!(DataType::Int64, *output.column(2).data_type());
272 assert_eq!(output.num_rows(), output.column(2).null_count());
273 }
274
275 #[tokio::test]
276 async fn test_aligner_with_fills_missing_struct_root_column() {
277 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
278 let struct_type = DataType::Struct(Fields::from(vec![
279 Field::new("x", DataType::Int64, true),
280 Field::new("y", DataType::Utf8, true),
281 ]));
282 let output_schema = schema([
283 Field::new("a", DataType::Int64, true),
284 Field::new("missing_struct", struct_type.clone(), true),
285 ]);
286 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
287 let stream = stream::iter([Ok(input)]);
288
289 let mut aligner =
290 NestedSchemaAligner::new(stream, vec![true, false], output_schema.clone()).unwrap();
291 let output = aligner.next().await.unwrap().unwrap();
292
293 assert_eq!(output_schema, output.schema());
294 assert_eq!(2, output.num_columns());
295 assert_eq!(struct_type, output.column(1).data_type().clone());
296 assert_eq!(output.num_rows(), output.column(1).null_count());
297 }
298
299 #[tokio::test]
300 async fn test_aligner_reject_projection_len_mismatch() {
301 let output_schema = schema([Field::new("a", DataType::Int64, true)]);
302 let stream = stream::iter([]);
303
304 let err = match NestedSchemaAligner::new(stream, vec![true, false], output_schema) {
305 Ok(_) => panic!("NestedSchemaAligner should reject projection length mismatch"),
306 Err(err) => err,
307 };
308
309 assert!(
310 err.to_string()
311 .contains("projected root presence len 2 does not match output schema columns 1")
312 );
313 }
314
315 #[tokio::test]
316 async fn test_aligner_reject_with_input_column_mismatch() {
317 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
318 let output_schema = schema([
319 Field::new("a", DataType::Int64, true),
320 Field::new("b", DataType::Int64, true),
321 Field::new("missing", DataType::Int64, true),
322 ]);
323 let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
324 let stream = stream::iter([Ok(input)]);
325
326 let mut aligner =
327 NestedSchemaAligner::new(stream, vec![true, true, false], output_schema).unwrap();
328 let err = aligner.next().await.unwrap().unwrap_err();
329
330 assert!(
331 err.to_string()
332 .contains("expected 2 input columns but got 1")
333 );
334 }
335
336 #[tokio::test]
337 async fn test_nested_schema_aligner_aligns_struct_field() {
338 let output_schema = schema([Field::new(
339 "nested",
340 DataType::Struct(Fields::from(vec![
341 Field::new("x", DataType::Int64, true),
342 Field::new("y", DataType::Utf8, true),
343 ])),
344 true,
345 )]);
346 let input = RecordBatch::try_new(
347 schema([Field::new(
348 "nested",
349 DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int64, true)])),
350 true,
351 )]),
352 vec![Arc::new(StructArray::from(vec![(
353 Arc::new(Field::new("x", DataType::Int64, true)),
354 int_array([1, 2]),
355 )]))],
356 )
357 .unwrap();
358
359 let mut aligner =
360 NestedSchemaAligner::new(stream::iter([Ok(input)]), vec![true], output_schema.clone())
361 .unwrap();
362 let output = aligner.next().await.unwrap().unwrap();
363
364 assert_eq!(output_schema, output.schema());
365 let nested = output
366 .column(0)
367 .as_any()
368 .downcast_ref::<StructArray>()
369 .unwrap();
370 assert_eq!(2, nested.columns().len());
371 assert_eq!(2, nested.column(1).null_count());
372 }
373
374 fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
375 Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
376 }
377
378 fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
379 Arc::new(Int64Array::from_iter_values(values))
380 }
381
382 fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
383 Arc::new(StringArray::from_iter_values(values))
384 }
385}