Skip to main content

mito2/sst/parquet/reader/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use datafusion_common::cast_column;
19use datafusion_common::format::DEFAULT_CAST_OPTIONS;
20use datatypes::arrow::array::{ArrayRef, new_null_array};
21use datatypes::arrow::datatypes::{DataType, FieldRef, SchemaRef};
22use datatypes::arrow::record_batch::RecordBatch;
23use futures::Stream;
24use futures::stream::BoxStream;
25use snafu::{ResultExt, ensure};
26
27use crate::error::{CastColumnSnafu, NewRecordBatchSnafu, Result, UnexpectedSnafu};
28
29/// Aligns projected batches to the expected output schema for nested projections.
30///
31/// Background
32/// ----------
33/// Nested projection may ask parquet to read leaves under a root column. If none
34/// of the requested leaves exists in the current parquet file, parquet decoding
35/// omits the whole root from the physical [`RecordBatch`].
36///
37/// In addition, after nested-path filtering, returned struct arrays may contain
38/// only a subset of fields. The current output schema is not pruned by nested
39/// paths, so physical struct fields can be a subset of the expected struct
40/// fields, and their nested schema can differ from the expected output schema.
41///
42/// To keep projected batches schema-consistent before entering upper readers:
43/// - Root-column presence alignment restores missing projected root columns by
44///   inserting root-level null arrays.
45/// - Nested struct alignment aligns struct arrays to the expected nested field
46///   layout.
47#[derive(derive_more::Debug)]
48pub struct NestedSchemaAligner<S> {
49    #[debug(skip)]
50    inner: S,
51    /// Output schema expected by the upper reader.
52    output_schema: SchemaRef,
53    /// Whether each projected root exists in the physical batch returned by
54    /// parquet.
55    projected_root_presence: Vec<bool>,
56    /// Number of columns expected from the physical batch returned by parquet.
57    expected_input_col_num: usize,
58    /// Whether all projected roots are present and the stream can pass batches
59    /// through.
60    all_roots_present: bool,
61    /// The cache for whether incoming batches already match output schema.
62    is_schema_matched: Option<bool>,
63}
64
65pub(crate) type ProjectedRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
66
67impl<S> NestedSchemaAligner<S>
68where
69    S: Stream<Item = Result<RecordBatch>>,
70{
71    pub fn new(
72        inner: S,
73        projected_root_presence: Vec<bool>,
74        output_schema: SchemaRef,
75    ) -> Result<NestedSchemaAligner<S>> {
76        ensure!(
77            projected_root_presence.len() == output_schema.fields().len(),
78            UnexpectedSnafu {
79                reason: format!(
80                    "NestedSchemaAligner projected root presence len {} does not match output schema columns {}",
81                    projected_root_presence.len(),
82                    output_schema.fields().len()
83                ),
84            }
85        );
86
87        let expected_input_col_num = projected_root_presence
88            .iter()
89            .filter(|matched| **matched)
90            .count();
91        let all_roots_present = projected_root_presence.iter().all(|&m| m);
92
93        Ok(NestedSchemaAligner {
94            inner,
95            output_schema,
96            projected_root_presence,
97            expected_input_col_num,
98            all_roots_present,
99            is_schema_matched: None,
100        })
101    }
102}
103
104impl<S> Stream for NestedSchemaAligner<S>
105where
106    S: Stream<Item = Result<RecordBatch>> + Unpin,
107{
108    type Item = Result<RecordBatch>;
109
110    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        let this = self.get_mut();
112
113        match Pin::new(&mut this.inner).poll_next(cx) {
114            Poll::Ready(Some(Ok(rb))) => {
115                let rb = if this.all_roots_present {
116                    rb
117                } else {
118                    fill_missing_cols(
119                        rb,
120                        &this.output_schema,
121                        &this.projected_root_presence,
122                        this.expected_input_col_num,
123                    )?
124                };
125
126                let is_schema_matched = *this
127                    .is_schema_matched
128                    .get_or_insert_with(|| rb.schema() == this.output_schema);
129
130                if is_schema_matched {
131                    Poll::Ready(Some(Ok(rb)))
132                } else {
133                    Poll::Ready(Some(align_batch_to_schema(rb, &this.output_schema)))
134                }
135            }
136            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
137            Poll::Ready(None) => Poll::Ready(None),
138            Poll::Pending => Poll::Pending,
139        }
140    }
141}
142
143fn fill_missing_cols(
144    rb: RecordBatch,
145    output_schema: &SchemaRef,
146    projected_root_matches: &[bool],
147    expected_input_col_num: usize,
148) -> Result<RecordBatch> {
149    ensure!(
150        rb.columns().len() == expected_input_col_num,
151        UnexpectedSnafu {
152            reason: format!(
153                "NestedSchemaAligner expected {} input columns but got {}",
154                expected_input_col_num,
155                rb.columns().len()
156            ),
157        }
158    );
159
160    let mut cols = Vec::with_capacity(projected_root_matches.len());
161    let mut idx = 0;
162
163    for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
164        if *matched {
165            cols.push(rb.column(idx).clone());
166            idx += 1;
167        } else {
168            cols.push(new_null_array(field.data_type(), rb.num_rows()));
169        }
170    }
171
172    RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
173}
174
175fn align_batch_to_schema(rb: RecordBatch, output_schema: &SchemaRef) -> Result<RecordBatch> {
176    ensure!(
177        rb.num_columns() == output_schema.fields().len(),
178        UnexpectedSnafu {
179            reason: format!(
180                "NestedSchemaAligner expected {} columns but got {}",
181                output_schema.fields().len(),
182                rb.num_columns()
183            ),
184        }
185    );
186
187    let columns = rb
188        .columns()
189        .iter()
190        .zip(output_schema.fields())
191        .map(|(array, field)| align_array(array, field))
192        .collect::<Result<Vec<_>>>()?;
193
194    RecordBatch::try_new(output_schema.clone(), columns).context(NewRecordBatchSnafu)
195}
196
197fn align_array(array: &ArrayRef, field: &FieldRef) -> Result<ArrayRef> {
198    if array.data_type() == field.data_type() {
199        return Ok(array.clone());
200    }
201
202    if !matches!(field.data_type(), DataType::Struct(_)) {
203        return Ok(array.clone());
204    }
205
206    cast_column(array, field.as_ref(), &DEFAULT_CAST_OPTIONS).context(CastColumnSnafu)
207}
208
209#[cfg(test)]
210mod tests {
211    use std::sync::Arc;
212
213    use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray, StructArray};
214    use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
215    use futures::{StreamExt, stream};
216
217    use super::*;
218
219    #[tokio::test]
220    async fn test_aligner_with_all_projected_roots_match() {
221        let output_schema = schema([
222            Field::new("a", DataType::Int64, true),
223            Field::new("b", DataType::Utf8, true),
224        ]);
225        let input = RecordBatch::try_new(
226            output_schema.clone(),
227            vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
228        )
229        .unwrap();
230        let stream = stream::iter([Ok(input.clone())]);
231
232        let mut aligner =
233            NestedSchemaAligner::new(stream, vec![true, true], output_schema.clone()).unwrap();
234        let output = aligner.next().await.unwrap().unwrap();
235
236        assert_eq!(input, output);
237        assert!(aligner.next().await.is_none());
238    }
239
240    #[tokio::test]
241    async fn test_aligner_with_fills_null_root_columns() {
242        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
243        let output_schema = schema([
244            Field::new("a", DataType::Int64, true),
245            Field::new("missing", DataType::Utf8, true),
246            Field::new("c", DataType::Int64, true),
247        ]);
248        let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
249        let stream = stream::iter([Ok(input)]);
250
251        let mut aligner =
252            NestedSchemaAligner::new(stream, vec![true, false, false], output_schema.clone())
253                .unwrap();
254        let output = aligner.next().await.unwrap().unwrap();
255
256        assert_eq!(output_schema, output.schema());
257        assert_eq!(3, output.num_columns());
258        assert_eq!(
259            &[Some(10), Some(20)],
260            output
261                .column(0)
262                .as_any()
263                .downcast_ref::<Int64Array>()
264                .unwrap()
265                .iter()
266                .collect::<Vec<_>>()
267                .as_slice()
268        );
269        assert_eq!(DataType::Utf8, *output.column(1).data_type());
270        assert_eq!(output.num_rows(), output.column(1).null_count());
271        assert_eq!(DataType::Int64, *output.column(2).data_type());
272        assert_eq!(output.num_rows(), output.column(2).null_count());
273    }
274
275    #[tokio::test]
276    async fn test_aligner_with_fills_missing_struct_root_column() {
277        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
278        let struct_type = DataType::Struct(Fields::from(vec![
279            Field::new("x", DataType::Int64, true),
280            Field::new("y", DataType::Utf8, true),
281        ]));
282        let output_schema = schema([
283            Field::new("a", DataType::Int64, true),
284            Field::new("missing_struct", struct_type.clone(), true),
285        ]);
286        let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
287        let stream = stream::iter([Ok(input)]);
288
289        let mut aligner =
290            NestedSchemaAligner::new(stream, vec![true, false], output_schema.clone()).unwrap();
291        let output = aligner.next().await.unwrap().unwrap();
292
293        assert_eq!(output_schema, output.schema());
294        assert_eq!(2, output.num_columns());
295        assert_eq!(struct_type, output.column(1).data_type().clone());
296        assert_eq!(output.num_rows(), output.column(1).null_count());
297    }
298
299    #[tokio::test]
300    async fn test_aligner_reject_projection_len_mismatch() {
301        let output_schema = schema([Field::new("a", DataType::Int64, true)]);
302        let stream = stream::iter([]);
303
304        let err = match NestedSchemaAligner::new(stream, vec![true, false], output_schema) {
305            Ok(_) => panic!("NestedSchemaAligner should reject projection length mismatch"),
306            Err(err) => err,
307        };
308
309        assert!(
310            err.to_string()
311                .contains("projected root presence len 2 does not match output schema columns 1")
312        );
313    }
314
315    #[tokio::test]
316    async fn test_aligner_reject_with_input_column_mismatch() {
317        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
318        let output_schema = schema([
319            Field::new("a", DataType::Int64, true),
320            Field::new("b", DataType::Int64, true),
321            Field::new("missing", DataType::Int64, true),
322        ]);
323        let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
324        let stream = stream::iter([Ok(input)]);
325
326        let mut aligner =
327            NestedSchemaAligner::new(stream, vec![true, true, false], output_schema).unwrap();
328        let err = aligner.next().await.unwrap().unwrap_err();
329
330        assert!(
331            err.to_string()
332                .contains("expected 2 input columns but got 1")
333        );
334    }
335
336    #[tokio::test]
337    async fn test_nested_schema_aligner_aligns_struct_field() {
338        let output_schema = schema([Field::new(
339            "nested",
340            DataType::Struct(Fields::from(vec![
341                Field::new("x", DataType::Int64, true),
342                Field::new("y", DataType::Utf8, true),
343            ])),
344            true,
345        )]);
346        let input = RecordBatch::try_new(
347            schema([Field::new(
348                "nested",
349                DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int64, true)])),
350                true,
351            )]),
352            vec![Arc::new(StructArray::from(vec![(
353                Arc::new(Field::new("x", DataType::Int64, true)),
354                int_array([1, 2]),
355            )]))],
356        )
357        .unwrap();
358
359        let mut aligner =
360            NestedSchemaAligner::new(stream::iter([Ok(input)]), vec![true], output_schema.clone())
361                .unwrap();
362        let output = aligner.next().await.unwrap().unwrap();
363
364        assert_eq!(output_schema, output.schema());
365        let nested = output
366            .column(0)
367            .as_any()
368            .downcast_ref::<StructArray>()
369            .unwrap();
370        assert_eq!(2, nested.columns().len());
371        assert_eq!(2, nested.column(1).null_count());
372    }
373
374    fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
375        Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
376    }
377
378    fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
379        Arc::new(Int64Array::from_iter_values(values))
380    }
381
382    fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
383        Arc::new(StringArray::from_iter_values(values))
384    }
385}