common_recordbatch/
adapter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use datafusion::arrow::compute::cast;
23use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
24use datafusion::error::Result as DfResult;
25use datafusion::execution::context::ExecutionProps;
26use datafusion::logical_expr::utils::conjunction;
27use datafusion::logical_expr::Expr;
28use datafusion::physical_expr::create_physical_expr;
29use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue};
30use datafusion::physical_plan::{
31    accept, DisplayFormatType, ExecutionPlan, ExecutionPlanVisitor, PhysicalExpr,
32    RecordBatchStream as DfRecordBatchStream,
33};
34use datafusion_common::arrow::error::ArrowError;
35use datafusion_common::{DataFusionError, ToDFSchema};
36use datatypes::arrow::array::Array;
37use datatypes::schema::{Schema, SchemaRef};
38use futures::ready;
39use pin_project::pin_project;
40use snafu::ResultExt;
41
42use crate::error::{self, Result};
43use crate::filter::batch_filter;
44use crate::{
45    DfRecordBatch, DfSendableRecordBatchStream, OrderOption, RecordBatch, RecordBatchStream,
46    SendableRecordBatchStream, Stream,
47};
48
49type FutureStream =
50    Pin<Box<dyn std::future::Future<Output = Result<SendableRecordBatchStream>> + Send>>;
51
52/// Casts the `RecordBatch`es of `stream` against the `output_schema`.
53#[pin_project]
54pub struct RecordBatchStreamTypeAdapter<T, E> {
55    #[pin]
56    stream: T,
57    projected_schema: DfSchemaRef,
58    projection: Vec<usize>,
59    predicate: Option<Arc<dyn PhysicalExpr>>,
60    phantom: PhantomData<E>,
61}
62
63impl<T, E> RecordBatchStreamTypeAdapter<T, E>
64where
65    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
66    E: std::error::Error + Send + Sync + 'static,
67{
68    pub fn new(projected_schema: DfSchemaRef, stream: T, projection: Option<Vec<usize>>) -> Self {
69        let projection = if let Some(projection) = projection {
70            projection
71        } else {
72            (0..projected_schema.fields().len()).collect()
73        };
74
75        Self {
76            stream,
77            projected_schema,
78            projection,
79            predicate: None,
80            phantom: Default::default(),
81        }
82    }
83
84    pub fn with_filter(mut self, filters: Vec<Expr>) -> Result<Self> {
85        let filters = if let Some(expr) = conjunction(filters) {
86            let df_schema = self
87                .projected_schema
88                .clone()
89                .to_dfschema_ref()
90                .context(error::PhysicalExprSnafu)?;
91
92            let filters = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())
93                .context(error::PhysicalExprSnafu)?;
94            Some(filters)
95        } else {
96            None
97        };
98        self.predicate = filters;
99        Ok(self)
100    }
101}
102
103impl<T, E> DfRecordBatchStream for RecordBatchStreamTypeAdapter<T, E>
104where
105    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
106    E: std::error::Error + Send + Sync + 'static,
107{
108    fn schema(&self) -> DfSchemaRef {
109        self.projected_schema.clone()
110    }
111}
112
113impl<T, E> Stream for RecordBatchStreamTypeAdapter<T, E>
114where
115    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
116    E: std::error::Error + Send + Sync + 'static,
117{
118    type Item = DfResult<DfRecordBatch>;
119
120    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121        let this = self.project();
122
123        let batch = futures::ready!(this.stream.poll_next(cx))
124            .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));
125
126        let projected_schema = this.projected_schema.clone();
127        let projection = this.projection.clone();
128        let predicate = this.predicate.clone();
129
130        let batch = batch.map(|b| {
131            b.and_then(|b| {
132                let projected_column = b.project(&projection)?;
133                if projected_column.schema().fields.len() != projected_schema.fields.len() {
134                   return Err(DataFusionError::ArrowError(ArrowError::SchemaError(format!(
135                        "Trying to cast a RecordBatch into an incompatible schema. RecordBatch: {}, Target: {}",
136                        projected_column.schema(),
137                        projected_schema,
138                    )), None));
139                }
140
141                let mut columns = Vec::with_capacity(projected_schema.fields.len());
142                for (idx,field) in projected_schema.fields.iter().enumerate() {
143                    let column = projected_column.column(idx);
144                    if column.data_type() != field.data_type() {
145                        let output = cast(&column, field.data_type())?;
146                        columns.push(output)
147                    } else {
148                        columns.push(column.clone())
149                    }
150                }
151                let record_batch = DfRecordBatch::try_new(projected_schema, columns)?;
152                let record_batch = if let Some(predicate) = predicate {
153                    batch_filter(&record_batch, &predicate)?
154                } else {
155                    record_batch
156                };
157                Ok(record_batch)
158            })
159        });
160
161        Poll::Ready(batch)
162    }
163
164    #[inline]
165    fn size_hint(&self) -> (usize, Option<usize>) {
166        self.stream.size_hint()
167    }
168}
169
170/// Greptime SendableRecordBatchStream -> DataFusion RecordBatchStream.
171/// The reverse one is [RecordBatchStreamAdapter].
172pub struct DfRecordBatchStreamAdapter {
173    stream: SendableRecordBatchStream,
174}
175
176impl DfRecordBatchStreamAdapter {
177    pub fn new(stream: SendableRecordBatchStream) -> Self {
178        Self { stream }
179    }
180}
181
182impl DfRecordBatchStream for DfRecordBatchStreamAdapter {
183    fn schema(&self) -> DfSchemaRef {
184        self.stream.schema().arrow_schema().clone()
185    }
186}
187
188impl Stream for DfRecordBatchStreamAdapter {
189    type Item = DfResult<DfRecordBatch>;
190
191    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        match Pin::new(&mut self.stream).poll_next(cx) {
193            Poll::Pending => Poll::Pending,
194            Poll::Ready(Some(recordbatch)) => match recordbatch {
195                Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))),
196                Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))),
197            },
198            Poll::Ready(None) => Poll::Ready(None),
199        }
200    }
201
202    #[inline]
203    fn size_hint(&self) -> (usize, Option<usize>) {
204        self.stream.size_hint()
205    }
206}
207
208/// DataFusion [SendableRecordBatchStream](DfSendableRecordBatchStream) -> Greptime [RecordBatchStream].
209/// The reverse one is [DfRecordBatchStreamAdapter].
210/// It can collect metrics from DataFusion execution plan.
211pub struct RecordBatchStreamAdapter {
212    schema: SchemaRef,
213    stream: DfSendableRecordBatchStream,
214    metrics: Option<BaselineMetrics>,
215    /// Aggregated plan-level metrics. Resolved after an [ExecutionPlan] is finished.
216    metrics_2: Metrics,
217    /// Display plan and metrics in verbose mode.
218    explain_verbose: bool,
219}
220
221/// Json encoded metrics. Contains metric from a whole plan tree.
222enum Metrics {
223    Unavailable,
224    Unresolved(Arc<dyn ExecutionPlan>),
225    Resolved(RecordBatchMetrics),
226}
227
228impl RecordBatchStreamAdapter {
229    pub fn try_new(stream: DfSendableRecordBatchStream) -> Result<Self> {
230        let schema =
231            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
232        Ok(Self {
233            schema,
234            stream,
235            metrics: None,
236            metrics_2: Metrics::Unavailable,
237            explain_verbose: false,
238        })
239    }
240
241    pub fn try_new_with_metrics_and_df_plan(
242        stream: DfSendableRecordBatchStream,
243        metrics: BaselineMetrics,
244        df_plan: Arc<dyn ExecutionPlan>,
245    ) -> Result<Self> {
246        let schema =
247            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
248        Ok(Self {
249            schema,
250            stream,
251            metrics: Some(metrics),
252            metrics_2: Metrics::Unresolved(df_plan),
253            explain_verbose: false,
254        })
255    }
256
257    pub fn set_metrics2(&mut self, plan: Arc<dyn ExecutionPlan>) {
258        self.metrics_2 = Metrics::Unresolved(plan)
259    }
260
261    /// Set the verbose mode for displaying plan and metrics.
262    pub fn set_explain_verbose(&mut self, verbose: bool) {
263        self.explain_verbose = verbose;
264    }
265}
266
267impl RecordBatchStream for RecordBatchStreamAdapter {
268    fn name(&self) -> &str {
269        "RecordBatchStreamAdapter"
270    }
271
272    fn schema(&self) -> SchemaRef {
273        self.schema.clone()
274    }
275
276    fn metrics(&self) -> Option<RecordBatchMetrics> {
277        match &self.metrics_2 {
278            Metrics::Resolved(metrics) => Some(metrics.clone()),
279            Metrics::Unavailable | Metrics::Unresolved(_) => None,
280        }
281    }
282
283    fn output_ordering(&self) -> Option<&[OrderOption]> {
284        None
285    }
286}
287
288impl Stream for RecordBatchStreamAdapter {
289    type Item = Result<RecordBatch>;
290
291    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292        let timer = self
293            .metrics
294            .as_ref()
295            .map(|m| m.elapsed_compute().clone())
296            .unwrap_or_default();
297        let _guard = timer.timer();
298        match Pin::new(&mut self.stream).poll_next(cx) {
299            Poll::Pending => Poll::Pending,
300            Poll::Ready(Some(df_record_batch)) => {
301                let df_record_batch = df_record_batch?;
302                Poll::Ready(Some(RecordBatch::try_from_df_record_batch(
303                    self.schema(),
304                    df_record_batch,
305                )))
306            }
307            Poll::Ready(None) => {
308                if let Metrics::Unresolved(df_plan) = &self.metrics_2 {
309                    let mut metric_collector = MetricCollector::new(self.explain_verbose);
310                    accept(df_plan.as_ref(), &mut metric_collector).unwrap();
311                    self.metrics_2 = Metrics::Resolved(metric_collector.record_batch_metrics);
312                }
313                Poll::Ready(None)
314            }
315        }
316    }
317
318    #[inline]
319    fn size_hint(&self) -> (usize, Option<usize>) {
320        self.stream.size_hint()
321    }
322}
323
324/// An [ExecutionPlanVisitor] to collect metrics from a [ExecutionPlan].
325pub struct MetricCollector {
326    current_level: usize,
327    pub record_batch_metrics: RecordBatchMetrics,
328    verbose: bool,
329}
330
331impl MetricCollector {
332    pub fn new(verbose: bool) -> Self {
333        Self {
334            current_level: 0,
335            record_batch_metrics: RecordBatchMetrics::default(),
336            verbose,
337        }
338    }
339}
340
341impl ExecutionPlanVisitor for MetricCollector {
342    type Error = !;
343
344    fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
345        // skip if no metric available
346        let Some(metric) = plan.metrics() else {
347            self.record_batch_metrics.plan_metrics.push(PlanMetrics {
348                plan: std::any::type_name::<Self>().to_string(),
349                level: self.current_level,
350                metrics: vec![],
351            });
352            self.current_level += 1;
353            return Ok(true);
354        };
355
356        // scrape plan metrics
357        let metric = metric
358            .aggregate_by_name()
359            .sorted_for_display()
360            .timestamps_removed();
361        let mut plan_metric = PlanMetrics {
362            plan: one_line(plan, self.verbose).to_string(),
363            level: self.current_level,
364            metrics: Vec::with_capacity(metric.iter().size_hint().0),
365        };
366        for m in metric.iter() {
367            plan_metric
368                .metrics
369                .push((m.value().name().to_string(), m.value().as_usize()));
370
371            // aggregate high-level metrics
372            match m.value() {
373                MetricValue::ElapsedCompute(ec) => {
374                    self.record_batch_metrics.elapsed_compute += ec.value()
375                }
376                MetricValue::CurrentMemoryUsage(m) => {
377                    self.record_batch_metrics.memory_usage += m.value()
378                }
379                _ => {}
380            }
381        }
382        self.record_batch_metrics.plan_metrics.push(plan_metric);
383
384        self.current_level += 1;
385        Ok(true)
386    }
387
388    fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
389        self.current_level -= 1;
390        Ok(true)
391    }
392}
393
394/// Returns a single-line summary of the root of the plan.
395/// If the `verbose` flag is set, it will display detailed information about the plan.
396fn one_line(plan: &dyn ExecutionPlan, verbose: bool) -> impl fmt::Display + '_ {
397    struct Wrapper<'a> {
398        plan: &'a dyn ExecutionPlan,
399        format_type: DisplayFormatType,
400    }
401
402    impl fmt::Display for Wrapper<'_> {
403        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404            self.plan.fmt_as(self.format_type, f)?;
405            writeln!(f)
406        }
407    }
408
409    let format_type = if verbose {
410        DisplayFormatType::Verbose
411    } else {
412        DisplayFormatType::Default
413    };
414    Wrapper { plan, format_type }
415}
416
417/// [`RecordBatchMetrics`] carrys metrics value
418/// from datanode to frontend through gRPC
419#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
420pub struct RecordBatchMetrics {
421    // High-level aggregated metrics
422    /// CPU consumption in nanoseconds
423    pub elapsed_compute: usize,
424    /// Memory used by the plan in bytes
425    pub memory_usage: usize,
426    // Detailed per-plan metrics
427    /// An ordered list of plan metrics, from top to bottom in post-order.
428    pub plan_metrics: Vec<PlanMetrics>,
429}
430
431/// Only display `plan_metrics` with indent `  ` (2 spaces).
432impl Display for RecordBatchMetrics {
433    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434        for metric in &self.plan_metrics {
435            write!(
436                f,
437                "{:indent$}{} metrics=[",
438                " ",
439                metric.plan.trim_end(),
440                indent = metric.level * 2,
441            )?;
442            for (label, value) in &metric.metrics {
443                write!(f, "{}: {}, ", label, value)?;
444            }
445            writeln!(f, "]")?;
446        }
447
448        Ok(())
449    }
450}
451
452#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
453pub struct PlanMetrics {
454    /// The plan name
455    pub plan: String,
456    /// The level of the plan, starts from 0
457    pub level: usize,
458    /// An ordered key-value list of metrics.
459    /// Key is metric label and value is metric value.
460    pub metrics: Vec<(String, usize)>,
461}
462
463enum AsyncRecordBatchStreamAdapterState {
464    Uninit(FutureStream),
465    Ready(SendableRecordBatchStream),
466    Failed,
467}
468
469pub struct AsyncRecordBatchStreamAdapter {
470    schema: SchemaRef,
471    state: AsyncRecordBatchStreamAdapterState,
472}
473
474impl AsyncRecordBatchStreamAdapter {
475    pub fn new(schema: SchemaRef, stream: FutureStream) -> Self {
476        Self {
477            schema,
478            state: AsyncRecordBatchStreamAdapterState::Uninit(stream),
479        }
480    }
481}
482
483impl RecordBatchStream for AsyncRecordBatchStreamAdapter {
484    fn schema(&self) -> SchemaRef {
485        self.schema.clone()
486    }
487
488    fn output_ordering(&self) -> Option<&[OrderOption]> {
489        None
490    }
491
492    fn metrics(&self) -> Option<RecordBatchMetrics> {
493        None
494    }
495}
496
497impl Stream for AsyncRecordBatchStreamAdapter {
498    type Item = Result<RecordBatch>;
499
500    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
501        loop {
502            match &mut self.state {
503                AsyncRecordBatchStreamAdapterState::Uninit(stream_future) => {
504                    match ready!(Pin::new(stream_future).poll(cx)) {
505                        Ok(stream) => {
506                            self.state = AsyncRecordBatchStreamAdapterState::Ready(stream);
507                            continue;
508                        }
509                        Err(e) => {
510                            self.state = AsyncRecordBatchStreamAdapterState::Failed;
511                            return Poll::Ready(Some(Err(e)));
512                        }
513                    };
514                }
515                AsyncRecordBatchStreamAdapterState::Ready(stream) => {
516                    return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)))
517                }
518                AsyncRecordBatchStreamAdapterState::Failed => return Poll::Ready(None),
519            }
520        }
521    }
522
523    // This is not supported for lazy stream.
524    #[inline]
525    fn size_hint(&self) -> (usize, Option<usize>) {
526        (0, None)
527    }
528}
529
530#[cfg(test)]
531mod test {
532    use common_error::ext::BoxedError;
533    use common_error::mock::MockError;
534    use common_error::status_code::StatusCode;
535    use datatypes::prelude::ConcreteDataType;
536    use datatypes::schema::ColumnSchema;
537    use datatypes::vectors::Int32Vector;
538    use snafu::IntoError;
539
540    use super::*;
541    use crate::error::Error;
542    use crate::RecordBatches;
543
544    #[tokio::test]
545    async fn test_async_recordbatch_stream_adaptor() {
546        struct MaybeErrorRecordBatchStream {
547            items: Vec<Result<RecordBatch>>,
548        }
549
550        impl RecordBatchStream for MaybeErrorRecordBatchStream {
551            fn schema(&self) -> SchemaRef {
552                unimplemented!()
553            }
554
555            fn output_ordering(&self) -> Option<&[OrderOption]> {
556                None
557            }
558
559            fn metrics(&self) -> Option<RecordBatchMetrics> {
560                None
561            }
562        }
563
564        impl Stream for MaybeErrorRecordBatchStream {
565            type Item = Result<RecordBatch>;
566
567            fn poll_next(
568                mut self: Pin<&mut Self>,
569                _: &mut Context<'_>,
570            ) -> Poll<Option<Self::Item>> {
571                if let Some(batch) = self.items.pop() {
572                    Poll::Ready(Some(Ok(batch?)))
573                } else {
574                    Poll::Ready(None)
575                }
576            }
577        }
578
579        fn new_future_stream(
580            maybe_recordbatches: Result<Vec<Result<RecordBatch>>>,
581        ) -> FutureStream {
582            Box::pin(async move {
583                maybe_recordbatches
584                    .map(|items| Box::pin(MaybeErrorRecordBatchStream { items }) as _)
585            })
586        }
587
588        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
589            "a",
590            ConcreteDataType::int32_datatype(),
591            false,
592        )]));
593        let batch1 = RecordBatch::new(
594            schema.clone(),
595            vec![Arc::new(Int32Vector::from_slice([1])) as _],
596        )
597        .unwrap();
598        let batch2 = RecordBatch::new(
599            schema.clone(),
600            vec![Arc::new(Int32Vector::from_slice([2])) as _],
601        )
602        .unwrap();
603
604        let success_stream = new_future_stream(Ok(vec![Ok(batch1.clone()), Ok(batch2.clone())]));
605        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), success_stream);
606        let collected = RecordBatches::try_collect(Box::pin(adapter)).await.unwrap();
607        assert_eq!(
608            collected,
609            RecordBatches::try_new(schema.clone(), vec![batch2.clone(), batch1.clone()]).unwrap()
610        );
611
612        let poll_err_stream = new_future_stream(Ok(vec![
613            Ok(batch1.clone()),
614            Err(error::ExternalSnafu
615                .into_error(BoxedError::new(MockError::new(StatusCode::Unknown)))),
616        ]));
617        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), poll_err_stream);
618        let err = RecordBatches::try_collect(Box::pin(adapter))
619            .await
620            .unwrap_err();
621        assert!(
622            matches!(err, Error::External { .. }),
623            "unexpected err {err}"
624        );
625
626        let failed_to_init_stream =
627            new_future_stream(Err(error::ExternalSnafu
628                .into_error(BoxedError::new(MockError::new(StatusCode::Internal)))));
629        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), failed_to_init_stream);
630        let err = RecordBatches::try_collect(Box::pin(adapter))
631            .await
632            .unwrap_err();
633        assert!(
634            matches!(err, Error::External { .. }),
635            "unexpected err {err}"
636        );
637    }
638}