common_recordbatch/
adapter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use common_base::readable_size::ReadableSize;
24use common_telemetry::tracing::{Span, info_span};
25use common_time::util::format_nanoseconds_human_readable;
26use datafusion::arrow::compute::cast;
27use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
28use datafusion::error::Result as DfResult;
29use datafusion::execution::context::ExecutionProps;
30use datafusion::logical_expr::Expr;
31use datafusion::logical_expr::utils::conjunction;
32use datafusion::physical_expr::create_physical_expr;
33use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue};
34use datafusion::physical_plan::{
35    DisplayFormatType, ExecutionPlan, ExecutionPlanVisitor, PhysicalExpr,
36    RecordBatchStream as DfRecordBatchStream, accept,
37};
38use datafusion_common::arrow::error::ArrowError;
39use datafusion_common::{DataFusionError, ToDFSchema};
40use datatypes::arrow::array::Array;
41use datatypes::arrow::datatypes::DataType as ArrowDataType;
42use datatypes::schema::{ColumnExtType, Schema, SchemaRef};
43use futures::ready;
44use jsonb;
45use pin_project::pin_project;
46use snafu::ResultExt;
47
48use crate::error::{self, Result};
49use crate::filter::batch_filter;
50use crate::{
51    DfRecordBatch, DfSendableRecordBatchStream, OrderOption, RecordBatch, RecordBatchStream,
52    SendableRecordBatchStream, Stream,
53};
54
55type FutureStream =
56    Pin<Box<dyn std::future::Future<Output = Result<SendableRecordBatchStream>> + Send>>;
57
58/// Casts the `RecordBatch`es of `stream` against the `output_schema`.
59#[pin_project]
60pub struct RecordBatchStreamTypeAdapter<T, E> {
61    #[pin]
62    stream: T,
63    projected_schema: DfSchemaRef,
64    projection: Vec<usize>,
65    predicate: Option<Arc<dyn PhysicalExpr>>,
66    phantom: PhantomData<E>,
67}
68
69impl<T, E> RecordBatchStreamTypeAdapter<T, E>
70where
71    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
72    E: std::error::Error + Send + Sync + 'static,
73{
74    pub fn new(projected_schema: DfSchemaRef, stream: T, projection: Option<Vec<usize>>) -> Self {
75        let projection = if let Some(projection) = projection {
76            projection
77        } else {
78            (0..projected_schema.fields().len()).collect()
79        };
80
81        Self {
82            stream,
83            projected_schema,
84            projection,
85            predicate: None,
86            phantom: Default::default(),
87        }
88    }
89
90    pub fn with_filter(mut self, filters: Vec<Expr>) -> Result<Self> {
91        let filters = if let Some(expr) = conjunction(filters) {
92            let df_schema = self
93                .projected_schema
94                .clone()
95                .to_dfschema_ref()
96                .context(error::PhysicalExprSnafu)?;
97
98            let filters = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())
99                .context(error::PhysicalExprSnafu)?;
100            Some(filters)
101        } else {
102            None
103        };
104        self.predicate = filters;
105        Ok(self)
106    }
107}
108
109impl<T, E> DfRecordBatchStream for RecordBatchStreamTypeAdapter<T, E>
110where
111    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
112    E: std::error::Error + Send + Sync + 'static,
113{
114    fn schema(&self) -> DfSchemaRef {
115        self.projected_schema.clone()
116    }
117}
118
119impl<T, E> Stream for RecordBatchStreamTypeAdapter<T, E>
120where
121    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
122    E: std::error::Error + Send + Sync + 'static,
123{
124    type Item = DfResult<DfRecordBatch>;
125
126    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        let this = self.project();
128
129        let batch = futures::ready!(this.stream.poll_next(cx))
130            .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));
131
132        let projected_schema = this.projected_schema.clone();
133        let projection = this.projection.clone();
134        let predicate = this.predicate.clone();
135
136        let batch = batch.map(|b| {
137            b.and_then(|b| {
138                let projected_column = b.project(&projection)?;
139                if projected_column.schema().fields.len() != projected_schema.fields.len() {
140                   return Err(DataFusionError::ArrowError(Box::new(ArrowError::SchemaError(format!(
141                        "Trying to cast a RecordBatch into an incompatible schema. RecordBatch: {}, Target: {}",
142                        projected_column.schema(),
143                        projected_schema,
144                    ))), None));
145                }
146
147                let mut columns = Vec::with_capacity(projected_schema.fields.len());
148                for (idx,field) in projected_schema.fields.iter().enumerate() {
149                    let column = projected_column.column(idx);
150                    let extype = field.metadata().get("greptime:type").and_then(|s| ColumnExtType::from_str(s).ok());
151                    let output = custom_cast(&column, field.data_type(), extype)?;
152                    columns.push(output)
153                }
154                let record_batch = DfRecordBatch::try_new(projected_schema, columns)?;
155                let record_batch = if let Some(predicate) = predicate {
156                    batch_filter(&record_batch, &predicate)?
157                } else {
158                    record_batch
159                };
160                Ok(record_batch)
161            })
162        });
163
164        Poll::Ready(batch)
165    }
166
167    #[inline]
168    fn size_hint(&self) -> (usize, Option<usize>) {
169        self.stream.size_hint()
170    }
171}
172
173/// Greptime SendableRecordBatchStream -> DataFusion RecordBatchStream.
174/// The reverse one is [RecordBatchStreamAdapter].
175pub struct DfRecordBatchStreamAdapter {
176    stream: SendableRecordBatchStream,
177}
178
179impl DfRecordBatchStreamAdapter {
180    pub fn new(stream: SendableRecordBatchStream) -> Self {
181        Self { stream }
182    }
183}
184
185impl DfRecordBatchStream for DfRecordBatchStreamAdapter {
186    fn schema(&self) -> DfSchemaRef {
187        self.stream.schema().arrow_schema().clone()
188    }
189}
190
191impl Stream for DfRecordBatchStreamAdapter {
192    type Item = DfResult<DfRecordBatch>;
193
194    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195        match Pin::new(&mut self.stream).poll_next(cx) {
196            Poll::Pending => Poll::Pending,
197            Poll::Ready(Some(recordbatch)) => match recordbatch {
198                Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))),
199                Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))),
200            },
201            Poll::Ready(None) => Poll::Ready(None),
202        }
203    }
204
205    #[inline]
206    fn size_hint(&self) -> (usize, Option<usize>) {
207        self.stream.size_hint()
208    }
209}
210
211/// DataFusion [SendableRecordBatchStream](DfSendableRecordBatchStream) -> Greptime [RecordBatchStream].
212/// The reverse one is [DfRecordBatchStreamAdapter].
213/// It can collect metrics from DataFusion execution plan.
214pub struct RecordBatchStreamAdapter {
215    schema: SchemaRef,
216    stream: DfSendableRecordBatchStream,
217    metrics: Option<BaselineMetrics>,
218    /// Aggregated plan-level metrics. Resolved after an [ExecutionPlan] is finished.
219    metrics_2: Metrics,
220    /// Display plan and metrics in verbose mode.
221    explain_verbose: bool,
222    span: Span,
223}
224
225/// Json encoded metrics. Contains metric from a whole plan tree.
226enum Metrics {
227    Unavailable,
228    Unresolved(Arc<dyn ExecutionPlan>),
229    PartialResolved(Arc<dyn ExecutionPlan>, RecordBatchMetrics),
230    Resolved(RecordBatchMetrics),
231}
232
233impl RecordBatchStreamAdapter {
234    pub fn try_new(stream: DfSendableRecordBatchStream) -> Result<Self> {
235        let schema =
236            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
237        Ok(Self {
238            schema,
239            stream,
240            metrics: None,
241            metrics_2: Metrics::Unavailable,
242            explain_verbose: false,
243            span: Span::current(),
244        })
245    }
246
247    pub fn try_new_with_span(stream: DfSendableRecordBatchStream, span: Span) -> Result<Self> {
248        let schema =
249            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
250        let subspan = info_span!(parent: &span, "RecordBatchStreamAdapter");
251        Ok(Self {
252            schema,
253            stream,
254            metrics: None,
255            metrics_2: Metrics::Unavailable,
256            explain_verbose: false,
257            span: subspan,
258        })
259    }
260
261    pub fn set_metrics2(&mut self, plan: Arc<dyn ExecutionPlan>) {
262        self.metrics_2 = Metrics::Unresolved(plan)
263    }
264
265    /// Set the verbose mode for displaying plan and metrics.
266    pub fn set_explain_verbose(&mut self, verbose: bool) {
267        self.explain_verbose = verbose;
268    }
269}
270
271impl RecordBatchStream for RecordBatchStreamAdapter {
272    fn name(&self) -> &str {
273        "RecordBatchStreamAdapter"
274    }
275
276    fn schema(&self) -> SchemaRef {
277        self.schema.clone()
278    }
279
280    fn metrics(&self) -> Option<RecordBatchMetrics> {
281        match &self.metrics_2 {
282            Metrics::Resolved(metrics) | Metrics::PartialResolved(_, metrics) => {
283                Some(metrics.clone())
284            }
285            Metrics::Unavailable | Metrics::Unresolved(_) => None,
286        }
287    }
288
289    fn output_ordering(&self) -> Option<&[OrderOption]> {
290        None
291    }
292}
293
294impl Stream for RecordBatchStreamAdapter {
295    type Item = Result<RecordBatch>;
296
297    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298        let timer = self
299            .metrics
300            .as_ref()
301            .map(|m| m.elapsed_compute().clone())
302            .unwrap_or_default();
303        let _guard = timer.timer();
304        let poll_span = info_span!(parent: &self.span, "poll_next");
305        let _entered = poll_span.enter();
306        match Pin::new(&mut self.stream).poll_next(cx) {
307            Poll::Pending => Poll::Pending,
308            Poll::Ready(Some(df_record_batch)) => {
309                let df_record_batch = df_record_batch?;
310                if let Metrics::Unresolved(df_plan) | Metrics::PartialResolved(df_plan, _) =
311                    &self.metrics_2
312                {
313                    let mut metric_collector = MetricCollector::new(self.explain_verbose);
314                    accept(df_plan.as_ref(), &mut metric_collector).unwrap();
315                    self.metrics_2 = Metrics::PartialResolved(
316                        df_plan.clone(),
317                        metric_collector.record_batch_metrics,
318                    );
319                }
320                Poll::Ready(Some(Ok(RecordBatch::from_df_record_batch(
321                    self.schema(),
322                    df_record_batch,
323                ))))
324            }
325            Poll::Ready(None) => {
326                if let Metrics::Unresolved(df_plan) | Metrics::PartialResolved(df_plan, _) =
327                    &self.metrics_2
328                {
329                    let mut metric_collector = MetricCollector::new(self.explain_verbose);
330                    accept(df_plan.as_ref(), &mut metric_collector).unwrap();
331                    self.metrics_2 = Metrics::Resolved(metric_collector.record_batch_metrics);
332                }
333                Poll::Ready(None)
334            }
335        }
336    }
337
338    #[inline]
339    fn size_hint(&self) -> (usize, Option<usize>) {
340        self.stream.size_hint()
341    }
342}
343
344/// An [ExecutionPlanVisitor] to collect metrics from a [ExecutionPlan].
345pub struct MetricCollector {
346    current_level: usize,
347    pub record_batch_metrics: RecordBatchMetrics,
348    verbose: bool,
349}
350
351impl MetricCollector {
352    pub fn new(verbose: bool) -> Self {
353        Self {
354            current_level: 0,
355            record_batch_metrics: RecordBatchMetrics::default(),
356            verbose,
357        }
358    }
359}
360
361impl ExecutionPlanVisitor for MetricCollector {
362    type Error = !;
363
364    fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
365        // skip if no metric available
366        let Some(metric) = plan.metrics() else {
367            self.record_batch_metrics.plan_metrics.push(PlanMetrics {
368                plan: plan.name().to_string(),
369                level: self.current_level,
370                metrics: vec![],
371            });
372            self.current_level += 1;
373            return Ok(true);
374        };
375
376        // scrape plan metrics
377        let metric = metric
378            .aggregate_by_name()
379            .sorted_for_display()
380            .timestamps_removed();
381        let mut plan_metric = PlanMetrics {
382            plan: one_line(plan, self.verbose).to_string(),
383            level: self.current_level,
384            metrics: Vec::with_capacity(metric.iter().size_hint().0),
385        };
386        for m in metric.iter() {
387            plan_metric
388                .metrics
389                .push((m.value().name().to_string(), m.value().as_usize()));
390
391            // aggregate high-level metrics
392            match m.value() {
393                MetricValue::ElapsedCompute(ec) => {
394                    self.record_batch_metrics.elapsed_compute += ec.value()
395                }
396                MetricValue::CurrentMemoryUsage(m) => {
397                    self.record_batch_metrics.memory_usage += m.value()
398                }
399                _ => {}
400            }
401        }
402        self.record_batch_metrics.plan_metrics.push(plan_metric);
403
404        self.current_level += 1;
405        Ok(true)
406    }
407
408    fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
409        self.current_level -= 1;
410        Ok(true)
411    }
412}
413
414/// Returns a single-line summary of the root of the plan.
415/// If the `verbose` flag is set, it will display detailed information about the plan.
416fn one_line(plan: &dyn ExecutionPlan, verbose: bool) -> impl fmt::Display + '_ {
417    struct Wrapper<'a> {
418        plan: &'a dyn ExecutionPlan,
419        format_type: DisplayFormatType,
420    }
421
422    impl fmt::Display for Wrapper<'_> {
423        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424            self.plan.fmt_as(self.format_type, f)?;
425            writeln!(f)
426        }
427    }
428
429    let format_type = if verbose {
430        DisplayFormatType::Verbose
431    } else {
432        DisplayFormatType::Default
433    };
434    Wrapper { plan, format_type }
435}
436
437/// [`RecordBatchMetrics`] carrys metrics value
438/// from datanode to frontend through gRPC
439#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
440pub struct RecordBatchMetrics {
441    // High-level aggregated metrics
442    /// CPU consumption in nanoseconds
443    pub elapsed_compute: usize,
444    /// Memory used by the plan in bytes
445    pub memory_usage: usize,
446    // Detailed per-plan metrics
447    /// An ordered list of plan metrics, from top to bottom in post-order.
448    pub plan_metrics: Vec<PlanMetrics>,
449}
450
451/// Determines if a metric name represents a time measurement that should be formatted.
452fn is_time_metric(metric_name: &str) -> bool {
453    metric_name.contains("elapsed") || metric_name.contains("time") || metric_name.contains("cost")
454}
455
456/// Determines if a metric name represents a bytes measurement that should be formatted.
457fn is_bytes_metric(metric_name: &str) -> bool {
458    metric_name.contains("bytes") || metric_name.contains("mem")
459}
460
461fn format_bytes_human_readable(bytes: usize) -> String {
462    format!("{}", ReadableSize(bytes as u64))
463}
464
465/// Only display `plan_metrics` with indent `  ` (2 spaces).
466impl Display for RecordBatchMetrics {
467    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        for metric in &self.plan_metrics {
469            write!(
470                f,
471                "{:indent$}{} metrics=[",
472                " ",
473                metric.plan.trim_end(),
474                indent = metric.level * 2,
475            )?;
476            for (label, value) in &metric.metrics {
477                if is_time_metric(label) {
478                    write!(
479                        f,
480                        "{}: {}, ",
481                        label,
482                        format_nanoseconds_human_readable(*value),
483                    )?;
484                } else if is_bytes_metric(label) {
485                    write!(f, "{}: {}, ", label, format_bytes_human_readable(*value),)?;
486                } else {
487                    write!(f, "{}: {}, ", label, value)?;
488                }
489            }
490            writeln!(f, "]")?;
491        }
492
493        Ok(())
494    }
495}
496
497#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
498pub struct PlanMetrics {
499    /// The plan name
500    pub plan: String,
501    /// The level of the plan, starts from 0
502    pub level: usize,
503    /// An ordered key-value list of metrics.
504    /// Key is metric label and value is metric value.
505    pub metrics: Vec<(String, usize)>,
506}
507
508enum AsyncRecordBatchStreamAdapterState {
509    Uninit(FutureStream),
510    Ready(SendableRecordBatchStream),
511    Failed,
512}
513
514pub struct AsyncRecordBatchStreamAdapter {
515    schema: SchemaRef,
516    state: AsyncRecordBatchStreamAdapterState,
517}
518
519impl AsyncRecordBatchStreamAdapter {
520    pub fn new(schema: SchemaRef, stream: FutureStream) -> Self {
521        Self {
522            schema,
523            state: AsyncRecordBatchStreamAdapterState::Uninit(stream),
524        }
525    }
526}
527
528impl RecordBatchStream for AsyncRecordBatchStreamAdapter {
529    fn schema(&self) -> SchemaRef {
530        self.schema.clone()
531    }
532
533    fn output_ordering(&self) -> Option<&[OrderOption]> {
534        None
535    }
536
537    fn metrics(&self) -> Option<RecordBatchMetrics> {
538        None
539    }
540}
541
542impl Stream for AsyncRecordBatchStreamAdapter {
543    type Item = Result<RecordBatch>;
544
545    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
546        loop {
547            match &mut self.state {
548                AsyncRecordBatchStreamAdapterState::Uninit(stream_future) => {
549                    match ready!(Pin::new(stream_future).poll(cx)) {
550                        Ok(stream) => {
551                            self.state = AsyncRecordBatchStreamAdapterState::Ready(stream);
552                            continue;
553                        }
554                        Err(e) => {
555                            self.state = AsyncRecordBatchStreamAdapterState::Failed;
556                            return Poll::Ready(Some(Err(e)));
557                        }
558                    };
559                }
560                AsyncRecordBatchStreamAdapterState::Ready(stream) => {
561                    return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)));
562                }
563                AsyncRecordBatchStreamAdapterState::Failed => return Poll::Ready(None),
564            }
565        }
566    }
567
568    // This is not supported for lazy stream.
569    #[inline]
570    fn size_hint(&self) -> (usize, Option<usize>) {
571        (0, None)
572    }
573}
574
575/// Custom cast function that handles Map -> Binary (JSON) conversion
576fn custom_cast(
577    array: &dyn Array,
578    target_type: &ArrowDataType,
579    extype: Option<ColumnExtType>,
580) -> std::result::Result<Arc<dyn Array>, ArrowError> {
581    if let ArrowDataType::Map(_, _) = array.data_type()
582        && let ArrowDataType::Binary = target_type
583    {
584        return convert_map_to_json_binary(array, extype);
585    }
586
587    cast(array, target_type)
588}
589
590/// Convert a Map array to a Binary array containing JSON data
591fn convert_map_to_json_binary(
592    array: &dyn Array,
593    extype: Option<ColumnExtType>,
594) -> std::result::Result<Arc<dyn Array>, ArrowError> {
595    use datatypes::arrow::array::{BinaryArray, MapArray};
596    use serde_json::Value;
597
598    let map_array = array
599        .as_any()
600        .downcast_ref::<MapArray>()
601        .ok_or_else(|| ArrowError::CastError("Failed to downcast to MapArray".to_string()))?;
602
603    let mut json_values = Vec::with_capacity(map_array.len());
604
605    for i in 0..map_array.len() {
606        if map_array.is_null(i) {
607            json_values.push(None);
608        } else {
609            // Extract the map entry at index i
610            let map_entry = map_array.value(i);
611            let key_value_array = map_entry
612                .as_any()
613                .downcast_ref::<datatypes::arrow::array::StructArray>()
614                .ok_or_else(|| {
615                    ArrowError::CastError("Failed to downcast to StructArray".to_string())
616                })?;
617
618            // Convert to JSON object
619            let mut json_obj = serde_json::Map::with_capacity(key_value_array.len());
620
621            for j in 0..key_value_array.len() {
622                if key_value_array.is_null(j) {
623                    continue;
624                }
625                let key_field = key_value_array.column(0);
626                let value_field = key_value_array.column(1);
627
628                if key_field.is_null(j) {
629                    continue;
630                }
631
632                let key = key_field
633                    .as_any()
634                    .downcast_ref::<datatypes::arrow::array::StringArray>()
635                    .ok_or_else(|| {
636                        ArrowError::CastError("Failed to downcast key to StringArray".to_string())
637                    })?
638                    .value(j);
639
640                let value = if value_field.is_null(j) {
641                    Value::Null
642                } else {
643                    let value_str = value_field
644                        .as_any()
645                        .downcast_ref::<datatypes::arrow::array::StringArray>()
646                        .ok_or_else(|| {
647                            ArrowError::CastError(
648                                "Failed to downcast value to StringArray".to_string(),
649                            )
650                        })?
651                        .value(j);
652                    Value::String(value_str.to_string())
653                };
654
655                json_obj.insert(key.to_string(), value);
656            }
657
658            let json_value = Value::Object(json_obj);
659            let json_bytes = match extype {
660                Some(ColumnExtType::Json) => {
661                    let json_string = match serde_json::to_string(&json_value) {
662                        Ok(s) => s,
663                        Err(e) => {
664                            return Err(ArrowError::CastError(format!(
665                                "Failed to serialize JSON: {}",
666                                e
667                            )));
668                        }
669                    };
670                    match jsonb::parse_value(json_string.as_bytes()) {
671                        Ok(jsonb_value) => jsonb_value.to_vec(),
672                        Err(e) => {
673                            return Err(ArrowError::CastError(format!(
674                                "Failed to serialize JSONB: {}",
675                                e
676                            )));
677                        }
678                    }
679                }
680                _ => match serde_json::to_vec(&json_value) {
681                    Ok(b) => b,
682                    Err(e) => {
683                        return Err(ArrowError::CastError(format!(
684                            "Failed to serialize JSON: {}",
685                            e
686                        )));
687                    }
688                },
689            };
690            json_values.push(Some(json_bytes));
691        }
692    }
693
694    let binary_array = BinaryArray::from_iter(json_values);
695    Ok(Arc::new(binary_array))
696}
697
698#[cfg(test)]
699mod test {
700    use common_error::ext::BoxedError;
701    use common_error::mock::MockError;
702    use common_error::status_code::StatusCode;
703    use datatypes::arrow::array::{ArrayRef, MapArray, StringArray, StructArray};
704    use datatypes::arrow::buffer::OffsetBuffer;
705    use datatypes::arrow::datatypes::Field;
706    use datatypes::prelude::ConcreteDataType;
707    use datatypes::schema::ColumnSchema;
708    use datatypes::vectors::Int32Vector;
709    use snafu::IntoError;
710
711    use super::*;
712    use crate::RecordBatches;
713    use crate::error::Error;
714
715    #[tokio::test]
716    async fn test_async_recordbatch_stream_adaptor() {
717        struct MaybeErrorRecordBatchStream {
718            items: Vec<Result<RecordBatch>>,
719        }
720
721        impl RecordBatchStream for MaybeErrorRecordBatchStream {
722            fn schema(&self) -> SchemaRef {
723                unimplemented!()
724            }
725
726            fn output_ordering(&self) -> Option<&[OrderOption]> {
727                None
728            }
729
730            fn metrics(&self) -> Option<RecordBatchMetrics> {
731                None
732            }
733        }
734
735        impl Stream for MaybeErrorRecordBatchStream {
736            type Item = Result<RecordBatch>;
737
738            fn poll_next(
739                mut self: Pin<&mut Self>,
740                _: &mut Context<'_>,
741            ) -> Poll<Option<Self::Item>> {
742                if let Some(batch) = self.items.pop() {
743                    Poll::Ready(Some(Ok(batch?)))
744                } else {
745                    Poll::Ready(None)
746                }
747            }
748        }
749
750        fn new_future_stream(
751            maybe_recordbatches: Result<Vec<Result<RecordBatch>>>,
752        ) -> FutureStream {
753            Box::pin(async move {
754                maybe_recordbatches
755                    .map(|items| Box::pin(MaybeErrorRecordBatchStream { items }) as _)
756            })
757        }
758
759        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
760            "a",
761            ConcreteDataType::int32_datatype(),
762            false,
763        )]));
764        let batch1 = RecordBatch::new(
765            schema.clone(),
766            vec![Arc::new(Int32Vector::from_slice([1])) as _],
767        )
768        .unwrap();
769        let batch2 = RecordBatch::new(
770            schema.clone(),
771            vec![Arc::new(Int32Vector::from_slice([2])) as _],
772        )
773        .unwrap();
774
775        let success_stream = new_future_stream(Ok(vec![Ok(batch1.clone()), Ok(batch2.clone())]));
776        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), success_stream);
777        let collected = RecordBatches::try_collect(Box::pin(adapter)).await.unwrap();
778        assert_eq!(
779            collected,
780            RecordBatches::try_new(schema.clone(), vec![batch2.clone(), batch1.clone()]).unwrap()
781        );
782
783        let poll_err_stream = new_future_stream(Ok(vec![
784            Ok(batch1.clone()),
785            Err(error::ExternalSnafu
786                .into_error(BoxedError::new(MockError::new(StatusCode::Unknown)))),
787        ]));
788        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), poll_err_stream);
789        let err = RecordBatches::try_collect(Box::pin(adapter))
790            .await
791            .unwrap_err();
792        assert!(
793            matches!(err, Error::External { .. }),
794            "unexpected err {err}"
795        );
796
797        let failed_to_init_stream =
798            new_future_stream(Err(error::ExternalSnafu
799                .into_error(BoxedError::new(MockError::new(StatusCode::Internal)))));
800        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), failed_to_init_stream);
801        let err = RecordBatches::try_collect(Box::pin(adapter))
802            .await
803            .unwrap_err();
804        assert!(
805            matches!(err, Error::External { .. }),
806            "unexpected err {err}"
807        );
808    }
809
810    #[test]
811    fn test_convert_map_to_json_binary() {
812        let keys = StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("x")]);
813        let values = StringArray::from(vec![Some("1"), None, Some("3"), Some("42")]);
814        let key_field = Arc::new(Field::new("key", ArrowDataType::Utf8, false));
815        let value_field = Arc::new(Field::new("value", ArrowDataType::Utf8, true));
816        let struct_type = ArrowDataType::Struct(vec![key_field, value_field].into());
817
818        let entries_field = Arc::new(Field::new("entries", struct_type, false));
819
820        let struct_array = StructArray::from(vec![
821            (
822                Arc::new(Field::new("key", ArrowDataType::Utf8, false)),
823                Arc::new(keys) as ArrayRef,
824            ),
825            (
826                Arc::new(Field::new("value", ArrowDataType::Utf8, true)),
827                Arc::new(values) as ArrayRef,
828            ),
829        ]);
830
831        let offsets = OffsetBuffer::from_lengths([3, 0, 1]);
832        let nulls = datatypes::arrow::buffer::NullBuffer::from(vec![true, false, true]);
833
834        let map_array = MapArray::new(
835            entries_field,
836            offsets,
837            struct_array,
838            Some(nulls), // nulls
839            false,
840        );
841
842        let result = convert_map_to_json_binary(&map_array, None).unwrap();
843        let binary_array = result
844            .as_any()
845            .downcast_ref::<datatypes::arrow::array::BinaryArray>()
846            .unwrap();
847
848        let expected_jsons = [
849            Some(r#"{"a":"1","b":null,"c":"3"}"#),
850            None,
851            Some(r#"{"x":"42"}"#),
852        ];
853
854        for (i, _) in expected_jsons.iter().enumerate() {
855            if let Some(expected) = &expected_jsons[i] {
856                assert!(!binary_array.is_null(i));
857                let actual_bytes = binary_array.value(i);
858                let actual_str = std::str::from_utf8(actual_bytes).unwrap();
859                assert_eq!(actual_str, *expected);
860            } else {
861                assert!(binary_array.is_null(i));
862            }
863        }
864
865        let result_json =
866            convert_map_to_json_binary(&map_array, Some(ColumnExtType::Json)).unwrap();
867        let binary_array_json = result_json
868            .as_any()
869            .downcast_ref::<datatypes::arrow::array::BinaryArray>()
870            .unwrap();
871
872        for (i, _) in expected_jsons.iter().enumerate() {
873            if expected_jsons[i].is_some() {
874                assert!(!binary_array_json.is_null(i));
875                let actual_bytes = binary_array_json.value(i);
876                assert_ne!(actual_bytes, expected_jsons[i].unwrap().as_bytes());
877            } else {
878                assert!(binary_array_json.is_null(i));
879            }
880        }
881    }
882}