1use std::fmt::{self, Display};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use datafusion::arrow::compute::cast;
23use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
24use datafusion::error::Result as DfResult;
25use datafusion::execution::context::ExecutionProps;
26use datafusion::logical_expr::utils::conjunction;
27use datafusion::logical_expr::Expr;
28use datafusion::physical_expr::create_physical_expr;
29use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue};
30use datafusion::physical_plan::{
31 accept, DisplayFormatType, ExecutionPlan, ExecutionPlanVisitor, PhysicalExpr,
32 RecordBatchStream as DfRecordBatchStream,
33};
34use datafusion_common::arrow::error::ArrowError;
35use datafusion_common::{DataFusionError, ToDFSchema};
36use datatypes::arrow::array::Array;
37use datatypes::schema::{Schema, SchemaRef};
38use futures::ready;
39use pin_project::pin_project;
40use snafu::ResultExt;
41
42use crate::error::{self, Result};
43use crate::filter::batch_filter;
44use crate::{
45 DfRecordBatch, DfSendableRecordBatchStream, OrderOption, RecordBatch, RecordBatchStream,
46 SendableRecordBatchStream, Stream,
47};
48
49type FutureStream =
50 Pin<Box<dyn std::future::Future<Output = Result<SendableRecordBatchStream>> + Send>>;
51
52#[pin_project]
54pub struct RecordBatchStreamTypeAdapter<T, E> {
55 #[pin]
56 stream: T,
57 projected_schema: DfSchemaRef,
58 projection: Vec<usize>,
59 predicate: Option<Arc<dyn PhysicalExpr>>,
60 phantom: PhantomData<E>,
61}
62
63impl<T, E> RecordBatchStreamTypeAdapter<T, E>
64where
65 T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
66 E: std::error::Error + Send + Sync + 'static,
67{
68 pub fn new(projected_schema: DfSchemaRef, stream: T, projection: Option<Vec<usize>>) -> Self {
69 let projection = if let Some(projection) = projection {
70 projection
71 } else {
72 (0..projected_schema.fields().len()).collect()
73 };
74
75 Self {
76 stream,
77 projected_schema,
78 projection,
79 predicate: None,
80 phantom: Default::default(),
81 }
82 }
83
84 pub fn with_filter(mut self, filters: Vec<Expr>) -> Result<Self> {
85 let filters = if let Some(expr) = conjunction(filters) {
86 let df_schema = self
87 .projected_schema
88 .clone()
89 .to_dfschema_ref()
90 .context(error::PhysicalExprSnafu)?;
91
92 let filters = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())
93 .context(error::PhysicalExprSnafu)?;
94 Some(filters)
95 } else {
96 None
97 };
98 self.predicate = filters;
99 Ok(self)
100 }
101}
102
103impl<T, E> DfRecordBatchStream for RecordBatchStreamTypeAdapter<T, E>
104where
105 T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
106 E: std::error::Error + Send + Sync + 'static,
107{
108 fn schema(&self) -> DfSchemaRef {
109 self.projected_schema.clone()
110 }
111}
112
113impl<T, E> Stream for RecordBatchStreamTypeAdapter<T, E>
114where
115 T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
116 E: std::error::Error + Send + Sync + 'static,
117{
118 type Item = DfResult<DfRecordBatch>;
119
120 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121 let this = self.project();
122
123 let batch = futures::ready!(this.stream.poll_next(cx))
124 .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));
125
126 let projected_schema = this.projected_schema.clone();
127 let projection = this.projection.clone();
128 let predicate = this.predicate.clone();
129
130 let batch = batch.map(|b| {
131 b.and_then(|b| {
132 let projected_column = b.project(&projection)?;
133 if projected_column.schema().fields.len() != projected_schema.fields.len() {
134 return Err(DataFusionError::ArrowError(ArrowError::SchemaError(format!(
135 "Trying to cast a RecordBatch into an incompatible schema. RecordBatch: {}, Target: {}",
136 projected_column.schema(),
137 projected_schema,
138 )), None));
139 }
140
141 let mut columns = Vec::with_capacity(projected_schema.fields.len());
142 for (idx,field) in projected_schema.fields.iter().enumerate() {
143 let column = projected_column.column(idx);
144 if column.data_type() != field.data_type() {
145 let output = cast(&column, field.data_type())?;
146 columns.push(output)
147 } else {
148 columns.push(column.clone())
149 }
150 }
151 let record_batch = DfRecordBatch::try_new(projected_schema, columns)?;
152 let record_batch = if let Some(predicate) = predicate {
153 batch_filter(&record_batch, &predicate)?
154 } else {
155 record_batch
156 };
157 Ok(record_batch)
158 })
159 });
160
161 Poll::Ready(batch)
162 }
163
164 #[inline]
165 fn size_hint(&self) -> (usize, Option<usize>) {
166 self.stream.size_hint()
167 }
168}
169
170pub struct DfRecordBatchStreamAdapter {
173 stream: SendableRecordBatchStream,
174}
175
176impl DfRecordBatchStreamAdapter {
177 pub fn new(stream: SendableRecordBatchStream) -> Self {
178 Self { stream }
179 }
180}
181
182impl DfRecordBatchStream for DfRecordBatchStreamAdapter {
183 fn schema(&self) -> DfSchemaRef {
184 self.stream.schema().arrow_schema().clone()
185 }
186}
187
188impl Stream for DfRecordBatchStreamAdapter {
189 type Item = DfResult<DfRecordBatch>;
190
191 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 match Pin::new(&mut self.stream).poll_next(cx) {
193 Poll::Pending => Poll::Pending,
194 Poll::Ready(Some(recordbatch)) => match recordbatch {
195 Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))),
196 Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))),
197 },
198 Poll::Ready(None) => Poll::Ready(None),
199 }
200 }
201
202 #[inline]
203 fn size_hint(&self) -> (usize, Option<usize>) {
204 self.stream.size_hint()
205 }
206}
207
208pub struct RecordBatchStreamAdapter {
212 schema: SchemaRef,
213 stream: DfSendableRecordBatchStream,
214 metrics: Option<BaselineMetrics>,
215 metrics_2: Metrics,
217 explain_verbose: bool,
219}
220
221enum Metrics {
223 Unavailable,
224 Unresolved(Arc<dyn ExecutionPlan>),
225 Resolved(RecordBatchMetrics),
226}
227
228impl RecordBatchStreamAdapter {
229 pub fn try_new(stream: DfSendableRecordBatchStream) -> Result<Self> {
230 let schema =
231 Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
232 Ok(Self {
233 schema,
234 stream,
235 metrics: None,
236 metrics_2: Metrics::Unavailable,
237 explain_verbose: false,
238 })
239 }
240
241 pub fn try_new_with_metrics_and_df_plan(
242 stream: DfSendableRecordBatchStream,
243 metrics: BaselineMetrics,
244 df_plan: Arc<dyn ExecutionPlan>,
245 ) -> Result<Self> {
246 let schema =
247 Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
248 Ok(Self {
249 schema,
250 stream,
251 metrics: Some(metrics),
252 metrics_2: Metrics::Unresolved(df_plan),
253 explain_verbose: false,
254 })
255 }
256
257 pub fn set_metrics2(&mut self, plan: Arc<dyn ExecutionPlan>) {
258 self.metrics_2 = Metrics::Unresolved(plan)
259 }
260
261 pub fn set_explain_verbose(&mut self, verbose: bool) {
263 self.explain_verbose = verbose;
264 }
265}
266
267impl RecordBatchStream for RecordBatchStreamAdapter {
268 fn name(&self) -> &str {
269 "RecordBatchStreamAdapter"
270 }
271
272 fn schema(&self) -> SchemaRef {
273 self.schema.clone()
274 }
275
276 fn metrics(&self) -> Option<RecordBatchMetrics> {
277 match &self.metrics_2 {
278 Metrics::Resolved(metrics) => Some(metrics.clone()),
279 Metrics::Unavailable | Metrics::Unresolved(_) => None,
280 }
281 }
282
283 fn output_ordering(&self) -> Option<&[OrderOption]> {
284 None
285 }
286}
287
288impl Stream for RecordBatchStreamAdapter {
289 type Item = Result<RecordBatch>;
290
291 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292 let timer = self
293 .metrics
294 .as_ref()
295 .map(|m| m.elapsed_compute().clone())
296 .unwrap_or_default();
297 let _guard = timer.timer();
298 match Pin::new(&mut self.stream).poll_next(cx) {
299 Poll::Pending => Poll::Pending,
300 Poll::Ready(Some(df_record_batch)) => {
301 let df_record_batch = df_record_batch?;
302 Poll::Ready(Some(RecordBatch::try_from_df_record_batch(
303 self.schema(),
304 df_record_batch,
305 )))
306 }
307 Poll::Ready(None) => {
308 if let Metrics::Unresolved(df_plan) = &self.metrics_2 {
309 let mut metric_collector = MetricCollector::new(self.explain_verbose);
310 accept(df_plan.as_ref(), &mut metric_collector).unwrap();
311 self.metrics_2 = Metrics::Resolved(metric_collector.record_batch_metrics);
312 }
313 Poll::Ready(None)
314 }
315 }
316 }
317
318 #[inline]
319 fn size_hint(&self) -> (usize, Option<usize>) {
320 self.stream.size_hint()
321 }
322}
323
324pub struct MetricCollector {
326 current_level: usize,
327 pub record_batch_metrics: RecordBatchMetrics,
328 verbose: bool,
329}
330
331impl MetricCollector {
332 pub fn new(verbose: bool) -> Self {
333 Self {
334 current_level: 0,
335 record_batch_metrics: RecordBatchMetrics::default(),
336 verbose,
337 }
338 }
339}
340
341impl ExecutionPlanVisitor for MetricCollector {
342 type Error = !;
343
344 fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
345 let Some(metric) = plan.metrics() else {
347 self.record_batch_metrics.plan_metrics.push(PlanMetrics {
348 plan: std::any::type_name::<Self>().to_string(),
349 level: self.current_level,
350 metrics: vec![],
351 });
352 self.current_level += 1;
353 return Ok(true);
354 };
355
356 let metric = metric
358 .aggregate_by_name()
359 .sorted_for_display()
360 .timestamps_removed();
361 let mut plan_metric = PlanMetrics {
362 plan: one_line(plan, self.verbose).to_string(),
363 level: self.current_level,
364 metrics: Vec::with_capacity(metric.iter().size_hint().0),
365 };
366 for m in metric.iter() {
367 plan_metric
368 .metrics
369 .push((m.value().name().to_string(), m.value().as_usize()));
370
371 match m.value() {
373 MetricValue::ElapsedCompute(ec) => {
374 self.record_batch_metrics.elapsed_compute += ec.value()
375 }
376 MetricValue::CurrentMemoryUsage(m) => {
377 self.record_batch_metrics.memory_usage += m.value()
378 }
379 _ => {}
380 }
381 }
382 self.record_batch_metrics.plan_metrics.push(plan_metric);
383
384 self.current_level += 1;
385 Ok(true)
386 }
387
388 fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
389 self.current_level -= 1;
390 Ok(true)
391 }
392}
393
394fn one_line(plan: &dyn ExecutionPlan, verbose: bool) -> impl fmt::Display + '_ {
397 struct Wrapper<'a> {
398 plan: &'a dyn ExecutionPlan,
399 format_type: DisplayFormatType,
400 }
401
402 impl fmt::Display for Wrapper<'_> {
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 self.plan.fmt_as(self.format_type, f)?;
405 writeln!(f)
406 }
407 }
408
409 let format_type = if verbose {
410 DisplayFormatType::Verbose
411 } else {
412 DisplayFormatType::Default
413 };
414 Wrapper { plan, format_type }
415}
416
417#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
420pub struct RecordBatchMetrics {
421 pub elapsed_compute: usize,
424 pub memory_usage: usize,
426 pub plan_metrics: Vec<PlanMetrics>,
429}
430
431impl Display for RecordBatchMetrics {
433 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434 for metric in &self.plan_metrics {
435 write!(
436 f,
437 "{:indent$}{} metrics=[",
438 " ",
439 metric.plan.trim_end(),
440 indent = metric.level * 2,
441 )?;
442 for (label, value) in &metric.metrics {
443 write!(f, "{}: {}, ", label, value)?;
444 }
445 writeln!(f, "]")?;
446 }
447
448 Ok(())
449 }
450}
451
452#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
453pub struct PlanMetrics {
454 pub plan: String,
456 pub level: usize,
458 pub metrics: Vec<(String, usize)>,
461}
462
463enum AsyncRecordBatchStreamAdapterState {
464 Uninit(FutureStream),
465 Ready(SendableRecordBatchStream),
466 Failed,
467}
468
469pub struct AsyncRecordBatchStreamAdapter {
470 schema: SchemaRef,
471 state: AsyncRecordBatchStreamAdapterState,
472}
473
474impl AsyncRecordBatchStreamAdapter {
475 pub fn new(schema: SchemaRef, stream: FutureStream) -> Self {
476 Self {
477 schema,
478 state: AsyncRecordBatchStreamAdapterState::Uninit(stream),
479 }
480 }
481}
482
483impl RecordBatchStream for AsyncRecordBatchStreamAdapter {
484 fn schema(&self) -> SchemaRef {
485 self.schema.clone()
486 }
487
488 fn output_ordering(&self) -> Option<&[OrderOption]> {
489 None
490 }
491
492 fn metrics(&self) -> Option<RecordBatchMetrics> {
493 None
494 }
495}
496
497impl Stream for AsyncRecordBatchStreamAdapter {
498 type Item = Result<RecordBatch>;
499
500 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
501 loop {
502 match &mut self.state {
503 AsyncRecordBatchStreamAdapterState::Uninit(stream_future) => {
504 match ready!(Pin::new(stream_future).poll(cx)) {
505 Ok(stream) => {
506 self.state = AsyncRecordBatchStreamAdapterState::Ready(stream);
507 continue;
508 }
509 Err(e) => {
510 self.state = AsyncRecordBatchStreamAdapterState::Failed;
511 return Poll::Ready(Some(Err(e)));
512 }
513 };
514 }
515 AsyncRecordBatchStreamAdapterState::Ready(stream) => {
516 return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)))
517 }
518 AsyncRecordBatchStreamAdapterState::Failed => return Poll::Ready(None),
519 }
520 }
521 }
522
523 #[inline]
525 fn size_hint(&self) -> (usize, Option<usize>) {
526 (0, None)
527 }
528}
529
530#[cfg(test)]
531mod test {
532 use common_error::ext::BoxedError;
533 use common_error::mock::MockError;
534 use common_error::status_code::StatusCode;
535 use datatypes::prelude::ConcreteDataType;
536 use datatypes::schema::ColumnSchema;
537 use datatypes::vectors::Int32Vector;
538 use snafu::IntoError;
539
540 use super::*;
541 use crate::error::Error;
542 use crate::RecordBatches;
543
544 #[tokio::test]
545 async fn test_async_recordbatch_stream_adaptor() {
546 struct MaybeErrorRecordBatchStream {
547 items: Vec<Result<RecordBatch>>,
548 }
549
550 impl RecordBatchStream for MaybeErrorRecordBatchStream {
551 fn schema(&self) -> SchemaRef {
552 unimplemented!()
553 }
554
555 fn output_ordering(&self) -> Option<&[OrderOption]> {
556 None
557 }
558
559 fn metrics(&self) -> Option<RecordBatchMetrics> {
560 None
561 }
562 }
563
564 impl Stream for MaybeErrorRecordBatchStream {
565 type Item = Result<RecordBatch>;
566
567 fn poll_next(
568 mut self: Pin<&mut Self>,
569 _: &mut Context<'_>,
570 ) -> Poll<Option<Self::Item>> {
571 if let Some(batch) = self.items.pop() {
572 Poll::Ready(Some(Ok(batch?)))
573 } else {
574 Poll::Ready(None)
575 }
576 }
577 }
578
579 fn new_future_stream(
580 maybe_recordbatches: Result<Vec<Result<RecordBatch>>>,
581 ) -> FutureStream {
582 Box::pin(async move {
583 maybe_recordbatches
584 .map(|items| Box::pin(MaybeErrorRecordBatchStream { items }) as _)
585 })
586 }
587
588 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
589 "a",
590 ConcreteDataType::int32_datatype(),
591 false,
592 )]));
593 let batch1 = RecordBatch::new(
594 schema.clone(),
595 vec![Arc::new(Int32Vector::from_slice([1])) as _],
596 )
597 .unwrap();
598 let batch2 = RecordBatch::new(
599 schema.clone(),
600 vec![Arc::new(Int32Vector::from_slice([2])) as _],
601 )
602 .unwrap();
603
604 let success_stream = new_future_stream(Ok(vec![Ok(batch1.clone()), Ok(batch2.clone())]));
605 let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), success_stream);
606 let collected = RecordBatches::try_collect(Box::pin(adapter)).await.unwrap();
607 assert_eq!(
608 collected,
609 RecordBatches::try_new(schema.clone(), vec![batch2.clone(), batch1.clone()]).unwrap()
610 );
611
612 let poll_err_stream = new_future_stream(Ok(vec![
613 Ok(batch1.clone()),
614 Err(error::ExternalSnafu
615 .into_error(BoxedError::new(MockError::new(StatusCode::Unknown)))),
616 ]));
617 let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), poll_err_stream);
618 let err = RecordBatches::try_collect(Box::pin(adapter))
619 .await
620 .unwrap_err();
621 assert!(
622 matches!(err, Error::External { .. }),
623 "unexpected err {err}"
624 );
625
626 let failed_to_init_stream =
627 new_future_stream(Err(error::ExternalSnafu
628 .into_error(BoxedError::new(MockError::new(StatusCode::Internal)))));
629 let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), failed_to_init_stream);
630 let err = RecordBatches::try_collect(Box::pin(adapter))
631 .await
632 .unwrap_err();
633 assert!(
634 matches!(err, Error::External { .. }),
635 "unexpected err {err}"
636 );
637 }
638}