1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use datafusion_common::cast_column;
19use datafusion_common::format::DEFAULT_CAST_OPTIONS;
20use datatypes::arrow::array::{ArrayRef, new_null_array};
21use datatypes::arrow::datatypes::{DataType, FieldRef, SchemaRef};
22use datatypes::arrow::record_batch::RecordBatch;
23use futures::Stream;
24use futures::stream::BoxStream;
25use snafu::{ResultExt, ensure};
26
27use crate::error::{CastColumnSnafu, NewRecordBatchSnafu, Result, UnexpectedSnafu};
28
29pub struct NestedSchemaAligner<S> {
48 inner: S,
49 output_schema: SchemaRef,
51 projected_root_presence: Vec<bool>,
54 expected_input_col_num: usize,
56 all_roots_present: bool,
59 is_schema_matched: Option<bool>,
61}
62
63pub(crate) type ProjectedRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
64
65impl<S> NestedSchemaAligner<S>
66where
67 S: Stream<Item = Result<RecordBatch>>,
68{
69 pub fn new(
70 inner: S,
71 projected_root_presence: Vec<bool>,
72 output_schema: SchemaRef,
73 ) -> Result<NestedSchemaAligner<S>> {
74 ensure!(
75 projected_root_presence.len() == output_schema.fields().len(),
76 UnexpectedSnafu {
77 reason: format!(
78 "NestedSchemaAligner projected root presence len {} does not match output schema columns {}",
79 projected_root_presence.len(),
80 output_schema.fields().len()
81 ),
82 }
83 );
84
85 let expected_input_col_num = projected_root_presence
86 .iter()
87 .filter(|matched| **matched)
88 .count();
89 let all_roots_present = projected_root_presence.iter().all(|&m| m);
90
91 Ok(NestedSchemaAligner {
92 inner,
93 output_schema,
94 projected_root_presence,
95 expected_input_col_num,
96 all_roots_present,
97 is_schema_matched: None,
98 })
99 }
100}
101
102impl<S> Stream for NestedSchemaAligner<S>
103where
104 S: Stream<Item = Result<RecordBatch>> + Unpin,
105{
106 type Item = Result<RecordBatch>;
107
108 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
109 let this = self.get_mut();
110
111 match Pin::new(&mut this.inner).poll_next(cx) {
112 Poll::Ready(Some(Ok(rb))) => {
113 let rb = if this.all_roots_present {
114 rb
115 } else {
116 fill_missing_cols(
117 rb,
118 &this.output_schema,
119 &this.projected_root_presence,
120 this.expected_input_col_num,
121 )?
122 };
123
124 let is_schema_matched = *this
125 .is_schema_matched
126 .get_or_insert_with(|| rb.schema() == this.output_schema);
127
128 if is_schema_matched {
129 Poll::Ready(Some(Ok(rb)))
130 } else {
131 Poll::Ready(Some(align_batch_to_schema(rb, &this.output_schema)))
132 }
133 }
134 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
135 Poll::Ready(None) => Poll::Ready(None),
136 Poll::Pending => Poll::Pending,
137 }
138 }
139}
140
141fn fill_missing_cols(
142 rb: RecordBatch,
143 output_schema: &SchemaRef,
144 projected_root_matches: &[bool],
145 expected_input_col_num: usize,
146) -> Result<RecordBatch> {
147 ensure!(
148 rb.columns().len() == expected_input_col_num,
149 UnexpectedSnafu {
150 reason: format!(
151 "NestedSchemaAligner expected {} input columns but got {}",
152 expected_input_col_num,
153 rb.columns().len()
154 ),
155 }
156 );
157
158 let mut cols = Vec::with_capacity(projected_root_matches.len());
159 let mut idx = 0;
160
161 for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
162 if *matched {
163 cols.push(rb.column(idx).clone());
164 idx += 1;
165 } else {
166 cols.push(new_null_array(field.data_type(), rb.num_rows()));
167 }
168 }
169
170 RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
171}
172
173fn align_batch_to_schema(rb: RecordBatch, output_schema: &SchemaRef) -> Result<RecordBatch> {
174 ensure!(
175 rb.num_columns() == output_schema.fields().len(),
176 UnexpectedSnafu {
177 reason: format!(
178 "NestedSchemaAligner expected {} columns but got {}",
179 output_schema.fields().len(),
180 rb.num_columns()
181 ),
182 }
183 );
184
185 let columns = rb
186 .columns()
187 .iter()
188 .zip(output_schema.fields())
189 .map(|(array, field)| align_array(array, field))
190 .collect::<Result<Vec<_>>>()?;
191
192 RecordBatch::try_new(output_schema.clone(), columns).context(NewRecordBatchSnafu)
193}
194
195fn align_array(array: &ArrayRef, field: &FieldRef) -> Result<ArrayRef> {
196 if array.data_type() == field.data_type() {
197 return Ok(array.clone());
198 }
199
200 if !matches!(field.data_type(), DataType::Struct(_)) {
201 return Ok(array.clone());
202 }
203
204 cast_column(array, field.as_ref(), &DEFAULT_CAST_OPTIONS).context(CastColumnSnafu)
205}
206
207#[cfg(test)]
208mod tests {
209 use std::sync::Arc;
210
211 use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray, StructArray};
212 use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
213 use futures::{StreamExt, stream};
214
215 use super::*;
216
217 #[tokio::test]
218 async fn test_aligner_with_all_projected_roots_match() {
219 let output_schema = schema([
220 Field::new("a", DataType::Int64, true),
221 Field::new("b", DataType::Utf8, true),
222 ]);
223 let input = RecordBatch::try_new(
224 output_schema.clone(),
225 vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
226 )
227 .unwrap();
228 let stream = stream::iter([Ok(input.clone())]);
229
230 let mut aligner =
231 NestedSchemaAligner::new(stream, vec![true, true], output_schema.clone()).unwrap();
232 let output = aligner.next().await.unwrap().unwrap();
233
234 assert_eq!(input, output);
235 assert!(aligner.next().await.is_none());
236 }
237
238 #[tokio::test]
239 async fn test_aligner_with_fills_null_root_columns() {
240 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
241 let output_schema = schema([
242 Field::new("a", DataType::Int64, true),
243 Field::new("missing", DataType::Utf8, true),
244 Field::new("c", DataType::Int64, true),
245 ]);
246 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
247 let stream = stream::iter([Ok(input)]);
248
249 let mut aligner =
250 NestedSchemaAligner::new(stream, vec![true, false, false], output_schema.clone())
251 .unwrap();
252 let output = aligner.next().await.unwrap().unwrap();
253
254 assert_eq!(output_schema, output.schema());
255 assert_eq!(3, output.num_columns());
256 assert_eq!(
257 &[Some(10), Some(20)],
258 output
259 .column(0)
260 .as_any()
261 .downcast_ref::<Int64Array>()
262 .unwrap()
263 .iter()
264 .collect::<Vec<_>>()
265 .as_slice()
266 );
267 assert_eq!(DataType::Utf8, *output.column(1).data_type());
268 assert_eq!(output.num_rows(), output.column(1).null_count());
269 assert_eq!(DataType::Int64, *output.column(2).data_type());
270 assert_eq!(output.num_rows(), output.column(2).null_count());
271 }
272
273 #[tokio::test]
274 async fn test_aligner_with_fills_missing_struct_root_column() {
275 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
276 let struct_type = DataType::Struct(Fields::from(vec![
277 Field::new("x", DataType::Int64, true),
278 Field::new("y", DataType::Utf8, true),
279 ]));
280 let output_schema = schema([
281 Field::new("a", DataType::Int64, true),
282 Field::new("missing_struct", struct_type.clone(), true),
283 ]);
284 let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
285 let stream = stream::iter([Ok(input)]);
286
287 let mut aligner =
288 NestedSchemaAligner::new(stream, vec![true, false], output_schema.clone()).unwrap();
289 let output = aligner.next().await.unwrap().unwrap();
290
291 assert_eq!(output_schema, output.schema());
292 assert_eq!(2, output.num_columns());
293 assert_eq!(struct_type, output.column(1).data_type().clone());
294 assert_eq!(output.num_rows(), output.column(1).null_count());
295 }
296
297 #[tokio::test]
298 async fn test_aligner_reject_projection_len_mismatch() {
299 let output_schema = schema([Field::new("a", DataType::Int64, true)]);
300 let stream = stream::iter([]);
301
302 let err = match NestedSchemaAligner::new(stream, vec![true, false], output_schema) {
303 Ok(_) => panic!("NestedSchemaAligner should reject projection length mismatch"),
304 Err(err) => err,
305 };
306
307 assert!(
308 err.to_string()
309 .contains("projected root presence len 2 does not match output schema columns 1")
310 );
311 }
312
313 #[tokio::test]
314 async fn test_aligner_reject_with_input_column_mismatch() {
315 let input_schema = schema([Field::new("a", DataType::Int64, true)]);
316 let output_schema = schema([
317 Field::new("a", DataType::Int64, true),
318 Field::new("b", DataType::Int64, true),
319 Field::new("missing", DataType::Int64, true),
320 ]);
321 let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
322 let stream = stream::iter([Ok(input)]);
323
324 let mut aligner =
325 NestedSchemaAligner::new(stream, vec![true, true, false], output_schema).unwrap();
326 let err = aligner.next().await.unwrap().unwrap_err();
327
328 assert!(
329 err.to_string()
330 .contains("expected 2 input columns but got 1")
331 );
332 }
333
334 #[tokio::test]
335 async fn test_nested_schema_aligner_aligns_struct_field() {
336 let output_schema = schema([Field::new(
337 "nested",
338 DataType::Struct(Fields::from(vec![
339 Field::new("x", DataType::Int64, true),
340 Field::new("y", DataType::Utf8, true),
341 ])),
342 true,
343 )]);
344 let input = RecordBatch::try_new(
345 schema([Field::new(
346 "nested",
347 DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int64, true)])),
348 true,
349 )]),
350 vec![Arc::new(StructArray::from(vec![(
351 Arc::new(Field::new("x", DataType::Int64, true)),
352 int_array([1, 2]),
353 )]))],
354 )
355 .unwrap();
356
357 let mut aligner =
358 NestedSchemaAligner::new(stream::iter([Ok(input)]), vec![true], output_schema.clone())
359 .unwrap();
360 let output = aligner.next().await.unwrap().unwrap();
361
362 assert_eq!(output_schema, output.schema());
363 let nested = output
364 .column(0)
365 .as_any()
366 .downcast_ref::<StructArray>()
367 .unwrap();
368 assert_eq!(2, nested.columns().len());
369 assert_eq!(2, nested.column(1).null_count());
370 }
371
372 fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
373 Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
374 }
375
376 fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
377 Arc::new(Int64Array::from_iter_values(values))
378 }
379
380 fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
381 Arc::new(StringArray::from_iter_values(values))
382 }
383}