1use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::DFSchemaRef;
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32 hash_utils, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
33 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
34};
35use datatypes::arrow::compute;
36use futures::future::BoxFuture;
37use futures::{ready, Stream, StreamExt, TryStreamExt};
38
39#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct UnionDistinctOn {
60 left: LogicalPlan,
61 right: LogicalPlan,
62 compare_keys: Vec<String>,
65 ts_col: String,
66 output_schema: DFSchemaRef,
67}
68
69impl UnionDistinctOn {
70 pub fn name() -> &'static str {
71 "UnionDistinctOn"
72 }
73
74 pub fn new(
75 left: LogicalPlan,
76 right: LogicalPlan,
77 compare_keys: Vec<String>,
78 ts_col: String,
79 output_schema: DFSchemaRef,
80 ) -> Self {
81 Self {
82 left,
83 right,
84 compare_keys,
85 ts_col,
86 output_schema,
87 }
88 }
89
90 pub fn to_execution_plan(
91 &self,
92 left_exec: Arc<dyn ExecutionPlan>,
93 right_exec: Arc<dyn ExecutionPlan>,
94 ) -> Arc<dyn ExecutionPlan> {
95 let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
96 let properties = Arc::new(PlanProperties::new(
97 EquivalenceProperties::new(output_schema.clone()),
98 Partitioning::UnknownPartitioning(1),
99 EmissionType::Incremental,
100 Boundedness::Bounded,
101 ));
102 Arc::new(UnionDistinctOnExec {
103 left: left_exec,
104 right: right_exec,
105 compare_keys: self.compare_keys.clone(),
106 ts_col: self.ts_col.clone(),
107 output_schema,
108 metric: ExecutionPlanMetricsSet::new(),
109 properties,
110 random_state: RandomState::new(),
111 })
112 }
113}
114
115impl PartialOrd for UnionDistinctOn {
116 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
117 match self.left.partial_cmp(&other.left) {
119 Some(core::cmp::Ordering::Equal) => {}
120 ord => return ord,
121 }
122 match self.right.partial_cmp(&other.right) {
123 Some(core::cmp::Ordering::Equal) => {}
124 ord => return ord,
125 }
126 match self.compare_keys.partial_cmp(&other.compare_keys) {
127 Some(core::cmp::Ordering::Equal) => {}
128 ord => return ord,
129 }
130 self.ts_col.partial_cmp(&other.ts_col)
131 }
132}
133
134impl UserDefinedLogicalNodeCore for UnionDistinctOn {
135 fn name(&self) -> &str {
136 Self::name()
137 }
138
139 fn inputs(&self) -> Vec<&LogicalPlan> {
140 vec![&self.left, &self.right]
141 }
142
143 fn schema(&self) -> &DFSchemaRef {
144 &self.output_schema
145 }
146
147 fn expressions(&self) -> Vec<Expr> {
148 vec![]
149 }
150
151 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
155 self.compare_keys, self.ts_col
156 )
157 }
158
159 fn with_exprs_and_inputs(
160 &self,
161 _exprs: Vec<Expr>,
162 inputs: Vec<LogicalPlan>,
163 ) -> DataFusionResult<Self> {
164 if inputs.len() != 2 {
165 return Err(DataFusionError::Internal(
166 "UnionDistinctOn must have exactly 2 inputs".to_string(),
167 ));
168 }
169
170 let mut inputs = inputs.into_iter();
171 let left = inputs.next().unwrap();
172 let right = inputs.next().unwrap();
173
174 Ok(Self {
175 left,
176 right,
177 compare_keys: self.compare_keys.clone(),
178 ts_col: self.ts_col.clone(),
179 output_schema: self.output_schema.clone(),
180 })
181 }
182}
183
184#[derive(Debug)]
185pub struct UnionDistinctOnExec {
186 left: Arc<dyn ExecutionPlan>,
187 right: Arc<dyn ExecutionPlan>,
188 compare_keys: Vec<String>,
189 ts_col: String,
190 output_schema: SchemaRef,
191 metric: ExecutionPlanMetricsSet,
192 properties: Arc<PlanProperties>,
193
194 random_state: RandomState,
196}
197
198impl ExecutionPlan for UnionDistinctOnExec {
199 fn as_any(&self) -> &dyn Any {
200 self
201 }
202
203 fn schema(&self) -> SchemaRef {
204 self.output_schema.clone()
205 }
206
207 fn required_input_distribution(&self) -> Vec<Distribution> {
208 vec![Distribution::SinglePartition, Distribution::SinglePartition]
209 }
210
211 fn properties(&self) -> &PlanProperties {
212 self.properties.as_ref()
213 }
214
215 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
216 vec![&self.left, &self.right]
217 }
218
219 fn with_new_children(
220 self: Arc<Self>,
221 children: Vec<Arc<dyn ExecutionPlan>>,
222 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
223 assert_eq!(children.len(), 2);
224
225 let left = children[0].clone();
226 let right = children[1].clone();
227 Ok(Arc::new(UnionDistinctOnExec {
228 left,
229 right,
230 compare_keys: self.compare_keys.clone(),
231 ts_col: self.ts_col.clone(),
232 output_schema: self.output_schema.clone(),
233 metric: self.metric.clone(),
234 properties: self.properties.clone(),
235 random_state: self.random_state.clone(),
236 }))
237 }
238
239 fn execute(
240 &self,
241 partition: usize,
242 context: Arc<TaskContext>,
243 ) -> DataFusionResult<SendableRecordBatchStream> {
244 let left_stream = self.left.execute(partition, context.clone())?;
245 let right_stream = self.right.execute(partition, context.clone())?;
246
247 let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
249 for key in &self.compare_keys {
250 let index = self
251 .output_schema
252 .column_with_name(key)
253 .map(|(i, _)| i)
254 .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
255 key_indices.push(index);
256 }
257 let ts_index = self
258 .output_schema
259 .column_with_name(&self.ts_col)
260 .map(|(i, _)| i)
261 .ok_or_else(|| {
262 DataFusionError::Internal(format!("Column {} not found", self.ts_col))
263 })?;
264 key_indices.push(ts_index);
265
266 let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
268 right_stream,
269 self.random_state.clone(),
270 key_indices.clone(),
271 )));
272
273 let baseline_metric = BaselineMetrics::new(&self.metric, partition);
274 Ok(Box::pin(UnionDistinctOnStream {
275 left: left_stream,
276 right: hashed_data_future,
277 compare_keys: key_indices,
278 output_schema: self.output_schema.clone(),
279 metric: baseline_metric,
280 }))
281 }
282
283 fn metrics(&self) -> Option<MetricsSet> {
284 Some(self.metric.clone_inner())
285 }
286
287 fn name(&self) -> &str {
288 "UnionDistinctOnExec"
289 }
290}
291
292impl DisplayAs for UnionDistinctOnExec {
293 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
294 match t {
295 DisplayFormatType::Default | DisplayFormatType::Verbose => {
296 write!(
297 f,
298 "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
299 self.compare_keys, self.ts_col
300 )
301 }
302 }
303 }
304}
305
306#[allow(dead_code)]
308pub struct UnionDistinctOnStream {
309 left: SendableRecordBatchStream,
310 right: HashedDataFut,
311 compare_keys: Vec<usize>,
313 output_schema: SchemaRef,
314 metric: BaselineMetrics,
315}
316
317impl UnionDistinctOnStream {
318 fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
319 let right = match self.right {
321 HashedDataFut::Pending(ref mut fut) => {
322 let right = ready!(fut.as_mut().poll(cx))?;
323 self.right = HashedDataFut::Ready(right);
324 let HashedDataFut::Ready(right_ref) = &mut self.right else {
325 unreachable!()
326 };
327 right_ref
328 }
329 HashedDataFut::Ready(ref mut right) => right,
330 HashedDataFut::Empty => return Poll::Ready(None),
331 };
332
333 let next_left = ready!(self.left.poll_next_unpin(cx));
335 match next_left {
336 Some(Ok(left)) => {
337 right.update_map(&left)?;
339 Poll::Ready(Some(Ok(left)))
340 }
341 Some(Err(e)) => Poll::Ready(Some(Err(e))),
342 None => {
343 let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
345 let HashedDataFut::Ready(data) = right else {
346 unreachable!()
347 };
348 Poll::Ready(Some(data.finish()))
349 }
350 }
351 }
352}
353
354impl RecordBatchStream for UnionDistinctOnStream {
355 fn schema(&self) -> SchemaRef {
356 self.output_schema.clone()
357 }
358}
359
360impl Stream for UnionDistinctOnStream {
361 type Item = DataFusionResult<RecordBatch>;
362
363 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
364 self.poll_impl(cx)
365 }
366}
367
368enum HashedDataFut {
370 Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
372 Ready(HashedData),
374 Empty,
376}
377
378struct HashedData {
380 hash_map: HashMap<u64, usize>,
384 batch: RecordBatch,
386 hash_key_indices: Vec<usize>,
388 random_state: RandomState,
389}
390
391impl HashedData {
392 pub async fn new(
393 input: SendableRecordBatchStream,
394 random_state: RandomState,
395 hash_key_indices: Vec<usize>,
396 ) -> DataFusionResult<Self> {
397 let initial = (Vec::new(), 0);
399 let schema = input.schema();
400 let (batches, _num_rows) = input
401 .try_fold(initial, |mut acc, batch| async {
402 acc.1 += batch.num_rows();
404 acc.0.push(batch);
406 Ok(acc)
407 })
408 .await?;
409
410 let mut hash_map = HashMap::default();
412 let mut hashes_buffer = Vec::new();
413 let mut interleave_indices = Vec::new();
414 for (batch_number, batch) in batches.iter().enumerate() {
415 hashes_buffer.resize(batch.num_rows(), 0);
416 let arrays = hash_key_indices
418 .iter()
419 .map(|i| batch.column(*i).clone())
420 .collect::<Vec<_>>();
421
422 let hash_values =
424 hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
425 for (row_number, hash_value) in hash_values.iter().enumerate() {
426 if hash_map
428 .try_insert(*hash_value, interleave_indices.len())
429 .is_ok()
430 {
431 interleave_indices.push((batch_number, row_number));
432 }
433 }
434 }
435
436 let batch = interleave_batches(schema, batches, interleave_indices)?;
438
439 Ok(Self {
440 hash_map,
441 batch,
442 hash_key_indices,
443 random_state,
444 })
445 }
446
447 pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
450 let mut hashes_buffer = Vec::new();
452 let arrays = self
453 .hash_key_indices
454 .iter()
455 .map(|i| input.column(*i).clone())
456 .collect::<Vec<_>>();
457
458 hashes_buffer.resize(input.num_rows(), 0);
460 let hash_values =
461 hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
462
463 for hash in hash_values {
465 self.hash_map.remove(hash);
466 }
467
468 Ok(())
469 }
470
471 pub fn finish(self) -> DataFusionResult<RecordBatch> {
472 let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
473 let result = take_batch(&self.batch, &valid_indices)?;
474 Ok(result)
475 }
476}
477
478fn interleave_batches(
480 schema: SchemaRef,
481 batches: Vec<RecordBatch>,
482 indices: Vec<(usize, usize)>,
483) -> DataFusionResult<RecordBatch> {
484 if batches.is_empty() {
485 if indices.is_empty() {
486 return Ok(RecordBatch::new_empty(schema));
487 } else {
488 return Err(DataFusionError::Internal(
489 "Cannot interleave empty batches with non-empty indices".to_string(),
490 ));
491 }
492 }
493
494 let mut arrays = vec![vec![]; schema.fields().len()];
496 for batch in &batches {
497 for (i, array) in batch.columns().iter().enumerate() {
498 arrays[i].push(array.as_ref());
499 }
500 }
501
502 let interleaved_arrays: Vec<_> = arrays
504 .into_iter()
505 .map(|array| compute::interleave(&array, &indices))
506 .collect::<Result<_, _>>()?;
507
508 RecordBatch::try_new(schema, interleaved_arrays)
510 .map_err(|e| DataFusionError::ArrowError(e, None))
511}
512
513fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
515 if batch.num_rows() == indices.len() {
517 return Ok(batch.clone());
518 }
519
520 let schema = batch.schema();
521
522 let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
523 let arrays = batch
524 .columns()
525 .iter()
526 .map(|array| compute::take(array, &indices_array, None))
527 .collect::<std::result::Result<Vec<_>, _>>()
528 .map_err(|e| DataFusionError::ArrowError(e, None))?;
529
530 let result =
531 RecordBatch::try_new(schema, arrays).map_err(|e| DataFusionError::ArrowError(e, None))?;
532 Ok(result)
533}
534
535#[cfg(test)]
536mod test {
537 use std::sync::Arc;
538
539 use datafusion::arrow::array::Int32Array;
540 use datafusion::arrow::datatypes::{DataType, Field, Schema};
541
542 use super::*;
543
544 #[test]
545 fn test_interleave_batches() {
546 let schema = Schema::new(vec![
547 Field::new("a", DataType::Int32, false),
548 Field::new("b", DataType::Int32, false),
549 ]);
550
551 let batch1 = RecordBatch::try_new(
552 Arc::new(schema.clone()),
553 vec![
554 Arc::new(Int32Array::from(vec![1, 2, 3])),
555 Arc::new(Int32Array::from(vec![4, 5, 6])),
556 ],
557 )
558 .unwrap();
559
560 let batch2 = RecordBatch::try_new(
561 Arc::new(schema.clone()),
562 vec![
563 Arc::new(Int32Array::from(vec![7, 8, 9])),
564 Arc::new(Int32Array::from(vec![10, 11, 12])),
565 ],
566 )
567 .unwrap();
568
569 let batch3 = RecordBatch::try_new(
570 Arc::new(schema.clone()),
571 vec![
572 Arc::new(Int32Array::from(vec![13, 14, 15])),
573 Arc::new(Int32Array::from(vec![16, 17, 18])),
574 ],
575 )
576 .unwrap();
577
578 let batches = vec![batch1, batch2, batch3];
579 let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
580 let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
581
582 let expected = RecordBatch::try_new(
583 Arc::new(schema),
584 vec![
585 Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
586 Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
587 ],
588 )
589 .unwrap();
590
591 assert_eq!(result, expected);
592 }
593
594 #[test]
595 fn test_take_batch() {
596 let schema = Schema::new(vec![
597 Field::new("a", DataType::Int32, false),
598 Field::new("b", DataType::Int32, false),
599 ]);
600
601 let batch = RecordBatch::try_new(
602 Arc::new(schema.clone()),
603 vec![
604 Arc::new(Int32Array::from(vec![1, 2, 3])),
605 Arc::new(Int32Array::from(vec![4, 5, 6])),
606 ],
607 )
608 .unwrap();
609
610 let indices = vec![0, 2];
611 let result = take_batch(&batch, &indices).unwrap();
612
613 let expected = RecordBatch::try_new(
614 Arc::new(schema),
615 vec![
616 Arc::new(Int32Array::from(vec![1, 3])),
617 Arc::new(Int32Array::from(vec![4, 6])),
618 ],
619 )
620 .unwrap();
621
622 assert_eq!(result, expected);
623 }
624}