Skip to main content

mito2/sst/parquet/reader/
stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use datafusion_common::cast_column;
19use datafusion_common::format::DEFAULT_CAST_OPTIONS;
20use datatypes::arrow::array::{ArrayRef, new_null_array};
21use datatypes::arrow::datatypes::{DataType, FieldRef, SchemaRef};
22use datatypes::arrow::record_batch::RecordBatch;
23use futures::Stream;
24use futures::stream::BoxStream;
25use snafu::{ResultExt, ensure};
26
27use crate::error::{CastColumnSnafu, NewRecordBatchSnafu, Result, UnexpectedSnafu};
28
29/// Aligns projected batches to the expected output schema for nested projections.
30///
31/// Background
32/// ----------
33/// Nested projection may ask parquet to read leaves under a root column. If none
34/// of the requested leaves exists in the current parquet file, parquet decoding
35/// omits the whole root from the physical [`RecordBatch`].
36///
37/// In addition, after nested-path filtering, returned struct arrays may contain
38/// only a subset of fields. The current output schema is not pruned by nested
39/// paths, so physical struct fields can be a subset of the expected struct
40/// fields, and their nested schema can differ from the expected output schema.
41///
42/// To keep projected batches schema-consistent before entering upper readers:
43/// - Root-column presence alignment restores missing projected root columns by
44///   inserting root-level null arrays.
45/// - Nested struct alignment aligns struct arrays to the expected nested field
46///   layout.
47pub struct NestedSchemaAligner<S> {
48    inner: S,
49    /// Output schema expected by the upper reader.
50    output_schema: SchemaRef,
51    /// Whether each projected root exists in the physical batch returned by
52    /// parquet.
53    projected_root_presence: Vec<bool>,
54    /// Number of columns expected from the physical batch returned by parquet.
55    expected_input_col_num: usize,
56    /// Whether all projected roots are present and the stream can pass batches
57    /// through.
58    all_roots_present: bool,
59    /// The cache for whether incoming batches already match output schema.
60    is_schema_matched: Option<bool>,
61}
62
63pub(crate) type ProjectedRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
64
65impl<S> NestedSchemaAligner<S>
66where
67    S: Stream<Item = Result<RecordBatch>>,
68{
69    pub fn new(
70        inner: S,
71        projected_root_presence: Vec<bool>,
72        output_schema: SchemaRef,
73    ) -> Result<NestedSchemaAligner<S>> {
74        ensure!(
75            projected_root_presence.len() == output_schema.fields().len(),
76            UnexpectedSnafu {
77                reason: format!(
78                    "NestedSchemaAligner projected root presence len {} does not match output schema columns {}",
79                    projected_root_presence.len(),
80                    output_schema.fields().len()
81                ),
82            }
83        );
84
85        let expected_input_col_num = projected_root_presence
86            .iter()
87            .filter(|matched| **matched)
88            .count();
89        let all_roots_present = projected_root_presence.iter().all(|&m| m);
90
91        Ok(NestedSchemaAligner {
92            inner,
93            output_schema,
94            projected_root_presence,
95            expected_input_col_num,
96            all_roots_present,
97            is_schema_matched: None,
98        })
99    }
100}
101
102impl<S> Stream for NestedSchemaAligner<S>
103where
104    S: Stream<Item = Result<RecordBatch>> + Unpin,
105{
106    type Item = Result<RecordBatch>;
107
108    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
109        let this = self.get_mut();
110
111        match Pin::new(&mut this.inner).poll_next(cx) {
112            Poll::Ready(Some(Ok(rb))) => {
113                let rb = if this.all_roots_present {
114                    rb
115                } else {
116                    fill_missing_cols(
117                        rb,
118                        &this.output_schema,
119                        &this.projected_root_presence,
120                        this.expected_input_col_num,
121                    )?
122                };
123
124                let is_schema_matched = *this
125                    .is_schema_matched
126                    .get_or_insert_with(|| rb.schema() == this.output_schema);
127
128                if is_schema_matched {
129                    Poll::Ready(Some(Ok(rb)))
130                } else {
131                    Poll::Ready(Some(align_batch_to_schema(rb, &this.output_schema)))
132                }
133            }
134            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
135            Poll::Ready(None) => Poll::Ready(None),
136            Poll::Pending => Poll::Pending,
137        }
138    }
139}
140
141fn fill_missing_cols(
142    rb: RecordBatch,
143    output_schema: &SchemaRef,
144    projected_root_matches: &[bool],
145    expected_input_col_num: usize,
146) -> Result<RecordBatch> {
147    ensure!(
148        rb.columns().len() == expected_input_col_num,
149        UnexpectedSnafu {
150            reason: format!(
151                "NestedSchemaAligner expected {} input columns but got {}",
152                expected_input_col_num,
153                rb.columns().len()
154            ),
155        }
156    );
157
158    let mut cols = Vec::with_capacity(projected_root_matches.len());
159    let mut idx = 0;
160
161    for (field, matched) in output_schema.fields().iter().zip(projected_root_matches) {
162        if *matched {
163            cols.push(rb.column(idx).clone());
164            idx += 1;
165        } else {
166            cols.push(new_null_array(field.data_type(), rb.num_rows()));
167        }
168    }
169
170    RecordBatch::try_new(output_schema.clone(), cols).context(NewRecordBatchSnafu)
171}
172
173fn align_batch_to_schema(rb: RecordBatch, output_schema: &SchemaRef) -> Result<RecordBatch> {
174    ensure!(
175        rb.num_columns() == output_schema.fields().len(),
176        UnexpectedSnafu {
177            reason: format!(
178                "NestedSchemaAligner expected {} columns but got {}",
179                output_schema.fields().len(),
180                rb.num_columns()
181            ),
182        }
183    );
184
185    let columns = rb
186        .columns()
187        .iter()
188        .zip(output_schema.fields())
189        .map(|(array, field)| align_array(array, field))
190        .collect::<Result<Vec<_>>>()?;
191
192    RecordBatch::try_new(output_schema.clone(), columns).context(NewRecordBatchSnafu)
193}
194
195fn align_array(array: &ArrayRef, field: &FieldRef) -> Result<ArrayRef> {
196    if array.data_type() == field.data_type() {
197        return Ok(array.clone());
198    }
199
200    if !matches!(field.data_type(), DataType::Struct(_)) {
201        return Ok(array.clone());
202    }
203
204    cast_column(array, field.as_ref(), &DEFAULT_CAST_OPTIONS).context(CastColumnSnafu)
205}
206
207#[cfg(test)]
208mod tests {
209    use std::sync::Arc;
210
211    use datatypes::arrow::array::{Array, ArrayRef, Int64Array, StringArray, StructArray};
212    use datatypes::arrow::datatypes::{DataType, Field, Fields, Schema};
213    use futures::{StreamExt, stream};
214
215    use super::*;
216
217    #[tokio::test]
218    async fn test_aligner_with_all_projected_roots_match() {
219        let output_schema = schema([
220            Field::new("a", DataType::Int64, true),
221            Field::new("b", DataType::Utf8, true),
222        ]);
223        let input = RecordBatch::try_new(
224            output_schema.clone(),
225            vec![int_array([1, 2, 3]), string_array(["x", "y", "z"])],
226        )
227        .unwrap();
228        let stream = stream::iter([Ok(input.clone())]);
229
230        let mut aligner =
231            NestedSchemaAligner::new(stream, vec![true, true], output_schema.clone()).unwrap();
232        let output = aligner.next().await.unwrap().unwrap();
233
234        assert_eq!(input, output);
235        assert!(aligner.next().await.is_none());
236    }
237
238    #[tokio::test]
239    async fn test_aligner_with_fills_null_root_columns() {
240        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
241        let output_schema = schema([
242            Field::new("a", DataType::Int64, true),
243            Field::new("missing", DataType::Utf8, true),
244            Field::new("c", DataType::Int64, true),
245        ]);
246        let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
247        let stream = stream::iter([Ok(input)]);
248
249        let mut aligner =
250            NestedSchemaAligner::new(stream, vec![true, false, false], output_schema.clone())
251                .unwrap();
252        let output = aligner.next().await.unwrap().unwrap();
253
254        assert_eq!(output_schema, output.schema());
255        assert_eq!(3, output.num_columns());
256        assert_eq!(
257            &[Some(10), Some(20)],
258            output
259                .column(0)
260                .as_any()
261                .downcast_ref::<Int64Array>()
262                .unwrap()
263                .iter()
264                .collect::<Vec<_>>()
265                .as_slice()
266        );
267        assert_eq!(DataType::Utf8, *output.column(1).data_type());
268        assert_eq!(output.num_rows(), output.column(1).null_count());
269        assert_eq!(DataType::Int64, *output.column(2).data_type());
270        assert_eq!(output.num_rows(), output.column(2).null_count());
271    }
272
273    #[tokio::test]
274    async fn test_aligner_with_fills_missing_struct_root_column() {
275        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
276        let struct_type = DataType::Struct(Fields::from(vec![
277            Field::new("x", DataType::Int64, true),
278            Field::new("y", DataType::Utf8, true),
279        ]));
280        let output_schema = schema([
281            Field::new("a", DataType::Int64, true),
282            Field::new("missing_struct", struct_type.clone(), true),
283        ]);
284        let input = RecordBatch::try_new(input_schema, vec![int_array([10, 20])]).unwrap();
285        let stream = stream::iter([Ok(input)]);
286
287        let mut aligner =
288            NestedSchemaAligner::new(stream, vec![true, false], output_schema.clone()).unwrap();
289        let output = aligner.next().await.unwrap().unwrap();
290
291        assert_eq!(output_schema, output.schema());
292        assert_eq!(2, output.num_columns());
293        assert_eq!(struct_type, output.column(1).data_type().clone());
294        assert_eq!(output.num_rows(), output.column(1).null_count());
295    }
296
297    #[tokio::test]
298    async fn test_aligner_reject_projection_len_mismatch() {
299        let output_schema = schema([Field::new("a", DataType::Int64, true)]);
300        let stream = stream::iter([]);
301
302        let err = match NestedSchemaAligner::new(stream, vec![true, false], output_schema) {
303            Ok(_) => panic!("NestedSchemaAligner should reject projection length mismatch"),
304            Err(err) => err,
305        };
306
307        assert!(
308            err.to_string()
309                .contains("projected root presence len 2 does not match output schema columns 1")
310        );
311    }
312
313    #[tokio::test]
314    async fn test_aligner_reject_with_input_column_mismatch() {
315        let input_schema = schema([Field::new("a", DataType::Int64, true)]);
316        let output_schema = schema([
317            Field::new("a", DataType::Int64, true),
318            Field::new("b", DataType::Int64, true),
319            Field::new("missing", DataType::Int64, true),
320        ]);
321        let input = RecordBatch::try_new(input_schema, vec![int_array([1, 2])]).unwrap();
322        let stream = stream::iter([Ok(input)]);
323
324        let mut aligner =
325            NestedSchemaAligner::new(stream, vec![true, true, false], output_schema).unwrap();
326        let err = aligner.next().await.unwrap().unwrap_err();
327
328        assert!(
329            err.to_string()
330                .contains("expected 2 input columns but got 1")
331        );
332    }
333
334    #[tokio::test]
335    async fn test_nested_schema_aligner_aligns_struct_field() {
336        let output_schema = schema([Field::new(
337            "nested",
338            DataType::Struct(Fields::from(vec![
339                Field::new("x", DataType::Int64, true),
340                Field::new("y", DataType::Utf8, true),
341            ])),
342            true,
343        )]);
344        let input = RecordBatch::try_new(
345            schema([Field::new(
346                "nested",
347                DataType::Struct(Fields::from(vec![Field::new("x", DataType::Int64, true)])),
348                true,
349            )]),
350            vec![Arc::new(StructArray::from(vec![(
351                Arc::new(Field::new("x", DataType::Int64, true)),
352                int_array([1, 2]),
353            )]))],
354        )
355        .unwrap();
356
357        let mut aligner =
358            NestedSchemaAligner::new(stream::iter([Ok(input)]), vec![true], output_schema.clone())
359                .unwrap();
360        let output = aligner.next().await.unwrap().unwrap();
361
362        assert_eq!(output_schema, output.schema());
363        let nested = output
364            .column(0)
365            .as_any()
366            .downcast_ref::<StructArray>()
367            .unwrap();
368        assert_eq!(2, nested.columns().len());
369        assert_eq!(2, nested.column(1).null_count());
370    }
371
372    fn schema(fields: impl IntoIterator<Item = Field>) -> SchemaRef {
373        Arc::new(Schema::new(fields.into_iter().collect::<Vec<_>>()))
374    }
375
376    fn int_array(values: impl IntoIterator<Item = i64>) -> ArrayRef {
377        Arc::new(Int64Array::from_iter_values(values))
378    }
379
380    fn string_array(values: impl IntoIterator<Item = &'static str>) -> ArrayRef {
381        Arc::new(StringArray::from_iter_values(values))
382    }
383}