promql/extension_plan/
union_distinct_on.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::{DFSchema, DFSchemaRef};
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
33    RecordBatchStream, SendableRecordBatchStream, hash_utils,
34};
35use datafusion_expr::col;
36use datatypes::arrow::compute;
37use futures::future::BoxFuture;
38use futures::{Stream, StreamExt, TryStreamExt, ready};
39use greptime_proto::substrait_extension as pb;
40use prost::Message;
41use snafu::ResultExt;
42
43use crate::error::{DeserializeSnafu, Result};
44use crate::extension_plan::{resolve_column_name, serialize_column_index};
45
46/// A special kind of `UNION`(`OR` in PromQL) operator, for PromQL specific use case.
47///
48/// This operator is similar to `UNION` from SQL, but it only accepts two inputs. The
49/// most different part is that it treat left child and right child differently:
50/// - All columns from left child will be outputted.
51/// - Only check collisions (when not distinct) on the columns specified by `compare_keys`.
52/// - When there is a collision:
53///   - If the collision is from right child itself, only the first observed row will be
54///     preserved. All others are discarded.
55///   - If the collision is from left child, the row in right child will be discarded.
56/// - The output order is not maintained. This plan will output left child first, then right child.
57/// - The output schema contains all columns from left or right child plans.
58///
59/// From the implementation perspective, this operator is similar to `HashJoin`, but the
60/// probe side is the right child, and the build side is the left child. Another difference
61/// is that the probe is opting-out.
62///
63/// This plan will exhaust the right child first to build probe hash table, then streaming
64/// on left side, and use the left side to "mask" the hash table.
65#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct UnionDistinctOn {
67    left: LogicalPlan,
68    right: LogicalPlan,
69    /// The columns to compare for equality.
70    /// TIME INDEX is included.
71    compare_keys: Vec<String>,
72    ts_col: String,
73    output_schema: DFSchemaRef,
74    unfix: Option<UnfixIndices>,
75}
76
77#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
78struct UnfixIndices {
79    pub compare_key_indices: Vec<u64>,
80    pub ts_col_idx: u64,
81}
82
83impl UnionDistinctOn {
84    pub fn name() -> &'static str {
85        "UnionDistinctOn"
86    }
87
88    pub fn new(
89        left: LogicalPlan,
90        right: LogicalPlan,
91        compare_keys: Vec<String>,
92        ts_col: String,
93        output_schema: DFSchemaRef,
94    ) -> Self {
95        Self {
96            left,
97            right,
98            compare_keys,
99            ts_col,
100            output_schema,
101            unfix: None,
102        }
103    }
104
105    pub fn to_execution_plan(
106        &self,
107        left_exec: Arc<dyn ExecutionPlan>,
108        right_exec: Arc<dyn ExecutionPlan>,
109    ) -> Arc<dyn ExecutionPlan> {
110        let output_schema: SchemaRef = self.output_schema.inner().clone();
111        let properties = Arc::new(PlanProperties::new(
112            EquivalenceProperties::new(output_schema.clone()),
113            Partitioning::UnknownPartitioning(1),
114            EmissionType::Incremental,
115            Boundedness::Bounded,
116        ));
117        Arc::new(UnionDistinctOnExec {
118            left: left_exec,
119            right: right_exec,
120            compare_keys: self.compare_keys.clone(),
121            ts_col: self.ts_col.clone(),
122            output_schema,
123            metric: ExecutionPlanMetricsSet::new(),
124            properties,
125            random_state: RandomState::new(),
126        })
127    }
128
129    pub fn serialize(&self) -> Vec<u8> {
130        let compare_key_indices = self
131            .compare_keys
132            .iter()
133            .map(|name| serialize_column_index(&self.output_schema, name))
134            .collect::<Vec<u64>>();
135
136        let ts_col_idx = serialize_column_index(&self.output_schema, &self.ts_col);
137
138        pb::UnionDistinctOn {
139            compare_key_indices,
140            ts_col_idx,
141        }
142        .encode_to_vec()
143    }
144
145    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
146        let pb_union = pb::UnionDistinctOn::decode(bytes).context(DeserializeSnafu)?;
147        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
148            produce_one_row: false,
149            schema: Arc::new(DFSchema::empty()),
150        });
151
152        let unfix = UnfixIndices {
153            compare_key_indices: pb_union.compare_key_indices.clone(),
154            ts_col_idx: pb_union.ts_col_idx,
155        };
156
157        Ok(Self {
158            left: placeholder_plan.clone(),
159            right: placeholder_plan,
160            compare_keys: Vec::new(),
161            ts_col: String::new(),
162            output_schema: Arc::new(DFSchema::empty()),
163            unfix: Some(unfix),
164        })
165    }
166}
167
168impl PartialOrd for UnionDistinctOn {
169    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
170        // Compare fields in order excluding output_schema
171        match self.left.partial_cmp(&other.left) {
172            Some(core::cmp::Ordering::Equal) => {}
173            ord => return ord,
174        }
175        match self.right.partial_cmp(&other.right) {
176            Some(core::cmp::Ordering::Equal) => {}
177            ord => return ord,
178        }
179        match self.compare_keys.partial_cmp(&other.compare_keys) {
180            Some(core::cmp::Ordering::Equal) => {}
181            ord => return ord,
182        }
183        self.ts_col.partial_cmp(&other.ts_col)
184    }
185}
186
187impl UserDefinedLogicalNodeCore for UnionDistinctOn {
188    fn name(&self) -> &str {
189        Self::name()
190    }
191
192    fn inputs(&self) -> Vec<&LogicalPlan> {
193        vec![&self.left, &self.right]
194    }
195
196    fn schema(&self) -> &DFSchemaRef {
197        &self.output_schema
198    }
199
200    fn expressions(&self) -> Vec<Expr> {
201        if self.unfix.is_some() {
202            return vec![];
203        }
204
205        let mut exprs: Vec<Expr> = self.compare_keys.iter().map(col).collect();
206        if !self.compare_keys.iter().any(|key| key == &self.ts_col) {
207            exprs.push(col(&self.ts_col));
208        }
209        exprs
210    }
211
212    fn necessary_children_exprs(&self, _output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
213        if self.unfix.is_some() {
214            return None;
215        }
216
217        let left_len = self.left.schema().fields().len();
218        let right_len = self.right.schema().fields().len();
219        Some(vec![
220            (0..left_len).collect::<Vec<_>>(),
221            (0..right_len).collect::<Vec<_>>(),
222        ])
223    }
224
225    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        write!(
227            f,
228            "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
229            self.compare_keys, self.ts_col
230        )
231    }
232
233    fn with_exprs_and_inputs(
234        &self,
235        _exprs: Vec<Expr>,
236        inputs: Vec<LogicalPlan>,
237    ) -> DataFusionResult<Self> {
238        if inputs.len() != 2 {
239            return Err(DataFusionError::Internal(
240                "UnionDistinctOn must have exactly 2 inputs".to_string(),
241            ));
242        }
243
244        let mut inputs = inputs.into_iter();
245        let left = inputs.next().unwrap();
246        let right = inputs.next().unwrap();
247
248        if let Some(unfix) = &self.unfix {
249            let output_schema = left.schema().clone();
250
251            let compare_keys = unfix
252                .compare_key_indices
253                .iter()
254                .map(|idx| {
255                    resolve_column_name(*idx, &output_schema, "UnionDistinctOn", "compare key")
256                })
257                .collect::<DataFusionResult<Vec<String>>>()?;
258
259            let ts_col =
260                resolve_column_name(unfix.ts_col_idx, &output_schema, "UnionDistinctOn", "ts")?;
261
262            Ok(Self {
263                left,
264                right,
265                compare_keys,
266                ts_col,
267                output_schema,
268                unfix: None,
269            })
270        } else {
271            Ok(Self {
272                left,
273                right,
274                compare_keys: self.compare_keys.clone(),
275                ts_col: self.ts_col.clone(),
276                output_schema: self.output_schema.clone(),
277                unfix: None,
278            })
279        }
280    }
281}
282
283#[derive(Debug)]
284pub struct UnionDistinctOnExec {
285    left: Arc<dyn ExecutionPlan>,
286    right: Arc<dyn ExecutionPlan>,
287    compare_keys: Vec<String>,
288    ts_col: String,
289    output_schema: SchemaRef,
290    metric: ExecutionPlanMetricsSet,
291    properties: Arc<PlanProperties>,
292
293    /// Shared the `RandomState` for the hashing algorithm
294    random_state: RandomState,
295}
296
297impl ExecutionPlan for UnionDistinctOnExec {
298    fn as_any(&self) -> &dyn Any {
299        self
300    }
301
302    fn schema(&self) -> SchemaRef {
303        self.output_schema.clone()
304    }
305
306    fn required_input_distribution(&self) -> Vec<Distribution> {
307        vec![Distribution::SinglePartition, Distribution::SinglePartition]
308    }
309
310    fn properties(&self) -> &PlanProperties {
311        self.properties.as_ref()
312    }
313
314    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
315        vec![&self.left, &self.right]
316    }
317
318    fn with_new_children(
319        self: Arc<Self>,
320        children: Vec<Arc<dyn ExecutionPlan>>,
321    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
322        assert_eq!(children.len(), 2);
323
324        let left = children[0].clone();
325        let right = children[1].clone();
326        Ok(Arc::new(UnionDistinctOnExec {
327            left,
328            right,
329            compare_keys: self.compare_keys.clone(),
330            ts_col: self.ts_col.clone(),
331            output_schema: self.output_schema.clone(),
332            metric: self.metric.clone(),
333            properties: self.properties.clone(),
334            random_state: self.random_state.clone(),
335        }))
336    }
337
338    fn execute(
339        &self,
340        partition: usize,
341        context: Arc<TaskContext>,
342    ) -> DataFusionResult<SendableRecordBatchStream> {
343        let left_stream = self.left.execute(partition, context.clone())?;
344        let right_stream = self.right.execute(partition, context.clone())?;
345
346        // Convert column name to column index. Add one for the time column.
347        let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
348        for key in &self.compare_keys {
349            let index = self
350                .output_schema
351                .column_with_name(key)
352                .map(|(i, _)| i)
353                .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
354            key_indices.push(index);
355        }
356        let ts_index = self
357            .output_schema
358            .column_with_name(&self.ts_col)
359            .map(|(i, _)| i)
360            .ok_or_else(|| {
361                DataFusionError::Internal(format!("Column {} not found", self.ts_col))
362            })?;
363        key_indices.push(ts_index);
364
365        // Build right hash table future.
366        let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
367            right_stream,
368            self.random_state.clone(),
369            key_indices.clone(),
370        )));
371
372        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
373        Ok(Box::pin(UnionDistinctOnStream {
374            left: left_stream,
375            right: hashed_data_future,
376            compare_keys: key_indices,
377            output_schema: self.output_schema.clone(),
378            metric: baseline_metric,
379        }))
380    }
381
382    fn metrics(&self) -> Option<MetricsSet> {
383        Some(self.metric.clone_inner())
384    }
385
386    fn name(&self) -> &str {
387        "UnionDistinctOnExec"
388    }
389}
390
391impl DisplayAs for UnionDistinctOnExec {
392    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
393        match t {
394            DisplayFormatType::Default
395            | DisplayFormatType::Verbose
396            | DisplayFormatType::TreeRender => {
397                write!(
398                    f,
399                    "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
400                    self.compare_keys, self.ts_col
401                )
402            }
403        }
404    }
405}
406
407// TODO(ruihang): some unused fields are for metrics, which will be implemented later.
408#[allow(dead_code)]
409pub struct UnionDistinctOnStream {
410    left: SendableRecordBatchStream,
411    right: HashedDataFut,
412    /// Include time index
413    compare_keys: Vec<usize>,
414    output_schema: SchemaRef,
415    metric: BaselineMetrics,
416}
417
418impl UnionDistinctOnStream {
419    fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
420        // resolve the right stream
421        let right = match self.right {
422            HashedDataFut::Pending(ref mut fut) => {
423                let right = ready!(fut.as_mut().poll(cx))?;
424                self.right = HashedDataFut::Ready(right);
425                let HashedDataFut::Ready(right_ref) = &mut self.right else {
426                    unreachable!()
427                };
428                right_ref
429            }
430            HashedDataFut::Ready(ref mut right) => right,
431            HashedDataFut::Empty => return Poll::Ready(None),
432        };
433
434        // poll left and probe with right
435        let next_left = ready!(self.left.poll_next_unpin(cx));
436        match next_left {
437            Some(Ok(left)) => {
438                // observe left batch and return it
439                right.update_map(&left)?;
440                Poll::Ready(Some(Ok(left)))
441            }
442            Some(Err(e)) => Poll::Ready(Some(Err(e))),
443            None => {
444                // left stream is exhausted, so we can send the right part
445                let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
446                let HashedDataFut::Ready(data) = right else {
447                    unreachable!()
448                };
449                Poll::Ready(Some(data.finish()))
450            }
451        }
452    }
453}
454
455impl RecordBatchStream for UnionDistinctOnStream {
456    fn schema(&self) -> SchemaRef {
457        self.output_schema.clone()
458    }
459}
460
461impl Stream for UnionDistinctOnStream {
462    type Item = DataFusionResult<RecordBatch>;
463
464    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
465        self.poll_impl(cx)
466    }
467}
468
469/// Simple future state for [HashedData]
470enum HashedDataFut {
471    /// The result is not ready
472    Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
473    /// The result is ready
474    Ready(HashedData),
475    /// The result is taken
476    Empty,
477}
478
479/// ALL input batches and its hash table
480struct HashedData {
481    // TODO(ruihang): use `JoinHashMap` instead after upgrading to DF 34.0
482    /// Hash table for all input batches. The key is hash value, and the value
483    /// is the index of `bathc`.
484    hash_map: HashMap<u64, usize>,
485    /// Output batch.
486    batch: RecordBatch,
487    /// The indices of the columns to be hashed.
488    hash_key_indices: Vec<usize>,
489    random_state: RandomState,
490}
491
492impl HashedData {
493    pub async fn new(
494        input: SendableRecordBatchStream,
495        random_state: RandomState,
496        hash_key_indices: Vec<usize>,
497    ) -> DataFusionResult<Self> {
498        // Collect all batches from the input stream
499        let initial = (Vec::new(), 0);
500        let schema = input.schema();
501        let (batches, _num_rows) = input
502            .try_fold(initial, |mut acc, batch| async {
503                // Update rowcount
504                acc.1 += batch.num_rows();
505                // Push batch to output
506                acc.0.push(batch);
507                Ok(acc)
508            })
509            .await?;
510
511        // Create hash for each batch
512        let mut hash_map = HashMap::default();
513        let mut hashes_buffer = Vec::new();
514        let mut interleave_indices = Vec::new();
515        for (batch_number, batch) in batches.iter().enumerate() {
516            hashes_buffer.resize(batch.num_rows(), 0);
517            // get columns for hashing
518            let arrays = hash_key_indices
519                .iter()
520                .map(|i| batch.column(*i).clone())
521                .collect::<Vec<_>>();
522
523            // compute hash
524            let hash_values =
525                hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
526            for (row_number, hash_value) in hash_values.iter().enumerate() {
527                // Only keeps the first observed row for each hash value
528                if hash_map
529                    .try_insert(*hash_value, interleave_indices.len())
530                    .is_ok()
531                {
532                    interleave_indices.push((batch_number, row_number));
533                }
534            }
535        }
536
537        // Finalize the hash map
538        let batch = interleave_batches(schema, batches, interleave_indices)?;
539
540        Ok(Self {
541            hash_map,
542            batch,
543            hash_key_indices,
544            random_state,
545        })
546    }
547
548    /// Remove rows that hash value present in the input
549    /// record batch from the hash map.
550    pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
551        // get columns for hashing
552        let mut hashes_buffer = Vec::new();
553        let arrays = self
554            .hash_key_indices
555            .iter()
556            .map(|i| input.column(*i).clone())
557            .collect::<Vec<_>>();
558
559        // compute hash
560        hashes_buffer.resize(input.num_rows(), 0);
561        let hash_values =
562            hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
563
564        // remove those hashes
565        for hash in hash_values {
566            self.hash_map.remove(hash);
567        }
568
569        Ok(())
570    }
571
572    pub fn finish(self) -> DataFusionResult<RecordBatch> {
573        let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
574        let result = take_batch(&self.batch, &valid_indices)?;
575        Ok(result)
576    }
577}
578
579/// Utility function to interleave batches. Based on [interleave](datafusion::arrow::compute::interleave)
580fn interleave_batches(
581    schema: SchemaRef,
582    batches: Vec<RecordBatch>,
583    indices: Vec<(usize, usize)>,
584) -> DataFusionResult<RecordBatch> {
585    if batches.is_empty() {
586        if indices.is_empty() {
587            return Ok(RecordBatch::new_empty(schema));
588        } else {
589            return Err(DataFusionError::Internal(
590                "Cannot interleave empty batches with non-empty indices".to_string(),
591            ));
592        }
593    }
594
595    // transform batches into arrays
596    let mut arrays = vec![vec![]; schema.fields().len()];
597    for batch in &batches {
598        for (i, array) in batch.columns().iter().enumerate() {
599            arrays[i].push(array.as_ref());
600        }
601    }
602
603    // interleave arrays
604    let interleaved_arrays: Vec<_> = arrays
605        .into_iter()
606        .map(|array| compute::interleave(&array, &indices))
607        .collect::<std::result::Result<_, _>>()?;
608
609    // assemble new record batch
610    RecordBatch::try_new(schema, interleaved_arrays)
611        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
612}
613
614/// Utility function to take rows from a record batch. Based on [take](datafusion::arrow::compute::take)
615fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
616    // fast path
617    if batch.num_rows() == indices.len() {
618        return Ok(batch.clone());
619    }
620
621    let schema = batch.schema();
622
623    let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
624    let arrays = batch
625        .columns()
626        .iter()
627        .map(|array| compute::take(array, &indices_array, None))
628        .collect::<std::result::Result<Vec<_>, _>>()
629        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
630
631    let result = RecordBatch::try_new(schema, arrays)
632        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
633    Ok(result)
634}
635
636#[cfg(test)]
637mod test {
638    use std::sync::Arc;
639
640    use datafusion::arrow::array::Int32Array;
641    use datafusion::arrow::datatypes::{DataType, Field, Schema};
642    use datafusion::common::ToDFSchema;
643    use datafusion::logical_expr::{EmptyRelation, LogicalPlan};
644
645    use super::*;
646
647    #[test]
648    fn pruning_should_keep_all_columns_for_exec() {
649        let schema = Arc::new(Schema::new(vec![
650            Field::new("ts", DataType::Int32, false),
651            Field::new("k", DataType::Int32, false),
652            Field::new("v", DataType::Int32, false),
653        ]));
654        let df_schema = schema.to_dfschema_ref().unwrap();
655        let left = LogicalPlan::EmptyRelation(EmptyRelation {
656            produce_one_row: false,
657            schema: df_schema.clone(),
658        });
659        let right = LogicalPlan::EmptyRelation(EmptyRelation {
660            produce_one_row: false,
661            schema: df_schema.clone(),
662        });
663        let plan = UnionDistinctOn::new(
664            left,
665            right,
666            vec!["k".to_string()],
667            "ts".to_string(),
668            df_schema,
669        );
670
671        // Simulate a parent projection requesting only one output column.
672        let output_columns = [2usize];
673        let required = plan.necessary_children_exprs(&output_columns).unwrap();
674        assert_eq!(required.len(), 2);
675        assert_eq!(required[0].as_slice(), &[0, 1, 2]);
676        assert_eq!(required[1].as_slice(), &[0, 1, 2]);
677    }
678
679    #[test]
680    fn test_interleave_batches() {
681        let schema = Schema::new(vec![
682            Field::new("a", DataType::Int32, false),
683            Field::new("b", DataType::Int32, false),
684        ]);
685
686        let batch1 = RecordBatch::try_new(
687            Arc::new(schema.clone()),
688            vec![
689                Arc::new(Int32Array::from(vec![1, 2, 3])),
690                Arc::new(Int32Array::from(vec![4, 5, 6])),
691            ],
692        )
693        .unwrap();
694
695        let batch2 = RecordBatch::try_new(
696            Arc::new(schema.clone()),
697            vec![
698                Arc::new(Int32Array::from(vec![7, 8, 9])),
699                Arc::new(Int32Array::from(vec![10, 11, 12])),
700            ],
701        )
702        .unwrap();
703
704        let batch3 = RecordBatch::try_new(
705            Arc::new(schema.clone()),
706            vec![
707                Arc::new(Int32Array::from(vec![13, 14, 15])),
708                Arc::new(Int32Array::from(vec![16, 17, 18])),
709            ],
710        )
711        .unwrap();
712
713        let batches = vec![batch1, batch2, batch3];
714        let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
715        let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
716
717        let expected = RecordBatch::try_new(
718            Arc::new(schema),
719            vec![
720                Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
721                Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
722            ],
723        )
724        .unwrap();
725
726        assert_eq!(result, expected);
727    }
728
729    #[test]
730    fn test_take_batch() {
731        let schema = Schema::new(vec![
732            Field::new("a", DataType::Int32, false),
733            Field::new("b", DataType::Int32, false),
734        ]);
735
736        let batch = RecordBatch::try_new(
737            Arc::new(schema.clone()),
738            vec![
739                Arc::new(Int32Array::from(vec![1, 2, 3])),
740                Arc::new(Int32Array::from(vec![4, 5, 6])),
741            ],
742        )
743        .unwrap();
744
745        let indices = vec![0, 2];
746        let result = take_batch(&batch, &indices).unwrap();
747
748        let expected = RecordBatch::try_new(
749            Arc::new(schema),
750            vec![
751                Arc::new(Int32Array::from(vec![1, 3])),
752                Arc::new(Int32Array::from(vec![4, 6])),
753            ],
754        )
755        .unwrap();
756
757        assert_eq!(result, expected);
758    }
759
760    #[tokio::test]
761    async fn encode_decode_union_distinct_on() {
762        let schema = Arc::new(Schema::new(vec![
763            Field::new("ts", DataType::Int64, false),
764            Field::new("job", DataType::Utf8, false),
765            Field::new("value", DataType::Float64, false),
766        ]));
767        let df_schema = schema.clone().to_dfschema_ref().unwrap();
768        let left_plan = LogicalPlan::EmptyRelation(EmptyRelation {
769            produce_one_row: false,
770            schema: df_schema.clone(),
771        });
772        let right_plan = LogicalPlan::EmptyRelation(EmptyRelation {
773            produce_one_row: false,
774            schema: df_schema.clone(),
775        });
776        let plan_node = UnionDistinctOn::new(
777            left_plan.clone(),
778            right_plan.clone(),
779            vec!["job".to_string()],
780            "ts".to_string(),
781            df_schema.clone(),
782        );
783
784        let bytes = plan_node.serialize();
785
786        let union_distinct_on = UnionDistinctOn::deserialize(&bytes).unwrap();
787        let union_distinct_on = union_distinct_on
788            .with_exprs_and_inputs(vec![], vec![left_plan, right_plan])
789            .unwrap();
790
791        assert_eq!(union_distinct_on.compare_keys, vec!["job".to_string()]);
792        assert_eq!(union_distinct_on.ts_col, "ts");
793        assert_eq!(union_distinct_on.output_schema, df_schema);
794    }
795}