common_recordbatch/
adapter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use common_base::readable_size::ReadableSize;
24use common_time::util::format_nanoseconds_human_readable;
25use datafusion::arrow::compute::cast;
26use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
27use datafusion::error::Result as DfResult;
28use datafusion::execution::context::ExecutionProps;
29use datafusion::logical_expr::Expr;
30use datafusion::logical_expr::utils::conjunction;
31use datafusion::physical_expr::create_physical_expr;
32use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue};
33use datafusion::physical_plan::{
34    DisplayFormatType, ExecutionPlan, ExecutionPlanVisitor, PhysicalExpr,
35    RecordBatchStream as DfRecordBatchStream, accept,
36};
37use datafusion_common::arrow::error::ArrowError;
38use datafusion_common::{DataFusionError, ToDFSchema};
39use datatypes::arrow::array::Array;
40use datatypes::arrow::datatypes::DataType as ArrowDataType;
41use datatypes::schema::{ColumnExtType, Schema, SchemaRef};
42use futures::ready;
43use jsonb;
44use pin_project::pin_project;
45use snafu::ResultExt;
46
47use crate::error::{self, Result};
48use crate::filter::batch_filter;
49use crate::{
50    DfRecordBatch, DfSendableRecordBatchStream, OrderOption, RecordBatch, RecordBatchStream,
51    SendableRecordBatchStream, Stream,
52};
53
54type FutureStream =
55    Pin<Box<dyn std::future::Future<Output = Result<SendableRecordBatchStream>> + Send>>;
56
57/// Casts the `RecordBatch`es of `stream` against the `output_schema`.
58#[pin_project]
59pub struct RecordBatchStreamTypeAdapter<T, E> {
60    #[pin]
61    stream: T,
62    projected_schema: DfSchemaRef,
63    projection: Vec<usize>,
64    predicate: Option<Arc<dyn PhysicalExpr>>,
65    phantom: PhantomData<E>,
66}
67
68impl<T, E> RecordBatchStreamTypeAdapter<T, E>
69where
70    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
71    E: std::error::Error + Send + Sync + 'static,
72{
73    pub fn new(projected_schema: DfSchemaRef, stream: T, projection: Option<Vec<usize>>) -> Self {
74        let projection = if let Some(projection) = projection {
75            projection
76        } else {
77            (0..projected_schema.fields().len()).collect()
78        };
79
80        Self {
81            stream,
82            projected_schema,
83            projection,
84            predicate: None,
85            phantom: Default::default(),
86        }
87    }
88
89    pub fn with_filter(mut self, filters: Vec<Expr>) -> Result<Self> {
90        let filters = if let Some(expr) = conjunction(filters) {
91            let df_schema = self
92                .projected_schema
93                .clone()
94                .to_dfschema_ref()
95                .context(error::PhysicalExprSnafu)?;
96
97            let filters = create_physical_expr(&expr, &df_schema, &ExecutionProps::new())
98                .context(error::PhysicalExprSnafu)?;
99            Some(filters)
100        } else {
101            None
102        };
103        self.predicate = filters;
104        Ok(self)
105    }
106}
107
108impl<T, E> DfRecordBatchStream for RecordBatchStreamTypeAdapter<T, E>
109where
110    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
111    E: std::error::Error + Send + Sync + 'static,
112{
113    fn schema(&self) -> DfSchemaRef {
114        self.projected_schema.clone()
115    }
116}
117
118impl<T, E> Stream for RecordBatchStreamTypeAdapter<T, E>
119where
120    T: Stream<Item = std::result::Result<DfRecordBatch, E>>,
121    E: std::error::Error + Send + Sync + 'static,
122{
123    type Item = DfResult<DfRecordBatch>;
124
125    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126        let this = self.project();
127
128        let batch = futures::ready!(this.stream.poll_next(cx))
129            .map(|r| r.map_err(|e| DataFusionError::External(Box::new(e))));
130
131        let projected_schema = this.projected_schema.clone();
132        let projection = this.projection.clone();
133        let predicate = this.predicate.clone();
134
135        let batch = batch.map(|b| {
136            b.and_then(|b| {
137                let projected_column = b.project(&projection)?;
138                if projected_column.schema().fields.len() != projected_schema.fields.len() {
139                   return Err(DataFusionError::ArrowError(Box::new(ArrowError::SchemaError(format!(
140                        "Trying to cast a RecordBatch into an incompatible schema. RecordBatch: {}, Target: {}",
141                        projected_column.schema(),
142                        projected_schema,
143                    ))), None));
144                }
145
146                let mut columns = Vec::with_capacity(projected_schema.fields.len());
147                for (idx,field) in projected_schema.fields.iter().enumerate() {
148                    let column = projected_column.column(idx);
149                    let extype = field.metadata().get("greptime:type").and_then(|s| ColumnExtType::from_str(s).ok());
150                    let output = custom_cast(&column, field.data_type(), extype)?;
151                    columns.push(output)
152                }
153                let record_batch = DfRecordBatch::try_new(projected_schema, columns)?;
154                let record_batch = if let Some(predicate) = predicate {
155                    batch_filter(&record_batch, &predicate)?
156                } else {
157                    record_batch
158                };
159                Ok(record_batch)
160            })
161        });
162
163        Poll::Ready(batch)
164    }
165
166    #[inline]
167    fn size_hint(&self) -> (usize, Option<usize>) {
168        self.stream.size_hint()
169    }
170}
171
172/// Greptime SendableRecordBatchStream -> DataFusion RecordBatchStream.
173/// The reverse one is [RecordBatchStreamAdapter].
174pub struct DfRecordBatchStreamAdapter {
175    stream: SendableRecordBatchStream,
176}
177
178impl DfRecordBatchStreamAdapter {
179    pub fn new(stream: SendableRecordBatchStream) -> Self {
180        Self { stream }
181    }
182}
183
184impl DfRecordBatchStream for DfRecordBatchStreamAdapter {
185    fn schema(&self) -> DfSchemaRef {
186        self.stream.schema().arrow_schema().clone()
187    }
188}
189
190impl Stream for DfRecordBatchStreamAdapter {
191    type Item = DfResult<DfRecordBatch>;
192
193    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194        match Pin::new(&mut self.stream).poll_next(cx) {
195            Poll::Pending => Poll::Pending,
196            Poll::Ready(Some(recordbatch)) => match recordbatch {
197                Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))),
198                Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))),
199            },
200            Poll::Ready(None) => Poll::Ready(None),
201        }
202    }
203
204    #[inline]
205    fn size_hint(&self) -> (usize, Option<usize>) {
206        self.stream.size_hint()
207    }
208}
209
210/// DataFusion [SendableRecordBatchStream](DfSendableRecordBatchStream) -> Greptime [RecordBatchStream].
211/// The reverse one is [DfRecordBatchStreamAdapter].
212/// It can collect metrics from DataFusion execution plan.
213pub struct RecordBatchStreamAdapter {
214    schema: SchemaRef,
215    stream: DfSendableRecordBatchStream,
216    metrics: Option<BaselineMetrics>,
217    /// Aggregated plan-level metrics. Resolved after an [ExecutionPlan] is finished.
218    metrics_2: Metrics,
219    /// Display plan and metrics in verbose mode.
220    explain_verbose: bool,
221}
222
223/// Json encoded metrics. Contains metric from a whole plan tree.
224enum Metrics {
225    Unavailable,
226    Unresolved(Arc<dyn ExecutionPlan>),
227    PartialResolved(Arc<dyn ExecutionPlan>, RecordBatchMetrics),
228    Resolved(RecordBatchMetrics),
229}
230
231impl RecordBatchStreamAdapter {
232    pub fn try_new(stream: DfSendableRecordBatchStream) -> Result<Self> {
233        let schema =
234            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
235        Ok(Self {
236            schema,
237            stream,
238            metrics: None,
239            metrics_2: Metrics::Unavailable,
240            explain_verbose: false,
241        })
242    }
243
244    pub fn try_new_with_metrics_and_df_plan(
245        stream: DfSendableRecordBatchStream,
246        metrics: BaselineMetrics,
247        df_plan: Arc<dyn ExecutionPlan>,
248    ) -> Result<Self> {
249        let schema =
250            Arc::new(Schema::try_from(stream.schema()).context(error::SchemaConversionSnafu)?);
251        Ok(Self {
252            schema,
253            stream,
254            metrics: Some(metrics),
255            metrics_2: Metrics::Unresolved(df_plan),
256            explain_verbose: false,
257        })
258    }
259
260    pub fn set_metrics2(&mut self, plan: Arc<dyn ExecutionPlan>) {
261        self.metrics_2 = Metrics::Unresolved(plan)
262    }
263
264    /// Set the verbose mode for displaying plan and metrics.
265    pub fn set_explain_verbose(&mut self, verbose: bool) {
266        self.explain_verbose = verbose;
267    }
268}
269
270impl RecordBatchStream for RecordBatchStreamAdapter {
271    fn name(&self) -> &str {
272        "RecordBatchStreamAdapter"
273    }
274
275    fn schema(&self) -> SchemaRef {
276        self.schema.clone()
277    }
278
279    fn metrics(&self) -> Option<RecordBatchMetrics> {
280        match &self.metrics_2 {
281            Metrics::Resolved(metrics) | Metrics::PartialResolved(_, metrics) => {
282                Some(metrics.clone())
283            }
284            Metrics::Unavailable | Metrics::Unresolved(_) => None,
285        }
286    }
287
288    fn output_ordering(&self) -> Option<&[OrderOption]> {
289        None
290    }
291}
292
293impl Stream for RecordBatchStreamAdapter {
294    type Item = Result<RecordBatch>;
295
296    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        let timer = self
298            .metrics
299            .as_ref()
300            .map(|m| m.elapsed_compute().clone())
301            .unwrap_or_default();
302        let _guard = timer.timer();
303        match Pin::new(&mut self.stream).poll_next(cx) {
304            Poll::Pending => Poll::Pending,
305            Poll::Ready(Some(df_record_batch)) => {
306                let df_record_batch = df_record_batch?;
307                if let Metrics::Unresolved(df_plan) | Metrics::PartialResolved(df_plan, _) =
308                    &self.metrics_2
309                {
310                    let mut metric_collector = MetricCollector::new(self.explain_verbose);
311                    accept(df_plan.as_ref(), &mut metric_collector).unwrap();
312                    self.metrics_2 = Metrics::PartialResolved(
313                        df_plan.clone(),
314                        metric_collector.record_batch_metrics,
315                    );
316                }
317                Poll::Ready(Some(RecordBatch::try_from_df_record_batch(
318                    self.schema(),
319                    df_record_batch,
320                )))
321            }
322            Poll::Ready(None) => {
323                if let Metrics::Unresolved(df_plan) | Metrics::PartialResolved(df_plan, _) =
324                    &self.metrics_2
325                {
326                    let mut metric_collector = MetricCollector::new(self.explain_verbose);
327                    accept(df_plan.as_ref(), &mut metric_collector).unwrap();
328                    self.metrics_2 = Metrics::Resolved(metric_collector.record_batch_metrics);
329                }
330                Poll::Ready(None)
331            }
332        }
333    }
334
335    #[inline]
336    fn size_hint(&self) -> (usize, Option<usize>) {
337        self.stream.size_hint()
338    }
339}
340
341/// An [ExecutionPlanVisitor] to collect metrics from a [ExecutionPlan].
342pub struct MetricCollector {
343    current_level: usize,
344    pub record_batch_metrics: RecordBatchMetrics,
345    verbose: bool,
346}
347
348impl MetricCollector {
349    pub fn new(verbose: bool) -> Self {
350        Self {
351            current_level: 0,
352            record_batch_metrics: RecordBatchMetrics::default(),
353            verbose,
354        }
355    }
356}
357
358impl ExecutionPlanVisitor for MetricCollector {
359    type Error = !;
360
361    fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
362        // skip if no metric available
363        let Some(metric) = plan.metrics() else {
364            self.record_batch_metrics.plan_metrics.push(PlanMetrics {
365                plan: plan.name().to_string(),
366                level: self.current_level,
367                metrics: vec![],
368            });
369            self.current_level += 1;
370            return Ok(true);
371        };
372
373        // scrape plan metrics
374        let metric = metric
375            .aggregate_by_name()
376            .sorted_for_display()
377            .timestamps_removed();
378        let mut plan_metric = PlanMetrics {
379            plan: one_line(plan, self.verbose).to_string(),
380            level: self.current_level,
381            metrics: Vec::with_capacity(metric.iter().size_hint().0),
382        };
383        for m in metric.iter() {
384            plan_metric
385                .metrics
386                .push((m.value().name().to_string(), m.value().as_usize()));
387
388            // aggregate high-level metrics
389            match m.value() {
390                MetricValue::ElapsedCompute(ec) => {
391                    self.record_batch_metrics.elapsed_compute += ec.value()
392                }
393                MetricValue::CurrentMemoryUsage(m) => {
394                    self.record_batch_metrics.memory_usage += m.value()
395                }
396                _ => {}
397            }
398        }
399        self.record_batch_metrics.plan_metrics.push(plan_metric);
400
401        self.current_level += 1;
402        Ok(true)
403    }
404
405    fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> std::result::Result<bool, Self::Error> {
406        self.current_level -= 1;
407        Ok(true)
408    }
409}
410
411/// Returns a single-line summary of the root of the plan.
412/// If the `verbose` flag is set, it will display detailed information about the plan.
413fn one_line(plan: &dyn ExecutionPlan, verbose: bool) -> impl fmt::Display + '_ {
414    struct Wrapper<'a> {
415        plan: &'a dyn ExecutionPlan,
416        format_type: DisplayFormatType,
417    }
418
419    impl fmt::Display for Wrapper<'_> {
420        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421            self.plan.fmt_as(self.format_type, f)?;
422            writeln!(f)
423        }
424    }
425
426    let format_type = if verbose {
427        DisplayFormatType::Verbose
428    } else {
429        DisplayFormatType::Default
430    };
431    Wrapper { plan, format_type }
432}
433
434/// [`RecordBatchMetrics`] carrys metrics value
435/// from datanode to frontend through gRPC
436#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
437pub struct RecordBatchMetrics {
438    // High-level aggregated metrics
439    /// CPU consumption in nanoseconds
440    pub elapsed_compute: usize,
441    /// Memory used by the plan in bytes
442    pub memory_usage: usize,
443    // Detailed per-plan metrics
444    /// An ordered list of plan metrics, from top to bottom in post-order.
445    pub plan_metrics: Vec<PlanMetrics>,
446}
447
448/// Determines if a metric name represents a time measurement that should be formatted.
449fn is_time_metric(metric_name: &str) -> bool {
450    metric_name.contains("elapsed") || metric_name.contains("time") || metric_name.contains("cost")
451}
452
453/// Determines if a metric name represents a bytes measurement that should be formatted.
454fn is_bytes_metric(metric_name: &str) -> bool {
455    metric_name.contains("bytes") || metric_name.contains("mem")
456}
457
458fn format_bytes_human_readable(bytes: usize) -> String {
459    format!("{}", ReadableSize(bytes as u64))
460}
461
462/// Only display `plan_metrics` with indent `  ` (2 spaces).
463impl Display for RecordBatchMetrics {
464    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465        for metric in &self.plan_metrics {
466            write!(
467                f,
468                "{:indent$}{} metrics=[",
469                " ",
470                metric.plan.trim_end(),
471                indent = metric.level * 2,
472            )?;
473            for (label, value) in &metric.metrics {
474                if is_time_metric(label) {
475                    write!(
476                        f,
477                        "{}: {}, ",
478                        label,
479                        format_nanoseconds_human_readable(*value),
480                    )?;
481                } else if is_bytes_metric(label) {
482                    write!(f, "{}: {}, ", label, format_bytes_human_readable(*value),)?;
483                } else {
484                    write!(f, "{}: {}, ", label, value)?;
485                }
486            }
487            writeln!(f, "]")?;
488        }
489
490        Ok(())
491    }
492}
493
494#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone)]
495pub struct PlanMetrics {
496    /// The plan name
497    pub plan: String,
498    /// The level of the plan, starts from 0
499    pub level: usize,
500    /// An ordered key-value list of metrics.
501    /// Key is metric label and value is metric value.
502    pub metrics: Vec<(String, usize)>,
503}
504
505enum AsyncRecordBatchStreamAdapterState {
506    Uninit(FutureStream),
507    Ready(SendableRecordBatchStream),
508    Failed,
509}
510
511pub struct AsyncRecordBatchStreamAdapter {
512    schema: SchemaRef,
513    state: AsyncRecordBatchStreamAdapterState,
514}
515
516impl AsyncRecordBatchStreamAdapter {
517    pub fn new(schema: SchemaRef, stream: FutureStream) -> Self {
518        Self {
519            schema,
520            state: AsyncRecordBatchStreamAdapterState::Uninit(stream),
521        }
522    }
523}
524
525impl RecordBatchStream for AsyncRecordBatchStreamAdapter {
526    fn schema(&self) -> SchemaRef {
527        self.schema.clone()
528    }
529
530    fn output_ordering(&self) -> Option<&[OrderOption]> {
531        None
532    }
533
534    fn metrics(&self) -> Option<RecordBatchMetrics> {
535        None
536    }
537}
538
539impl Stream for AsyncRecordBatchStreamAdapter {
540    type Item = Result<RecordBatch>;
541
542    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
543        loop {
544            match &mut self.state {
545                AsyncRecordBatchStreamAdapterState::Uninit(stream_future) => {
546                    match ready!(Pin::new(stream_future).poll(cx)) {
547                        Ok(stream) => {
548                            self.state = AsyncRecordBatchStreamAdapterState::Ready(stream);
549                            continue;
550                        }
551                        Err(e) => {
552                            self.state = AsyncRecordBatchStreamAdapterState::Failed;
553                            return Poll::Ready(Some(Err(e)));
554                        }
555                    };
556                }
557                AsyncRecordBatchStreamAdapterState::Ready(stream) => {
558                    return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)));
559                }
560                AsyncRecordBatchStreamAdapterState::Failed => return Poll::Ready(None),
561            }
562        }
563    }
564
565    // This is not supported for lazy stream.
566    #[inline]
567    fn size_hint(&self) -> (usize, Option<usize>) {
568        (0, None)
569    }
570}
571
572/// Custom cast function that handles Map -> Binary (JSON) conversion
573fn custom_cast(
574    array: &dyn Array,
575    target_type: &ArrowDataType,
576    extype: Option<ColumnExtType>,
577) -> std::result::Result<Arc<dyn Array>, ArrowError> {
578    if let ArrowDataType::Map(_, _) = array.data_type()
579        && let ArrowDataType::Binary = target_type
580    {
581        return convert_map_to_json_binary(array, extype);
582    }
583
584    cast(array, target_type)
585}
586
587/// Convert a Map array to a Binary array containing JSON data
588fn convert_map_to_json_binary(
589    array: &dyn Array,
590    extype: Option<ColumnExtType>,
591) -> std::result::Result<Arc<dyn Array>, ArrowError> {
592    use datatypes::arrow::array::{BinaryArray, MapArray};
593    use serde_json::Value;
594
595    let map_array = array
596        .as_any()
597        .downcast_ref::<MapArray>()
598        .ok_or_else(|| ArrowError::CastError("Failed to downcast to MapArray".to_string()))?;
599
600    let mut json_values = Vec::with_capacity(map_array.len());
601
602    for i in 0..map_array.len() {
603        if map_array.is_null(i) {
604            json_values.push(None);
605        } else {
606            // Extract the map entry at index i
607            let map_entry = map_array.value(i);
608            let key_value_array = map_entry
609                .as_any()
610                .downcast_ref::<datatypes::arrow::array::StructArray>()
611                .ok_or_else(|| {
612                    ArrowError::CastError("Failed to downcast to StructArray".to_string())
613                })?;
614
615            // Convert to JSON object
616            let mut json_obj = serde_json::Map::with_capacity(key_value_array.len());
617
618            for j in 0..key_value_array.len() {
619                if key_value_array.is_null(j) {
620                    continue;
621                }
622                let key_field = key_value_array.column(0);
623                let value_field = key_value_array.column(1);
624
625                if key_field.is_null(j) {
626                    continue;
627                }
628
629                let key = key_field
630                    .as_any()
631                    .downcast_ref::<datatypes::arrow::array::StringArray>()
632                    .ok_or_else(|| {
633                        ArrowError::CastError("Failed to downcast key to StringArray".to_string())
634                    })?
635                    .value(j);
636
637                let value = if value_field.is_null(j) {
638                    Value::Null
639                } else {
640                    let value_str = value_field
641                        .as_any()
642                        .downcast_ref::<datatypes::arrow::array::StringArray>()
643                        .ok_or_else(|| {
644                            ArrowError::CastError(
645                                "Failed to downcast value to StringArray".to_string(),
646                            )
647                        })?
648                        .value(j);
649                    Value::String(value_str.to_string())
650                };
651
652                json_obj.insert(key.to_string(), value);
653            }
654
655            let json_value = Value::Object(json_obj);
656            let json_bytes = match extype {
657                Some(ColumnExtType::Json) => {
658                    let json_string = match serde_json::to_string(&json_value) {
659                        Ok(s) => s,
660                        Err(e) => {
661                            return Err(ArrowError::CastError(format!(
662                                "Failed to serialize JSON: {}",
663                                e
664                            )));
665                        }
666                    };
667                    match jsonb::parse_value(json_string.as_bytes()) {
668                        Ok(jsonb_value) => jsonb_value.to_vec(),
669                        Err(e) => {
670                            return Err(ArrowError::CastError(format!(
671                                "Failed to serialize JSONB: {}",
672                                e
673                            )));
674                        }
675                    }
676                }
677                _ => match serde_json::to_vec(&json_value) {
678                    Ok(b) => b,
679                    Err(e) => {
680                        return Err(ArrowError::CastError(format!(
681                            "Failed to serialize JSON: {}",
682                            e
683                        )));
684                    }
685                },
686            };
687            json_values.push(Some(json_bytes));
688        }
689    }
690
691    let binary_array = BinaryArray::from_iter(json_values);
692    Ok(Arc::new(binary_array))
693}
694
695#[cfg(test)]
696mod test {
697    use common_error::ext::BoxedError;
698    use common_error::mock::MockError;
699    use common_error::status_code::StatusCode;
700    use datatypes::arrow::array::{ArrayRef, MapArray, StringArray, StructArray};
701    use datatypes::arrow::buffer::OffsetBuffer;
702    use datatypes::arrow::datatypes::Field;
703    use datatypes::prelude::ConcreteDataType;
704    use datatypes::schema::ColumnSchema;
705    use datatypes::vectors::Int32Vector;
706    use snafu::IntoError;
707
708    use super::*;
709    use crate::RecordBatches;
710    use crate::error::Error;
711
712    #[tokio::test]
713    async fn test_async_recordbatch_stream_adaptor() {
714        struct MaybeErrorRecordBatchStream {
715            items: Vec<Result<RecordBatch>>,
716        }
717
718        impl RecordBatchStream for MaybeErrorRecordBatchStream {
719            fn schema(&self) -> SchemaRef {
720                unimplemented!()
721            }
722
723            fn output_ordering(&self) -> Option<&[OrderOption]> {
724                None
725            }
726
727            fn metrics(&self) -> Option<RecordBatchMetrics> {
728                None
729            }
730        }
731
732        impl Stream for MaybeErrorRecordBatchStream {
733            type Item = Result<RecordBatch>;
734
735            fn poll_next(
736                mut self: Pin<&mut Self>,
737                _: &mut Context<'_>,
738            ) -> Poll<Option<Self::Item>> {
739                if let Some(batch) = self.items.pop() {
740                    Poll::Ready(Some(Ok(batch?)))
741                } else {
742                    Poll::Ready(None)
743                }
744            }
745        }
746
747        fn new_future_stream(
748            maybe_recordbatches: Result<Vec<Result<RecordBatch>>>,
749        ) -> FutureStream {
750            Box::pin(async move {
751                maybe_recordbatches
752                    .map(|items| Box::pin(MaybeErrorRecordBatchStream { items }) as _)
753            })
754        }
755
756        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
757            "a",
758            ConcreteDataType::int32_datatype(),
759            false,
760        )]));
761        let batch1 = RecordBatch::new(
762            schema.clone(),
763            vec![Arc::new(Int32Vector::from_slice([1])) as _],
764        )
765        .unwrap();
766        let batch2 = RecordBatch::new(
767            schema.clone(),
768            vec![Arc::new(Int32Vector::from_slice([2])) as _],
769        )
770        .unwrap();
771
772        let success_stream = new_future_stream(Ok(vec![Ok(batch1.clone()), Ok(batch2.clone())]));
773        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), success_stream);
774        let collected = RecordBatches::try_collect(Box::pin(adapter)).await.unwrap();
775        assert_eq!(
776            collected,
777            RecordBatches::try_new(schema.clone(), vec![batch2.clone(), batch1.clone()]).unwrap()
778        );
779
780        let poll_err_stream = new_future_stream(Ok(vec![
781            Ok(batch1.clone()),
782            Err(error::ExternalSnafu
783                .into_error(BoxedError::new(MockError::new(StatusCode::Unknown)))),
784        ]));
785        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), poll_err_stream);
786        let err = RecordBatches::try_collect(Box::pin(adapter))
787            .await
788            .unwrap_err();
789        assert!(
790            matches!(err, Error::External { .. }),
791            "unexpected err {err}"
792        );
793
794        let failed_to_init_stream =
795            new_future_stream(Err(error::ExternalSnafu
796                .into_error(BoxedError::new(MockError::new(StatusCode::Internal)))));
797        let adapter = AsyncRecordBatchStreamAdapter::new(schema.clone(), failed_to_init_stream);
798        let err = RecordBatches::try_collect(Box::pin(adapter))
799            .await
800            .unwrap_err();
801        assert!(
802            matches!(err, Error::External { .. }),
803            "unexpected err {err}"
804        );
805    }
806
807    #[test]
808    fn test_convert_map_to_json_binary() {
809        let keys = StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("x")]);
810        let values = StringArray::from(vec![Some("1"), None, Some("3"), Some("42")]);
811        let key_field = Arc::new(Field::new("key", ArrowDataType::Utf8, false));
812        let value_field = Arc::new(Field::new("value", ArrowDataType::Utf8, true));
813        let struct_type = ArrowDataType::Struct(vec![key_field, value_field].into());
814
815        let entries_field = Arc::new(Field::new("entries", struct_type, false));
816
817        let struct_array = StructArray::from(vec![
818            (
819                Arc::new(Field::new("key", ArrowDataType::Utf8, false)),
820                Arc::new(keys) as ArrayRef,
821            ),
822            (
823                Arc::new(Field::new("value", ArrowDataType::Utf8, true)),
824                Arc::new(values) as ArrayRef,
825            ),
826        ]);
827
828        let offsets = OffsetBuffer::from_lengths([3, 0, 1]);
829        let nulls = datatypes::arrow::buffer::NullBuffer::from(vec![true, false, true]);
830
831        let map_array = MapArray::new(
832            entries_field,
833            offsets,
834            struct_array,
835            Some(nulls), // nulls
836            false,
837        );
838
839        let result = convert_map_to_json_binary(&map_array, None).unwrap();
840        let binary_array = result
841            .as_any()
842            .downcast_ref::<datatypes::arrow::array::BinaryArray>()
843            .unwrap();
844
845        let expected_jsons = [
846            Some(r#"{"a":"1","b":null,"c":"3"}"#),
847            None,
848            Some(r#"{"x":"42"}"#),
849        ];
850
851        for (i, _) in expected_jsons.iter().enumerate() {
852            if let Some(expected) = &expected_jsons[i] {
853                assert!(!binary_array.is_null(i));
854                let actual_bytes = binary_array.value(i);
855                let actual_str = std::str::from_utf8(actual_bytes).unwrap();
856                assert_eq!(actual_str, *expected);
857            } else {
858                assert!(binary_array.is_null(i));
859            }
860        }
861
862        let result_json =
863            convert_map_to_json_binary(&map_array, Some(ColumnExtType::Json)).unwrap();
864        let binary_array_json = result_json
865            .as_any()
866            .downcast_ref::<datatypes::arrow::array::BinaryArray>()
867            .unwrap();
868
869        for (i, _) in expected_jsons.iter().enumerate() {
870            if expected_jsons[i].is_some() {
871                assert!(!binary_array_json.is_null(i));
872                let actual_bytes = binary_array_json.value(i);
873                assert_ne!(actual_bytes, expected_jsons[i].unwrap().as_bytes());
874            } else {
875                assert!(binary_array_json.is_null(i));
876            }
877        }
878    }
879}