promql/extension_plan/
union_distinct_on.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::DFSchemaRef;
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32    hash_utils, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
33    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
34};
35use datatypes::arrow::compute;
36use futures::future::BoxFuture;
37use futures::{ready, Stream, StreamExt, TryStreamExt};
38
39/// A special kind of `UNION`(`OR` in PromQL) operator, for PromQL specific use case.
40///
41/// This operator is similar to `UNION` from SQL, but it only accepts two inputs. The
42/// most different part is that it treat left child and right child differently:
43/// - All columns from left child will be outputted.
44/// - Only check collisions (when not distinct) on the columns specified by `compare_keys`.
45/// - When there is a collision:
46///   - If the collision is from right child itself, only the first observed row will be
47///     preserved. All others are discarded.
48///   - If the collision is from left child, the row in right child will be discarded.
49/// - The output order is not maintained. This plan will output left child first, then right child.
50/// - The output schema contains all columns from left or right child plans.
51///
52/// From the implementation perspective, this operator is similar to `HashJoin`, but the
53/// probe side is the right child, and the build side is the left child. Another difference
54/// is that the probe is opting-out.
55///
56/// This plan will exhaust the right child first to build probe hash table, then streaming
57/// on left side, and use the left side to "mask" the hash table.
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct UnionDistinctOn {
60    left: LogicalPlan,
61    right: LogicalPlan,
62    /// The columns to compare for equality.
63    /// TIME INDEX is included.
64    compare_keys: Vec<String>,
65    ts_col: String,
66    output_schema: DFSchemaRef,
67}
68
69impl UnionDistinctOn {
70    pub fn name() -> &'static str {
71        "UnionDistinctOn"
72    }
73
74    pub fn new(
75        left: LogicalPlan,
76        right: LogicalPlan,
77        compare_keys: Vec<String>,
78        ts_col: String,
79        output_schema: DFSchemaRef,
80    ) -> Self {
81        Self {
82            left,
83            right,
84            compare_keys,
85            ts_col,
86            output_schema,
87        }
88    }
89
90    pub fn to_execution_plan(
91        &self,
92        left_exec: Arc<dyn ExecutionPlan>,
93        right_exec: Arc<dyn ExecutionPlan>,
94    ) -> Arc<dyn ExecutionPlan> {
95        let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
96        let properties = Arc::new(PlanProperties::new(
97            EquivalenceProperties::new(output_schema.clone()),
98            Partitioning::UnknownPartitioning(1),
99            EmissionType::Incremental,
100            Boundedness::Bounded,
101        ));
102        Arc::new(UnionDistinctOnExec {
103            left: left_exec,
104            right: right_exec,
105            compare_keys: self.compare_keys.clone(),
106            ts_col: self.ts_col.clone(),
107            output_schema,
108            metric: ExecutionPlanMetricsSet::new(),
109            properties,
110            random_state: RandomState::new(),
111        })
112    }
113}
114
115impl PartialOrd for UnionDistinctOn {
116    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
117        // Compare fields in order excluding output_schema
118        match self.left.partial_cmp(&other.left) {
119            Some(core::cmp::Ordering::Equal) => {}
120            ord => return ord,
121        }
122        match self.right.partial_cmp(&other.right) {
123            Some(core::cmp::Ordering::Equal) => {}
124            ord => return ord,
125        }
126        match self.compare_keys.partial_cmp(&other.compare_keys) {
127            Some(core::cmp::Ordering::Equal) => {}
128            ord => return ord,
129        }
130        self.ts_col.partial_cmp(&other.ts_col)
131    }
132}
133
134impl UserDefinedLogicalNodeCore for UnionDistinctOn {
135    fn name(&self) -> &str {
136        Self::name()
137    }
138
139    fn inputs(&self) -> Vec<&LogicalPlan> {
140        vec![&self.left, &self.right]
141    }
142
143    fn schema(&self) -> &DFSchemaRef {
144        &self.output_schema
145    }
146
147    fn expressions(&self) -> Vec<Expr> {
148        vec![]
149    }
150
151    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
155            self.compare_keys, self.ts_col
156        )
157    }
158
159    fn with_exprs_and_inputs(
160        &self,
161        _exprs: Vec<Expr>,
162        inputs: Vec<LogicalPlan>,
163    ) -> DataFusionResult<Self> {
164        if inputs.len() != 2 {
165            return Err(DataFusionError::Internal(
166                "UnionDistinctOn must have exactly 2 inputs".to_string(),
167            ));
168        }
169
170        let mut inputs = inputs.into_iter();
171        let left = inputs.next().unwrap();
172        let right = inputs.next().unwrap();
173
174        Ok(Self {
175            left,
176            right,
177            compare_keys: self.compare_keys.clone(),
178            ts_col: self.ts_col.clone(),
179            output_schema: self.output_schema.clone(),
180        })
181    }
182}
183
184#[derive(Debug)]
185pub struct UnionDistinctOnExec {
186    left: Arc<dyn ExecutionPlan>,
187    right: Arc<dyn ExecutionPlan>,
188    compare_keys: Vec<String>,
189    ts_col: String,
190    output_schema: SchemaRef,
191    metric: ExecutionPlanMetricsSet,
192    properties: Arc<PlanProperties>,
193
194    /// Shared the `RandomState` for the hashing algorithm
195    random_state: RandomState,
196}
197
198impl ExecutionPlan for UnionDistinctOnExec {
199    fn as_any(&self) -> &dyn Any {
200        self
201    }
202
203    fn schema(&self) -> SchemaRef {
204        self.output_schema.clone()
205    }
206
207    fn required_input_distribution(&self) -> Vec<Distribution> {
208        vec![Distribution::SinglePartition, Distribution::SinglePartition]
209    }
210
211    fn properties(&self) -> &PlanProperties {
212        self.properties.as_ref()
213    }
214
215    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
216        vec![&self.left, &self.right]
217    }
218
219    fn with_new_children(
220        self: Arc<Self>,
221        children: Vec<Arc<dyn ExecutionPlan>>,
222    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
223        assert_eq!(children.len(), 2);
224
225        let left = children[0].clone();
226        let right = children[1].clone();
227        Ok(Arc::new(UnionDistinctOnExec {
228            left,
229            right,
230            compare_keys: self.compare_keys.clone(),
231            ts_col: self.ts_col.clone(),
232            output_schema: self.output_schema.clone(),
233            metric: self.metric.clone(),
234            properties: self.properties.clone(),
235            random_state: self.random_state.clone(),
236        }))
237    }
238
239    fn execute(
240        &self,
241        partition: usize,
242        context: Arc<TaskContext>,
243    ) -> DataFusionResult<SendableRecordBatchStream> {
244        let left_stream = self.left.execute(partition, context.clone())?;
245        let right_stream = self.right.execute(partition, context.clone())?;
246
247        // Convert column name to column index. Add one for the time column.
248        let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
249        for key in &self.compare_keys {
250            let index = self
251                .output_schema
252                .column_with_name(key)
253                .map(|(i, _)| i)
254                .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
255            key_indices.push(index);
256        }
257        let ts_index = self
258            .output_schema
259            .column_with_name(&self.ts_col)
260            .map(|(i, _)| i)
261            .ok_or_else(|| {
262                DataFusionError::Internal(format!("Column {} not found", self.ts_col))
263            })?;
264        key_indices.push(ts_index);
265
266        // Build right hash table future.
267        let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
268            right_stream,
269            self.random_state.clone(),
270            key_indices.clone(),
271        )));
272
273        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
274        Ok(Box::pin(UnionDistinctOnStream {
275            left: left_stream,
276            right: hashed_data_future,
277            compare_keys: key_indices,
278            output_schema: self.output_schema.clone(),
279            metric: baseline_metric,
280        }))
281    }
282
283    fn metrics(&self) -> Option<MetricsSet> {
284        Some(self.metric.clone_inner())
285    }
286
287    fn name(&self) -> &str {
288        "UnionDistinctOnExec"
289    }
290}
291
292impl DisplayAs for UnionDistinctOnExec {
293    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
294        match t {
295            DisplayFormatType::Default
296            | DisplayFormatType::Verbose
297            | DisplayFormatType::TreeRender => {
298                write!(
299                    f,
300                    "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
301                    self.compare_keys, self.ts_col
302                )
303            }
304        }
305    }
306}
307
308// TODO(ruihang): some unused fields are for metrics, which will be implemented later.
309#[allow(dead_code)]
310pub struct UnionDistinctOnStream {
311    left: SendableRecordBatchStream,
312    right: HashedDataFut,
313    /// Include time index
314    compare_keys: Vec<usize>,
315    output_schema: SchemaRef,
316    metric: BaselineMetrics,
317}
318
319impl UnionDistinctOnStream {
320    fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
321        // resolve the right stream
322        let right = match self.right {
323            HashedDataFut::Pending(ref mut fut) => {
324                let right = ready!(fut.as_mut().poll(cx))?;
325                self.right = HashedDataFut::Ready(right);
326                let HashedDataFut::Ready(right_ref) = &mut self.right else {
327                    unreachable!()
328                };
329                right_ref
330            }
331            HashedDataFut::Ready(ref mut right) => right,
332            HashedDataFut::Empty => return Poll::Ready(None),
333        };
334
335        // poll left and probe with right
336        let next_left = ready!(self.left.poll_next_unpin(cx));
337        match next_left {
338            Some(Ok(left)) => {
339                // observe left batch and return it
340                right.update_map(&left)?;
341                Poll::Ready(Some(Ok(left)))
342            }
343            Some(Err(e)) => Poll::Ready(Some(Err(e))),
344            None => {
345                // left stream is exhausted, so we can send the right part
346                let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
347                let HashedDataFut::Ready(data) = right else {
348                    unreachable!()
349                };
350                Poll::Ready(Some(data.finish()))
351            }
352        }
353    }
354}
355
356impl RecordBatchStream for UnionDistinctOnStream {
357    fn schema(&self) -> SchemaRef {
358        self.output_schema.clone()
359    }
360}
361
362impl Stream for UnionDistinctOnStream {
363    type Item = DataFusionResult<RecordBatch>;
364
365    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
366        self.poll_impl(cx)
367    }
368}
369
370/// Simple future state for [HashedData]
371enum HashedDataFut {
372    /// The result is not ready
373    Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
374    /// The result is ready
375    Ready(HashedData),
376    /// The result is taken
377    Empty,
378}
379
380/// ALL input batches and its hash table
381struct HashedData {
382    // TODO(ruihang): use `JoinHashMap` instead after upgrading to DF 34.0
383    /// Hash table for all input batches. The key is hash value, and the value
384    /// is the index of `bathc`.
385    hash_map: HashMap<u64, usize>,
386    /// Output batch.
387    batch: RecordBatch,
388    /// The indices of the columns to be hashed.
389    hash_key_indices: Vec<usize>,
390    random_state: RandomState,
391}
392
393impl HashedData {
394    pub async fn new(
395        input: SendableRecordBatchStream,
396        random_state: RandomState,
397        hash_key_indices: Vec<usize>,
398    ) -> DataFusionResult<Self> {
399        // Collect all batches from the input stream
400        let initial = (Vec::new(), 0);
401        let schema = input.schema();
402        let (batches, _num_rows) = input
403            .try_fold(initial, |mut acc, batch| async {
404                // Update rowcount
405                acc.1 += batch.num_rows();
406                // Push batch to output
407                acc.0.push(batch);
408                Ok(acc)
409            })
410            .await?;
411
412        // Create hash for each batch
413        let mut hash_map = HashMap::default();
414        let mut hashes_buffer = Vec::new();
415        let mut interleave_indices = Vec::new();
416        for (batch_number, batch) in batches.iter().enumerate() {
417            hashes_buffer.resize(batch.num_rows(), 0);
418            // get columns for hashing
419            let arrays = hash_key_indices
420                .iter()
421                .map(|i| batch.column(*i).clone())
422                .collect::<Vec<_>>();
423
424            // compute hash
425            let hash_values =
426                hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
427            for (row_number, hash_value) in hash_values.iter().enumerate() {
428                // Only keeps the first observed row for each hash value
429                if hash_map
430                    .try_insert(*hash_value, interleave_indices.len())
431                    .is_ok()
432                {
433                    interleave_indices.push((batch_number, row_number));
434                }
435            }
436        }
437
438        // Finalize the hash map
439        let batch = interleave_batches(schema, batches, interleave_indices)?;
440
441        Ok(Self {
442            hash_map,
443            batch,
444            hash_key_indices,
445            random_state,
446        })
447    }
448
449    /// Remove rows that hash value present in the input
450    /// record batch from the hash map.
451    pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
452        // get columns for hashing
453        let mut hashes_buffer = Vec::new();
454        let arrays = self
455            .hash_key_indices
456            .iter()
457            .map(|i| input.column(*i).clone())
458            .collect::<Vec<_>>();
459
460        // compute hash
461        hashes_buffer.resize(input.num_rows(), 0);
462        let hash_values =
463            hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
464
465        // remove those hashes
466        for hash in hash_values {
467            self.hash_map.remove(hash);
468        }
469
470        Ok(())
471    }
472
473    pub fn finish(self) -> DataFusionResult<RecordBatch> {
474        let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
475        let result = take_batch(&self.batch, &valid_indices)?;
476        Ok(result)
477    }
478}
479
480/// Utility function to interleave batches. Based on [interleave](datafusion::arrow::compute::interleave)
481fn interleave_batches(
482    schema: SchemaRef,
483    batches: Vec<RecordBatch>,
484    indices: Vec<(usize, usize)>,
485) -> DataFusionResult<RecordBatch> {
486    if batches.is_empty() {
487        if indices.is_empty() {
488            return Ok(RecordBatch::new_empty(schema));
489        } else {
490            return Err(DataFusionError::Internal(
491                "Cannot interleave empty batches with non-empty indices".to_string(),
492            ));
493        }
494    }
495
496    // transform batches into arrays
497    let mut arrays = vec![vec![]; schema.fields().len()];
498    for batch in &batches {
499        for (i, array) in batch.columns().iter().enumerate() {
500            arrays[i].push(array.as_ref());
501        }
502    }
503
504    // interleave arrays
505    let interleaved_arrays: Vec<_> = arrays
506        .into_iter()
507        .map(|array| compute::interleave(&array, &indices))
508        .collect::<Result<_, _>>()?;
509
510    // assemble new record batch
511    RecordBatch::try_new(schema, interleaved_arrays)
512        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
513}
514
515/// Utility function to take rows from a record batch. Based on [take](datafusion::arrow::compute::take)
516fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
517    // fast path
518    if batch.num_rows() == indices.len() {
519        return Ok(batch.clone());
520    }
521
522    let schema = batch.schema();
523
524    let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
525    let arrays = batch
526        .columns()
527        .iter()
528        .map(|array| compute::take(array, &indices_array, None))
529        .collect::<std::result::Result<Vec<_>, _>>()
530        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
531
532    let result = RecordBatch::try_new(schema, arrays)
533        .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
534    Ok(result)
535}
536
537#[cfg(test)]
538mod test {
539    use std::sync::Arc;
540
541    use datafusion::arrow::array::Int32Array;
542    use datafusion::arrow::datatypes::{DataType, Field, Schema};
543
544    use super::*;
545
546    #[test]
547    fn test_interleave_batches() {
548        let schema = Schema::new(vec![
549            Field::new("a", DataType::Int32, false),
550            Field::new("b", DataType::Int32, false),
551        ]);
552
553        let batch1 = RecordBatch::try_new(
554            Arc::new(schema.clone()),
555            vec![
556                Arc::new(Int32Array::from(vec![1, 2, 3])),
557                Arc::new(Int32Array::from(vec![4, 5, 6])),
558            ],
559        )
560        .unwrap();
561
562        let batch2 = RecordBatch::try_new(
563            Arc::new(schema.clone()),
564            vec![
565                Arc::new(Int32Array::from(vec![7, 8, 9])),
566                Arc::new(Int32Array::from(vec![10, 11, 12])),
567            ],
568        )
569        .unwrap();
570
571        let batch3 = RecordBatch::try_new(
572            Arc::new(schema.clone()),
573            vec![
574                Arc::new(Int32Array::from(vec![13, 14, 15])),
575                Arc::new(Int32Array::from(vec![16, 17, 18])),
576            ],
577        )
578        .unwrap();
579
580        let batches = vec![batch1, batch2, batch3];
581        let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
582        let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
583
584        let expected = RecordBatch::try_new(
585            Arc::new(schema),
586            vec![
587                Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
588                Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
589            ],
590        )
591        .unwrap();
592
593        assert_eq!(result, expected);
594    }
595
596    #[test]
597    fn test_take_batch() {
598        let schema = Schema::new(vec![
599            Field::new("a", DataType::Int32, false),
600            Field::new("b", DataType::Int32, false),
601        ]);
602
603        let batch = RecordBatch::try_new(
604            Arc::new(schema.clone()),
605            vec![
606                Arc::new(Int32Array::from(vec![1, 2, 3])),
607                Arc::new(Int32Array::from(vec![4, 5, 6])),
608            ],
609        )
610        .unwrap();
611
612        let indices = vec![0, 2];
613        let result = take_batch(&batch, &indices).unwrap();
614
615        let expected = RecordBatch::try_new(
616            Arc::new(schema),
617            vec![
618                Arc::new(Int32Array::from(vec![1, 3])),
619                Arc::new(Int32Array::from(vec![4, 6])),
620            ],
621        )
622        .unwrap();
623
624        assert_eq!(result, expected);
625    }
626}