promql/extension_plan/
union_distinct_on.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::DFSchemaRef;
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32    hash_utils, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
33    PlanProperties, RecordBatchStream, SendableRecordBatchStream,
34};
35use datatypes::arrow::compute;
36use futures::future::BoxFuture;
37use futures::{ready, Stream, StreamExt, TryStreamExt};
38
39/// A special kind of `UNION`(`OR` in PromQL) operator, for PromQL specific use case.
40///
41/// This operator is similar to `UNION` from SQL, but it only accepts two inputs. The
42/// most different part is that it treat left child and right child differently:
43/// - All columns from left child will be outputted.
44/// - Only check collisions (when not distinct) on the columns specified by `compare_keys`.
45/// - When there is a collision:
46///   - If the collision is from right child itself, only the first observed row will be
47///     preserved. All others are discarded.
48///   - If the collision is from left child, the row in right child will be discarded.
49/// - The output order is not maintained. This plan will output left child first, then right child.
50/// - The output schema contains all columns from left or right child plans.
51///
52/// From the implementation perspective, this operator is similar to `HashJoin`, but the
53/// probe side is the right child, and the build side is the left child. Another difference
54/// is that the probe is opting-out.
55///
56/// This plan will exhaust the right child first to build probe hash table, then streaming
57/// on left side, and use the left side to "mask" the hash table.
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct UnionDistinctOn {
60    left: LogicalPlan,
61    right: LogicalPlan,
62    /// The columns to compare for equality.
63    /// TIME INDEX is included.
64    compare_keys: Vec<String>,
65    ts_col: String,
66    output_schema: DFSchemaRef,
67}
68
69impl UnionDistinctOn {
70    pub fn name() -> &'static str {
71        "UnionDistinctOn"
72    }
73
74    pub fn new(
75        left: LogicalPlan,
76        right: LogicalPlan,
77        compare_keys: Vec<String>,
78        ts_col: String,
79        output_schema: DFSchemaRef,
80    ) -> Self {
81        Self {
82            left,
83            right,
84            compare_keys,
85            ts_col,
86            output_schema,
87        }
88    }
89
90    pub fn to_execution_plan(
91        &self,
92        left_exec: Arc<dyn ExecutionPlan>,
93        right_exec: Arc<dyn ExecutionPlan>,
94    ) -> Arc<dyn ExecutionPlan> {
95        let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
96        let properties = Arc::new(PlanProperties::new(
97            EquivalenceProperties::new(output_schema.clone()),
98            Partitioning::UnknownPartitioning(1),
99            EmissionType::Incremental,
100            Boundedness::Bounded,
101        ));
102        Arc::new(UnionDistinctOnExec {
103            left: left_exec,
104            right: right_exec,
105            compare_keys: self.compare_keys.clone(),
106            ts_col: self.ts_col.clone(),
107            output_schema,
108            metric: ExecutionPlanMetricsSet::new(),
109            properties,
110            random_state: RandomState::new(),
111        })
112    }
113}
114
115impl PartialOrd for UnionDistinctOn {
116    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
117        // Compare fields in order excluding output_schema
118        match self.left.partial_cmp(&other.left) {
119            Some(core::cmp::Ordering::Equal) => {}
120            ord => return ord,
121        }
122        match self.right.partial_cmp(&other.right) {
123            Some(core::cmp::Ordering::Equal) => {}
124            ord => return ord,
125        }
126        match self.compare_keys.partial_cmp(&other.compare_keys) {
127            Some(core::cmp::Ordering::Equal) => {}
128            ord => return ord,
129        }
130        self.ts_col.partial_cmp(&other.ts_col)
131    }
132}
133
134impl UserDefinedLogicalNodeCore for UnionDistinctOn {
135    fn name(&self) -> &str {
136        Self::name()
137    }
138
139    fn inputs(&self) -> Vec<&LogicalPlan> {
140        vec![&self.left, &self.right]
141    }
142
143    fn schema(&self) -> &DFSchemaRef {
144        &self.output_schema
145    }
146
147    fn expressions(&self) -> Vec<Expr> {
148        vec![]
149    }
150
151    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(
153            f,
154            "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
155            self.compare_keys, self.ts_col
156        )
157    }
158
159    fn with_exprs_and_inputs(
160        &self,
161        _exprs: Vec<Expr>,
162        inputs: Vec<LogicalPlan>,
163    ) -> DataFusionResult<Self> {
164        if inputs.len() != 2 {
165            return Err(DataFusionError::Internal(
166                "UnionDistinctOn must have exactly 2 inputs".to_string(),
167            ));
168        }
169
170        let mut inputs = inputs.into_iter();
171        let left = inputs.next().unwrap();
172        let right = inputs.next().unwrap();
173
174        Ok(Self {
175            left,
176            right,
177            compare_keys: self.compare_keys.clone(),
178            ts_col: self.ts_col.clone(),
179            output_schema: self.output_schema.clone(),
180        })
181    }
182}
183
184#[derive(Debug)]
185pub struct UnionDistinctOnExec {
186    left: Arc<dyn ExecutionPlan>,
187    right: Arc<dyn ExecutionPlan>,
188    compare_keys: Vec<String>,
189    ts_col: String,
190    output_schema: SchemaRef,
191    metric: ExecutionPlanMetricsSet,
192    properties: Arc<PlanProperties>,
193
194    /// Shared the `RandomState` for the hashing algorithm
195    random_state: RandomState,
196}
197
198impl ExecutionPlan for UnionDistinctOnExec {
199    fn as_any(&self) -> &dyn Any {
200        self
201    }
202
203    fn schema(&self) -> SchemaRef {
204        self.output_schema.clone()
205    }
206
207    fn required_input_distribution(&self) -> Vec<Distribution> {
208        vec![Distribution::SinglePartition, Distribution::SinglePartition]
209    }
210
211    fn properties(&self) -> &PlanProperties {
212        self.properties.as_ref()
213    }
214
215    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
216        vec![&self.left, &self.right]
217    }
218
219    fn with_new_children(
220        self: Arc<Self>,
221        children: Vec<Arc<dyn ExecutionPlan>>,
222    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
223        assert_eq!(children.len(), 2);
224
225        let left = children[0].clone();
226        let right = children[1].clone();
227        Ok(Arc::new(UnionDistinctOnExec {
228            left,
229            right,
230            compare_keys: self.compare_keys.clone(),
231            ts_col: self.ts_col.clone(),
232            output_schema: self.output_schema.clone(),
233            metric: self.metric.clone(),
234            properties: self.properties.clone(),
235            random_state: self.random_state.clone(),
236        }))
237    }
238
239    fn execute(
240        &self,
241        partition: usize,
242        context: Arc<TaskContext>,
243    ) -> DataFusionResult<SendableRecordBatchStream> {
244        let left_stream = self.left.execute(partition, context.clone())?;
245        let right_stream = self.right.execute(partition, context.clone())?;
246
247        // Convert column name to column index. Add one for the time column.
248        let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
249        for key in &self.compare_keys {
250            let index = self
251                .output_schema
252                .column_with_name(key)
253                .map(|(i, _)| i)
254                .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
255            key_indices.push(index);
256        }
257        let ts_index = self
258            .output_schema
259            .column_with_name(&self.ts_col)
260            .map(|(i, _)| i)
261            .ok_or_else(|| {
262                DataFusionError::Internal(format!("Column {} not found", self.ts_col))
263            })?;
264        key_indices.push(ts_index);
265
266        // Build right hash table future.
267        let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
268            right_stream,
269            self.random_state.clone(),
270            key_indices.clone(),
271        )));
272
273        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
274        Ok(Box::pin(UnionDistinctOnStream {
275            left: left_stream,
276            right: hashed_data_future,
277            compare_keys: key_indices,
278            output_schema: self.output_schema.clone(),
279            metric: baseline_metric,
280        }))
281    }
282
283    fn metrics(&self) -> Option<MetricsSet> {
284        Some(self.metric.clone_inner())
285    }
286
287    fn name(&self) -> &str {
288        "UnionDistinctOnExec"
289    }
290}
291
292impl DisplayAs for UnionDistinctOnExec {
293    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
294        match t {
295            DisplayFormatType::Default | DisplayFormatType::Verbose => {
296                write!(
297                    f,
298                    "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
299                    self.compare_keys, self.ts_col
300                )
301            }
302        }
303    }
304}
305
306// TODO(ruihang): some unused fields are for metrics, which will be implemented later.
307#[allow(dead_code)]
308pub struct UnionDistinctOnStream {
309    left: SendableRecordBatchStream,
310    right: HashedDataFut,
311    /// Include time index
312    compare_keys: Vec<usize>,
313    output_schema: SchemaRef,
314    metric: BaselineMetrics,
315}
316
317impl UnionDistinctOnStream {
318    fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
319        // resolve the right stream
320        let right = match self.right {
321            HashedDataFut::Pending(ref mut fut) => {
322                let right = ready!(fut.as_mut().poll(cx))?;
323                self.right = HashedDataFut::Ready(right);
324                let HashedDataFut::Ready(right_ref) = &mut self.right else {
325                    unreachable!()
326                };
327                right_ref
328            }
329            HashedDataFut::Ready(ref mut right) => right,
330            HashedDataFut::Empty => return Poll::Ready(None),
331        };
332
333        // poll left and probe with right
334        let next_left = ready!(self.left.poll_next_unpin(cx));
335        match next_left {
336            Some(Ok(left)) => {
337                // observe left batch and return it
338                right.update_map(&left)?;
339                Poll::Ready(Some(Ok(left)))
340            }
341            Some(Err(e)) => Poll::Ready(Some(Err(e))),
342            None => {
343                // left stream is exhausted, so we can send the right part
344                let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
345                let HashedDataFut::Ready(data) = right else {
346                    unreachable!()
347                };
348                Poll::Ready(Some(data.finish()))
349            }
350        }
351    }
352}
353
354impl RecordBatchStream for UnionDistinctOnStream {
355    fn schema(&self) -> SchemaRef {
356        self.output_schema.clone()
357    }
358}
359
360impl Stream for UnionDistinctOnStream {
361    type Item = DataFusionResult<RecordBatch>;
362
363    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
364        self.poll_impl(cx)
365    }
366}
367
368/// Simple future state for [HashedData]
369enum HashedDataFut {
370    /// The result is not ready
371    Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
372    /// The result is ready
373    Ready(HashedData),
374    /// The result is taken
375    Empty,
376}
377
378/// ALL input batches and its hash table
379struct HashedData {
380    // TODO(ruihang): use `JoinHashMap` instead after upgrading to DF 34.0
381    /// Hash table for all input batches. The key is hash value, and the value
382    /// is the index of `bathc`.
383    hash_map: HashMap<u64, usize>,
384    /// Output batch.
385    batch: RecordBatch,
386    /// The indices of the columns to be hashed.
387    hash_key_indices: Vec<usize>,
388    random_state: RandomState,
389}
390
391impl HashedData {
392    pub async fn new(
393        input: SendableRecordBatchStream,
394        random_state: RandomState,
395        hash_key_indices: Vec<usize>,
396    ) -> DataFusionResult<Self> {
397        // Collect all batches from the input stream
398        let initial = (Vec::new(), 0);
399        let schema = input.schema();
400        let (batches, _num_rows) = input
401            .try_fold(initial, |mut acc, batch| async {
402                // Update rowcount
403                acc.1 += batch.num_rows();
404                // Push batch to output
405                acc.0.push(batch);
406                Ok(acc)
407            })
408            .await?;
409
410        // Create hash for each batch
411        let mut hash_map = HashMap::default();
412        let mut hashes_buffer = Vec::new();
413        let mut interleave_indices = Vec::new();
414        for (batch_number, batch) in batches.iter().enumerate() {
415            hashes_buffer.resize(batch.num_rows(), 0);
416            // get columns for hashing
417            let arrays = hash_key_indices
418                .iter()
419                .map(|i| batch.column(*i).clone())
420                .collect::<Vec<_>>();
421
422            // compute hash
423            let hash_values =
424                hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
425            for (row_number, hash_value) in hash_values.iter().enumerate() {
426                // Only keeps the first observed row for each hash value
427                if hash_map
428                    .try_insert(*hash_value, interleave_indices.len())
429                    .is_ok()
430                {
431                    interleave_indices.push((batch_number, row_number));
432                }
433            }
434        }
435
436        // Finalize the hash map
437        let batch = interleave_batches(schema, batches, interleave_indices)?;
438
439        Ok(Self {
440            hash_map,
441            batch,
442            hash_key_indices,
443            random_state,
444        })
445    }
446
447    /// Remove rows that hash value present in the input
448    /// record batch from the hash map.
449    pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
450        // get columns for hashing
451        let mut hashes_buffer = Vec::new();
452        let arrays = self
453            .hash_key_indices
454            .iter()
455            .map(|i| input.column(*i).clone())
456            .collect::<Vec<_>>();
457
458        // compute hash
459        hashes_buffer.resize(input.num_rows(), 0);
460        let hash_values =
461            hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
462
463        // remove those hashes
464        for hash in hash_values {
465            self.hash_map.remove(hash);
466        }
467
468        Ok(())
469    }
470
471    pub fn finish(self) -> DataFusionResult<RecordBatch> {
472        let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
473        let result = take_batch(&self.batch, &valid_indices)?;
474        Ok(result)
475    }
476}
477
478/// Utility function to interleave batches. Based on [interleave](datafusion::arrow::compute::interleave)
479fn interleave_batches(
480    schema: SchemaRef,
481    batches: Vec<RecordBatch>,
482    indices: Vec<(usize, usize)>,
483) -> DataFusionResult<RecordBatch> {
484    if batches.is_empty() {
485        if indices.is_empty() {
486            return Ok(RecordBatch::new_empty(schema));
487        } else {
488            return Err(DataFusionError::Internal(
489                "Cannot interleave empty batches with non-empty indices".to_string(),
490            ));
491        }
492    }
493
494    // transform batches into arrays
495    let mut arrays = vec![vec![]; schema.fields().len()];
496    for batch in &batches {
497        for (i, array) in batch.columns().iter().enumerate() {
498            arrays[i].push(array.as_ref());
499        }
500    }
501
502    // interleave arrays
503    let interleaved_arrays: Vec<_> = arrays
504        .into_iter()
505        .map(|array| compute::interleave(&array, &indices))
506        .collect::<Result<_, _>>()?;
507
508    // assemble new record batch
509    RecordBatch::try_new(schema, interleaved_arrays)
510        .map_err(|e| DataFusionError::ArrowError(e, None))
511}
512
513/// Utility function to take rows from a record batch. Based on [take](datafusion::arrow::compute::take)
514fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
515    // fast path
516    if batch.num_rows() == indices.len() {
517        return Ok(batch.clone());
518    }
519
520    let schema = batch.schema();
521
522    let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
523    let arrays = batch
524        .columns()
525        .iter()
526        .map(|array| compute::take(array, &indices_array, None))
527        .collect::<std::result::Result<Vec<_>, _>>()
528        .map_err(|e| DataFusionError::ArrowError(e, None))?;
529
530    let result =
531        RecordBatch::try_new(schema, arrays).map_err(|e| DataFusionError::ArrowError(e, None))?;
532    Ok(result)
533}
534
535#[cfg(test)]
536mod test {
537    use std::sync::Arc;
538
539    use datafusion::arrow::array::Int32Array;
540    use datafusion::arrow::datatypes::{DataType, Field, Schema};
541
542    use super::*;
543
544    #[test]
545    fn test_interleave_batches() {
546        let schema = Schema::new(vec![
547            Field::new("a", DataType::Int32, false),
548            Field::new("b", DataType::Int32, false),
549        ]);
550
551        let batch1 = RecordBatch::try_new(
552            Arc::new(schema.clone()),
553            vec![
554                Arc::new(Int32Array::from(vec![1, 2, 3])),
555                Arc::new(Int32Array::from(vec![4, 5, 6])),
556            ],
557        )
558        .unwrap();
559
560        let batch2 = RecordBatch::try_new(
561            Arc::new(schema.clone()),
562            vec![
563                Arc::new(Int32Array::from(vec![7, 8, 9])),
564                Arc::new(Int32Array::from(vec![10, 11, 12])),
565            ],
566        )
567        .unwrap();
568
569        let batch3 = RecordBatch::try_new(
570            Arc::new(schema.clone()),
571            vec![
572                Arc::new(Int32Array::from(vec![13, 14, 15])),
573                Arc::new(Int32Array::from(vec![16, 17, 18])),
574            ],
575        )
576        .unwrap();
577
578        let batches = vec![batch1, batch2, batch3];
579        let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
580        let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
581
582        let expected = RecordBatch::try_new(
583            Arc::new(schema),
584            vec![
585                Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
586                Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
587            ],
588        )
589        .unwrap();
590
591        assert_eq!(result, expected);
592    }
593
594    #[test]
595    fn test_take_batch() {
596        let schema = Schema::new(vec![
597            Field::new("a", DataType::Int32, false),
598            Field::new("b", DataType::Int32, false),
599        ]);
600
601        let batch = RecordBatch::try_new(
602            Arc::new(schema.clone()),
603            vec![
604                Arc::new(Int32Array::from(vec![1, 2, 3])),
605                Arc::new(Int32Array::from(vec![4, 5, 6])),
606            ],
607        )
608        .unwrap();
609
610        let indices = vec![0, 2];
611        let result = take_batch(&batch, &indices).unwrap();
612
613        let expected = RecordBatch::try_new(
614            Arc::new(schema),
615            vec![
616                Arc::new(Int32Array::from(vec![1, 3])),
617                Arc::new(Int32Array::from(vec![4, 6])),
618            ],
619        )
620        .unwrap();
621
622        assert_eq!(result, expected);
623    }
624}