1use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::{DFSchema, DFSchemaRef};
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
33 RecordBatchStream, SendableRecordBatchStream, hash_utils,
34};
35use datafusion_expr::col;
36use datatypes::arrow::compute;
37use futures::future::BoxFuture;
38use futures::{Stream, StreamExt, TryStreamExt, ready};
39use greptime_proto::substrait_extension as pb;
40use prost::Message;
41use snafu::ResultExt;
42
43use crate::error::{DeserializeSnafu, Result};
44use crate::extension_plan::{resolve_column_name, serialize_column_index};
45
46#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct UnionDistinctOn {
67 left: LogicalPlan,
68 right: LogicalPlan,
69 compare_keys: Vec<String>,
72 ts_col: String,
73 output_schema: DFSchemaRef,
74 unfix: Option<UnfixIndices>,
75}
76
77#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
78struct UnfixIndices {
79 pub compare_key_indices: Vec<u64>,
80 pub ts_col_idx: u64,
81}
82
83impl UnionDistinctOn {
84 pub fn name() -> &'static str {
85 "UnionDistinctOn"
86 }
87
88 pub fn new(
89 left: LogicalPlan,
90 right: LogicalPlan,
91 compare_keys: Vec<String>,
92 ts_col: String,
93 output_schema: DFSchemaRef,
94 ) -> Self {
95 Self {
96 left,
97 right,
98 compare_keys,
99 ts_col,
100 output_schema,
101 unfix: None,
102 }
103 }
104
105 pub fn to_execution_plan(
106 &self,
107 left_exec: Arc<dyn ExecutionPlan>,
108 right_exec: Arc<dyn ExecutionPlan>,
109 ) -> Arc<dyn ExecutionPlan> {
110 let output_schema: SchemaRef = self.output_schema.inner().clone();
111 let properties = Arc::new(PlanProperties::new(
112 EquivalenceProperties::new(output_schema.clone()),
113 Partitioning::UnknownPartitioning(1),
114 EmissionType::Incremental,
115 Boundedness::Bounded,
116 ));
117 Arc::new(UnionDistinctOnExec {
118 left: left_exec,
119 right: right_exec,
120 compare_keys: self.compare_keys.clone(),
121 ts_col: self.ts_col.clone(),
122 output_schema,
123 metric: ExecutionPlanMetricsSet::new(),
124 properties,
125 random_state: RandomState::new(),
126 })
127 }
128
129 pub fn serialize(&self) -> Vec<u8> {
130 let compare_key_indices = self
131 .compare_keys
132 .iter()
133 .map(|name| serialize_column_index(&self.output_schema, name))
134 .collect::<Vec<u64>>();
135
136 let ts_col_idx = serialize_column_index(&self.output_schema, &self.ts_col);
137
138 pb::UnionDistinctOn {
139 compare_key_indices,
140 ts_col_idx,
141 }
142 .encode_to_vec()
143 }
144
145 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
146 let pb_union = pb::UnionDistinctOn::decode(bytes).context(DeserializeSnafu)?;
147 let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
148 produce_one_row: false,
149 schema: Arc::new(DFSchema::empty()),
150 });
151
152 let unfix = UnfixIndices {
153 compare_key_indices: pb_union.compare_key_indices.clone(),
154 ts_col_idx: pb_union.ts_col_idx,
155 };
156
157 Ok(Self {
158 left: placeholder_plan.clone(),
159 right: placeholder_plan,
160 compare_keys: Vec::new(),
161 ts_col: String::new(),
162 output_schema: Arc::new(DFSchema::empty()),
163 unfix: Some(unfix),
164 })
165 }
166}
167
168impl PartialOrd for UnionDistinctOn {
169 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
170 match self.left.partial_cmp(&other.left) {
172 Some(core::cmp::Ordering::Equal) => {}
173 ord => return ord,
174 }
175 match self.right.partial_cmp(&other.right) {
176 Some(core::cmp::Ordering::Equal) => {}
177 ord => return ord,
178 }
179 match self.compare_keys.partial_cmp(&other.compare_keys) {
180 Some(core::cmp::Ordering::Equal) => {}
181 ord => return ord,
182 }
183 self.ts_col.partial_cmp(&other.ts_col)
184 }
185}
186
187impl UserDefinedLogicalNodeCore for UnionDistinctOn {
188 fn name(&self) -> &str {
189 Self::name()
190 }
191
192 fn inputs(&self) -> Vec<&LogicalPlan> {
193 vec![&self.left, &self.right]
194 }
195
196 fn schema(&self) -> &DFSchemaRef {
197 &self.output_schema
198 }
199
200 fn expressions(&self) -> Vec<Expr> {
201 if self.unfix.is_some() {
202 return vec![];
203 }
204
205 let mut exprs: Vec<Expr> = self.compare_keys.iter().map(col).collect();
206 if !self.compare_keys.iter().any(|key| key == &self.ts_col) {
207 exprs.push(col(&self.ts_col));
208 }
209 exprs
210 }
211
212 fn necessary_children_exprs(&self, _output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
213 if self.unfix.is_some() {
214 return None;
215 }
216
217 let left_len = self.left.schema().fields().len();
218 let right_len = self.right.schema().fields().len();
219 Some(vec![
220 (0..left_len).collect::<Vec<_>>(),
221 (0..right_len).collect::<Vec<_>>(),
222 ])
223 }
224
225 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 write!(
227 f,
228 "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
229 self.compare_keys, self.ts_col
230 )
231 }
232
233 fn with_exprs_and_inputs(
234 &self,
235 _exprs: Vec<Expr>,
236 inputs: Vec<LogicalPlan>,
237 ) -> DataFusionResult<Self> {
238 if inputs.len() != 2 {
239 return Err(DataFusionError::Internal(
240 "UnionDistinctOn must have exactly 2 inputs".to_string(),
241 ));
242 }
243
244 let mut inputs = inputs.into_iter();
245 let left = inputs.next().unwrap();
246 let right = inputs.next().unwrap();
247
248 if let Some(unfix) = &self.unfix {
249 let output_schema = left.schema().clone();
250
251 let compare_keys = unfix
252 .compare_key_indices
253 .iter()
254 .map(|idx| {
255 resolve_column_name(*idx, &output_schema, "UnionDistinctOn", "compare key")
256 })
257 .collect::<DataFusionResult<Vec<String>>>()?;
258
259 let ts_col =
260 resolve_column_name(unfix.ts_col_idx, &output_schema, "UnionDistinctOn", "ts")?;
261
262 Ok(Self {
263 left,
264 right,
265 compare_keys,
266 ts_col,
267 output_schema,
268 unfix: None,
269 })
270 } else {
271 Ok(Self {
272 left,
273 right,
274 compare_keys: self.compare_keys.clone(),
275 ts_col: self.ts_col.clone(),
276 output_schema: self.output_schema.clone(),
277 unfix: None,
278 })
279 }
280 }
281}
282
283#[derive(Debug)]
284pub struct UnionDistinctOnExec {
285 left: Arc<dyn ExecutionPlan>,
286 right: Arc<dyn ExecutionPlan>,
287 compare_keys: Vec<String>,
288 ts_col: String,
289 output_schema: SchemaRef,
290 metric: ExecutionPlanMetricsSet,
291 properties: Arc<PlanProperties>,
292
293 random_state: RandomState,
295}
296
297impl ExecutionPlan for UnionDistinctOnExec {
298 fn as_any(&self) -> &dyn Any {
299 self
300 }
301
302 fn schema(&self) -> SchemaRef {
303 self.output_schema.clone()
304 }
305
306 fn required_input_distribution(&self) -> Vec<Distribution> {
307 vec![Distribution::SinglePartition, Distribution::SinglePartition]
308 }
309
310 fn properties(&self) -> &PlanProperties {
311 self.properties.as_ref()
312 }
313
314 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
315 vec![&self.left, &self.right]
316 }
317
318 fn with_new_children(
319 self: Arc<Self>,
320 children: Vec<Arc<dyn ExecutionPlan>>,
321 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
322 assert_eq!(children.len(), 2);
323
324 let left = children[0].clone();
325 let right = children[1].clone();
326 Ok(Arc::new(UnionDistinctOnExec {
327 left,
328 right,
329 compare_keys: self.compare_keys.clone(),
330 ts_col: self.ts_col.clone(),
331 output_schema: self.output_schema.clone(),
332 metric: self.metric.clone(),
333 properties: self.properties.clone(),
334 random_state: self.random_state.clone(),
335 }))
336 }
337
338 fn execute(
339 &self,
340 partition: usize,
341 context: Arc<TaskContext>,
342 ) -> DataFusionResult<SendableRecordBatchStream> {
343 let left_stream = self.left.execute(partition, context.clone())?;
344 let right_stream = self.right.execute(partition, context.clone())?;
345
346 let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
348 for key in &self.compare_keys {
349 let index = self
350 .output_schema
351 .column_with_name(key)
352 .map(|(i, _)| i)
353 .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
354 key_indices.push(index);
355 }
356 let ts_index = self
357 .output_schema
358 .column_with_name(&self.ts_col)
359 .map(|(i, _)| i)
360 .ok_or_else(|| {
361 DataFusionError::Internal(format!("Column {} not found", self.ts_col))
362 })?;
363 key_indices.push(ts_index);
364
365 let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
367 right_stream,
368 self.random_state.clone(),
369 key_indices.clone(),
370 )));
371
372 let baseline_metric = BaselineMetrics::new(&self.metric, partition);
373 Ok(Box::pin(UnionDistinctOnStream {
374 left: left_stream,
375 right: hashed_data_future,
376 compare_keys: key_indices,
377 output_schema: self.output_schema.clone(),
378 metric: baseline_metric,
379 }))
380 }
381
382 fn metrics(&self) -> Option<MetricsSet> {
383 Some(self.metric.clone_inner())
384 }
385
386 fn name(&self) -> &str {
387 "UnionDistinctOnExec"
388 }
389}
390
391impl DisplayAs for UnionDistinctOnExec {
392 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
393 match t {
394 DisplayFormatType::Default
395 | DisplayFormatType::Verbose
396 | DisplayFormatType::TreeRender => {
397 write!(
398 f,
399 "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
400 self.compare_keys, self.ts_col
401 )
402 }
403 }
404 }
405}
406
407#[allow(dead_code)]
409pub struct UnionDistinctOnStream {
410 left: SendableRecordBatchStream,
411 right: HashedDataFut,
412 compare_keys: Vec<usize>,
414 output_schema: SchemaRef,
415 metric: BaselineMetrics,
416}
417
418impl UnionDistinctOnStream {
419 fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
420 let right = match self.right {
422 HashedDataFut::Pending(ref mut fut) => {
423 let right = ready!(fut.as_mut().poll(cx))?;
424 self.right = HashedDataFut::Ready(right);
425 let HashedDataFut::Ready(right_ref) = &mut self.right else {
426 unreachable!()
427 };
428 right_ref
429 }
430 HashedDataFut::Ready(ref mut right) => right,
431 HashedDataFut::Empty => return Poll::Ready(None),
432 };
433
434 let next_left = ready!(self.left.poll_next_unpin(cx));
436 match next_left {
437 Some(Ok(left)) => {
438 right.update_map(&left)?;
440 Poll::Ready(Some(Ok(left)))
441 }
442 Some(Err(e)) => Poll::Ready(Some(Err(e))),
443 None => {
444 let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
446 let HashedDataFut::Ready(data) = right else {
447 unreachable!()
448 };
449 Poll::Ready(Some(data.finish()))
450 }
451 }
452 }
453}
454
455impl RecordBatchStream for UnionDistinctOnStream {
456 fn schema(&self) -> SchemaRef {
457 self.output_schema.clone()
458 }
459}
460
461impl Stream for UnionDistinctOnStream {
462 type Item = DataFusionResult<RecordBatch>;
463
464 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
465 self.poll_impl(cx)
466 }
467}
468
469enum HashedDataFut {
471 Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
473 Ready(HashedData),
475 Empty,
477}
478
479struct HashedData {
481 hash_map: HashMap<u64, usize>,
485 batch: RecordBatch,
487 hash_key_indices: Vec<usize>,
489 random_state: RandomState,
490}
491
492impl HashedData {
493 pub async fn new(
494 input: SendableRecordBatchStream,
495 random_state: RandomState,
496 hash_key_indices: Vec<usize>,
497 ) -> DataFusionResult<Self> {
498 let initial = (Vec::new(), 0);
500 let schema = input.schema();
501 let (batches, _num_rows) = input
502 .try_fold(initial, |mut acc, batch| async {
503 acc.1 += batch.num_rows();
505 acc.0.push(batch);
507 Ok(acc)
508 })
509 .await?;
510
511 let mut hash_map = HashMap::default();
513 let mut hashes_buffer = Vec::new();
514 let mut interleave_indices = Vec::new();
515 for (batch_number, batch) in batches.iter().enumerate() {
516 hashes_buffer.resize(batch.num_rows(), 0);
517 let arrays = hash_key_indices
519 .iter()
520 .map(|i| batch.column(*i).clone())
521 .collect::<Vec<_>>();
522
523 let hash_values =
525 hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
526 for (row_number, hash_value) in hash_values.iter().enumerate() {
527 if hash_map
529 .try_insert(*hash_value, interleave_indices.len())
530 .is_ok()
531 {
532 interleave_indices.push((batch_number, row_number));
533 }
534 }
535 }
536
537 let batch = interleave_batches(schema, batches, interleave_indices)?;
539
540 Ok(Self {
541 hash_map,
542 batch,
543 hash_key_indices,
544 random_state,
545 })
546 }
547
548 pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
551 let mut hashes_buffer = Vec::new();
553 let arrays = self
554 .hash_key_indices
555 .iter()
556 .map(|i| input.column(*i).clone())
557 .collect::<Vec<_>>();
558
559 hashes_buffer.resize(input.num_rows(), 0);
561 let hash_values =
562 hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
563
564 for hash in hash_values {
566 self.hash_map.remove(hash);
567 }
568
569 Ok(())
570 }
571
572 pub fn finish(self) -> DataFusionResult<RecordBatch> {
573 let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
574 let result = take_batch(&self.batch, &valid_indices)?;
575 Ok(result)
576 }
577}
578
579fn interleave_batches(
581 schema: SchemaRef,
582 batches: Vec<RecordBatch>,
583 indices: Vec<(usize, usize)>,
584) -> DataFusionResult<RecordBatch> {
585 if batches.is_empty() {
586 if indices.is_empty() {
587 return Ok(RecordBatch::new_empty(schema));
588 } else {
589 return Err(DataFusionError::Internal(
590 "Cannot interleave empty batches with non-empty indices".to_string(),
591 ));
592 }
593 }
594
595 let mut arrays = vec![vec![]; schema.fields().len()];
597 for batch in &batches {
598 for (i, array) in batch.columns().iter().enumerate() {
599 arrays[i].push(array.as_ref());
600 }
601 }
602
603 let interleaved_arrays: Vec<_> = arrays
605 .into_iter()
606 .map(|array| compute::interleave(&array, &indices))
607 .collect::<std::result::Result<_, _>>()?;
608
609 RecordBatch::try_new(schema, interleaved_arrays)
611 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
612}
613
614fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
616 if batch.num_rows() == indices.len() {
618 return Ok(batch.clone());
619 }
620
621 let schema = batch.schema();
622
623 let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
624 let arrays = batch
625 .columns()
626 .iter()
627 .map(|array| compute::take(array, &indices_array, None))
628 .collect::<std::result::Result<Vec<_>, _>>()
629 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
630
631 let result = RecordBatch::try_new(schema, arrays)
632 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
633 Ok(result)
634}
635
636#[cfg(test)]
637mod test {
638 use std::sync::Arc;
639
640 use datafusion::arrow::array::Int32Array;
641 use datafusion::arrow::datatypes::{DataType, Field, Schema};
642 use datafusion::common::ToDFSchema;
643 use datafusion::logical_expr::{EmptyRelation, LogicalPlan};
644
645 use super::*;
646
647 #[test]
648 fn pruning_should_keep_all_columns_for_exec() {
649 let schema = Arc::new(Schema::new(vec![
650 Field::new("ts", DataType::Int32, false),
651 Field::new("k", DataType::Int32, false),
652 Field::new("v", DataType::Int32, false),
653 ]));
654 let df_schema = schema.to_dfschema_ref().unwrap();
655 let left = LogicalPlan::EmptyRelation(EmptyRelation {
656 produce_one_row: false,
657 schema: df_schema.clone(),
658 });
659 let right = LogicalPlan::EmptyRelation(EmptyRelation {
660 produce_one_row: false,
661 schema: df_schema.clone(),
662 });
663 let plan = UnionDistinctOn::new(
664 left,
665 right,
666 vec!["k".to_string()],
667 "ts".to_string(),
668 df_schema,
669 );
670
671 let output_columns = [2usize];
673 let required = plan.necessary_children_exprs(&output_columns).unwrap();
674 assert_eq!(required.len(), 2);
675 assert_eq!(required[0].as_slice(), &[0, 1, 2]);
676 assert_eq!(required[1].as_slice(), &[0, 1, 2]);
677 }
678
679 #[test]
680 fn test_interleave_batches() {
681 let schema = Schema::new(vec![
682 Field::new("a", DataType::Int32, false),
683 Field::new("b", DataType::Int32, false),
684 ]);
685
686 let batch1 = RecordBatch::try_new(
687 Arc::new(schema.clone()),
688 vec![
689 Arc::new(Int32Array::from(vec![1, 2, 3])),
690 Arc::new(Int32Array::from(vec![4, 5, 6])),
691 ],
692 )
693 .unwrap();
694
695 let batch2 = RecordBatch::try_new(
696 Arc::new(schema.clone()),
697 vec![
698 Arc::new(Int32Array::from(vec![7, 8, 9])),
699 Arc::new(Int32Array::from(vec![10, 11, 12])),
700 ],
701 )
702 .unwrap();
703
704 let batch3 = RecordBatch::try_new(
705 Arc::new(schema.clone()),
706 vec![
707 Arc::new(Int32Array::from(vec![13, 14, 15])),
708 Arc::new(Int32Array::from(vec![16, 17, 18])),
709 ],
710 )
711 .unwrap();
712
713 let batches = vec![batch1, batch2, batch3];
714 let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
715 let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
716
717 let expected = RecordBatch::try_new(
718 Arc::new(schema),
719 vec![
720 Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
721 Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
722 ],
723 )
724 .unwrap();
725
726 assert_eq!(result, expected);
727 }
728
729 #[test]
730 fn test_take_batch() {
731 let schema = Schema::new(vec![
732 Field::new("a", DataType::Int32, false),
733 Field::new("b", DataType::Int32, false),
734 ]);
735
736 let batch = RecordBatch::try_new(
737 Arc::new(schema.clone()),
738 vec![
739 Arc::new(Int32Array::from(vec![1, 2, 3])),
740 Arc::new(Int32Array::from(vec![4, 5, 6])),
741 ],
742 )
743 .unwrap();
744
745 let indices = vec![0, 2];
746 let result = take_batch(&batch, &indices).unwrap();
747
748 let expected = RecordBatch::try_new(
749 Arc::new(schema),
750 vec![
751 Arc::new(Int32Array::from(vec![1, 3])),
752 Arc::new(Int32Array::from(vec![4, 6])),
753 ],
754 )
755 .unwrap();
756
757 assert_eq!(result, expected);
758 }
759
760 #[tokio::test]
761 async fn encode_decode_union_distinct_on() {
762 let schema = Arc::new(Schema::new(vec![
763 Field::new("ts", DataType::Int64, false),
764 Field::new("job", DataType::Utf8, false),
765 Field::new("value", DataType::Float64, false),
766 ]));
767 let df_schema = schema.clone().to_dfschema_ref().unwrap();
768 let left_plan = LogicalPlan::EmptyRelation(EmptyRelation {
769 produce_one_row: false,
770 schema: df_schema.clone(),
771 });
772 let right_plan = LogicalPlan::EmptyRelation(EmptyRelation {
773 produce_one_row: false,
774 schema: df_schema.clone(),
775 });
776 let plan_node = UnionDistinctOn::new(
777 left_plan.clone(),
778 right_plan.clone(),
779 vec!["job".to_string()],
780 "ts".to_string(),
781 df_schema.clone(),
782 );
783
784 let bytes = plan_node.serialize();
785
786 let union_distinct_on = UnionDistinctOn::deserialize(&bytes).unwrap();
787 let union_distinct_on = union_distinct_on
788 .with_exprs_and_inputs(vec![], vec![left_plan, right_plan])
789 .unwrap();
790
791 assert_eq!(union_distinct_on.compare_keys, vec!["job".to_string()]);
792 assert_eq!(union_distinct_on.ts_col, "ts");
793 assert_eq!(union_distinct_on.output_schema, df_schema);
794 }
795}