1use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use ahash::{HashMap, RandomState};
21use datafusion::arrow::array::UInt64Array;
22use datafusion::arrow::datatypes::SchemaRef;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::common::DFSchemaRef;
25use datafusion::error::{DataFusionError, Result as DataFusionResult};
26use datafusion::execution::context::TaskContext;
27use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
28use datafusion::physical_expr::EquivalenceProperties;
29use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{
32 hash_utils, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
33 PlanProperties, RecordBatchStream, SendableRecordBatchStream,
34};
35use datatypes::arrow::compute;
36use futures::future::BoxFuture;
37use futures::{ready, Stream, StreamExt, TryStreamExt};
38
39#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct UnionDistinctOn {
60 left: LogicalPlan,
61 right: LogicalPlan,
62 compare_keys: Vec<String>,
65 ts_col: String,
66 output_schema: DFSchemaRef,
67}
68
69impl UnionDistinctOn {
70 pub fn name() -> &'static str {
71 "UnionDistinctOn"
72 }
73
74 pub fn new(
75 left: LogicalPlan,
76 right: LogicalPlan,
77 compare_keys: Vec<String>,
78 ts_col: String,
79 output_schema: DFSchemaRef,
80 ) -> Self {
81 Self {
82 left,
83 right,
84 compare_keys,
85 ts_col,
86 output_schema,
87 }
88 }
89
90 pub fn to_execution_plan(
91 &self,
92 left_exec: Arc<dyn ExecutionPlan>,
93 right_exec: Arc<dyn ExecutionPlan>,
94 ) -> Arc<dyn ExecutionPlan> {
95 let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
96 let properties = Arc::new(PlanProperties::new(
97 EquivalenceProperties::new(output_schema.clone()),
98 Partitioning::UnknownPartitioning(1),
99 EmissionType::Incremental,
100 Boundedness::Bounded,
101 ));
102 Arc::new(UnionDistinctOnExec {
103 left: left_exec,
104 right: right_exec,
105 compare_keys: self.compare_keys.clone(),
106 ts_col: self.ts_col.clone(),
107 output_schema,
108 metric: ExecutionPlanMetricsSet::new(),
109 properties,
110 random_state: RandomState::new(),
111 })
112 }
113}
114
115impl PartialOrd for UnionDistinctOn {
116 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
117 match self.left.partial_cmp(&other.left) {
119 Some(core::cmp::Ordering::Equal) => {}
120 ord => return ord,
121 }
122 match self.right.partial_cmp(&other.right) {
123 Some(core::cmp::Ordering::Equal) => {}
124 ord => return ord,
125 }
126 match self.compare_keys.partial_cmp(&other.compare_keys) {
127 Some(core::cmp::Ordering::Equal) => {}
128 ord => return ord,
129 }
130 self.ts_col.partial_cmp(&other.ts_col)
131 }
132}
133
134impl UserDefinedLogicalNodeCore for UnionDistinctOn {
135 fn name(&self) -> &str {
136 Self::name()
137 }
138
139 fn inputs(&self) -> Vec<&LogicalPlan> {
140 vec![&self.left, &self.right]
141 }
142
143 fn schema(&self) -> &DFSchemaRef {
144 &self.output_schema
145 }
146
147 fn expressions(&self) -> Vec<Expr> {
148 vec![]
149 }
150
151 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 write!(
153 f,
154 "UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
155 self.compare_keys, self.ts_col
156 )
157 }
158
159 fn with_exprs_and_inputs(
160 &self,
161 _exprs: Vec<Expr>,
162 inputs: Vec<LogicalPlan>,
163 ) -> DataFusionResult<Self> {
164 if inputs.len() != 2 {
165 return Err(DataFusionError::Internal(
166 "UnionDistinctOn must have exactly 2 inputs".to_string(),
167 ));
168 }
169
170 let mut inputs = inputs.into_iter();
171 let left = inputs.next().unwrap();
172 let right = inputs.next().unwrap();
173
174 Ok(Self {
175 left,
176 right,
177 compare_keys: self.compare_keys.clone(),
178 ts_col: self.ts_col.clone(),
179 output_schema: self.output_schema.clone(),
180 })
181 }
182}
183
184#[derive(Debug)]
185pub struct UnionDistinctOnExec {
186 left: Arc<dyn ExecutionPlan>,
187 right: Arc<dyn ExecutionPlan>,
188 compare_keys: Vec<String>,
189 ts_col: String,
190 output_schema: SchemaRef,
191 metric: ExecutionPlanMetricsSet,
192 properties: Arc<PlanProperties>,
193
194 random_state: RandomState,
196}
197
198impl ExecutionPlan for UnionDistinctOnExec {
199 fn as_any(&self) -> &dyn Any {
200 self
201 }
202
203 fn schema(&self) -> SchemaRef {
204 self.output_schema.clone()
205 }
206
207 fn required_input_distribution(&self) -> Vec<Distribution> {
208 vec![Distribution::SinglePartition, Distribution::SinglePartition]
209 }
210
211 fn properties(&self) -> &PlanProperties {
212 self.properties.as_ref()
213 }
214
215 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
216 vec![&self.left, &self.right]
217 }
218
219 fn with_new_children(
220 self: Arc<Self>,
221 children: Vec<Arc<dyn ExecutionPlan>>,
222 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
223 assert_eq!(children.len(), 2);
224
225 let left = children[0].clone();
226 let right = children[1].clone();
227 Ok(Arc::new(UnionDistinctOnExec {
228 left,
229 right,
230 compare_keys: self.compare_keys.clone(),
231 ts_col: self.ts_col.clone(),
232 output_schema: self.output_schema.clone(),
233 metric: self.metric.clone(),
234 properties: self.properties.clone(),
235 random_state: self.random_state.clone(),
236 }))
237 }
238
239 fn execute(
240 &self,
241 partition: usize,
242 context: Arc<TaskContext>,
243 ) -> DataFusionResult<SendableRecordBatchStream> {
244 let left_stream = self.left.execute(partition, context.clone())?;
245 let right_stream = self.right.execute(partition, context.clone())?;
246
247 let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
249 for key in &self.compare_keys {
250 let index = self
251 .output_schema
252 .column_with_name(key)
253 .map(|(i, _)| i)
254 .ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
255 key_indices.push(index);
256 }
257 let ts_index = self
258 .output_schema
259 .column_with_name(&self.ts_col)
260 .map(|(i, _)| i)
261 .ok_or_else(|| {
262 DataFusionError::Internal(format!("Column {} not found", self.ts_col))
263 })?;
264 key_indices.push(ts_index);
265
266 let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
268 right_stream,
269 self.random_state.clone(),
270 key_indices.clone(),
271 )));
272
273 let baseline_metric = BaselineMetrics::new(&self.metric, partition);
274 Ok(Box::pin(UnionDistinctOnStream {
275 left: left_stream,
276 right: hashed_data_future,
277 compare_keys: key_indices,
278 output_schema: self.output_schema.clone(),
279 metric: baseline_metric,
280 }))
281 }
282
283 fn metrics(&self) -> Option<MetricsSet> {
284 Some(self.metric.clone_inner())
285 }
286
287 fn name(&self) -> &str {
288 "UnionDistinctOnExec"
289 }
290}
291
292impl DisplayAs for UnionDistinctOnExec {
293 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
294 match t {
295 DisplayFormatType::Default
296 | DisplayFormatType::Verbose
297 | DisplayFormatType::TreeRender => {
298 write!(
299 f,
300 "UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
301 self.compare_keys, self.ts_col
302 )
303 }
304 }
305 }
306}
307
308#[allow(dead_code)]
310pub struct UnionDistinctOnStream {
311 left: SendableRecordBatchStream,
312 right: HashedDataFut,
313 compare_keys: Vec<usize>,
315 output_schema: SchemaRef,
316 metric: BaselineMetrics,
317}
318
319impl UnionDistinctOnStream {
320 fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
321 let right = match self.right {
323 HashedDataFut::Pending(ref mut fut) => {
324 let right = ready!(fut.as_mut().poll(cx))?;
325 self.right = HashedDataFut::Ready(right);
326 let HashedDataFut::Ready(right_ref) = &mut self.right else {
327 unreachable!()
328 };
329 right_ref
330 }
331 HashedDataFut::Ready(ref mut right) => right,
332 HashedDataFut::Empty => return Poll::Ready(None),
333 };
334
335 let next_left = ready!(self.left.poll_next_unpin(cx));
337 match next_left {
338 Some(Ok(left)) => {
339 right.update_map(&left)?;
341 Poll::Ready(Some(Ok(left)))
342 }
343 Some(Err(e)) => Poll::Ready(Some(Err(e))),
344 None => {
345 let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
347 let HashedDataFut::Ready(data) = right else {
348 unreachable!()
349 };
350 Poll::Ready(Some(data.finish()))
351 }
352 }
353 }
354}
355
356impl RecordBatchStream for UnionDistinctOnStream {
357 fn schema(&self) -> SchemaRef {
358 self.output_schema.clone()
359 }
360}
361
362impl Stream for UnionDistinctOnStream {
363 type Item = DataFusionResult<RecordBatch>;
364
365 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
366 self.poll_impl(cx)
367 }
368}
369
370enum HashedDataFut {
372 Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
374 Ready(HashedData),
376 Empty,
378}
379
380struct HashedData {
382 hash_map: HashMap<u64, usize>,
386 batch: RecordBatch,
388 hash_key_indices: Vec<usize>,
390 random_state: RandomState,
391}
392
393impl HashedData {
394 pub async fn new(
395 input: SendableRecordBatchStream,
396 random_state: RandomState,
397 hash_key_indices: Vec<usize>,
398 ) -> DataFusionResult<Self> {
399 let initial = (Vec::new(), 0);
401 let schema = input.schema();
402 let (batches, _num_rows) = input
403 .try_fold(initial, |mut acc, batch| async {
404 acc.1 += batch.num_rows();
406 acc.0.push(batch);
408 Ok(acc)
409 })
410 .await?;
411
412 let mut hash_map = HashMap::default();
414 let mut hashes_buffer = Vec::new();
415 let mut interleave_indices = Vec::new();
416 for (batch_number, batch) in batches.iter().enumerate() {
417 hashes_buffer.resize(batch.num_rows(), 0);
418 let arrays = hash_key_indices
420 .iter()
421 .map(|i| batch.column(*i).clone())
422 .collect::<Vec<_>>();
423
424 let hash_values =
426 hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
427 for (row_number, hash_value) in hash_values.iter().enumerate() {
428 if hash_map
430 .try_insert(*hash_value, interleave_indices.len())
431 .is_ok()
432 {
433 interleave_indices.push((batch_number, row_number));
434 }
435 }
436 }
437
438 let batch = interleave_batches(schema, batches, interleave_indices)?;
440
441 Ok(Self {
442 hash_map,
443 batch,
444 hash_key_indices,
445 random_state,
446 })
447 }
448
449 pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
452 let mut hashes_buffer = Vec::new();
454 let arrays = self
455 .hash_key_indices
456 .iter()
457 .map(|i| input.column(*i).clone())
458 .collect::<Vec<_>>();
459
460 hashes_buffer.resize(input.num_rows(), 0);
462 let hash_values =
463 hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
464
465 for hash in hash_values {
467 self.hash_map.remove(hash);
468 }
469
470 Ok(())
471 }
472
473 pub fn finish(self) -> DataFusionResult<RecordBatch> {
474 let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
475 let result = take_batch(&self.batch, &valid_indices)?;
476 Ok(result)
477 }
478}
479
480fn interleave_batches(
482 schema: SchemaRef,
483 batches: Vec<RecordBatch>,
484 indices: Vec<(usize, usize)>,
485) -> DataFusionResult<RecordBatch> {
486 if batches.is_empty() {
487 if indices.is_empty() {
488 return Ok(RecordBatch::new_empty(schema));
489 } else {
490 return Err(DataFusionError::Internal(
491 "Cannot interleave empty batches with non-empty indices".to_string(),
492 ));
493 }
494 }
495
496 let mut arrays = vec![vec![]; schema.fields().len()];
498 for batch in &batches {
499 for (i, array) in batch.columns().iter().enumerate() {
500 arrays[i].push(array.as_ref());
501 }
502 }
503
504 let interleaved_arrays: Vec<_> = arrays
506 .into_iter()
507 .map(|array| compute::interleave(&array, &indices))
508 .collect::<Result<_, _>>()?;
509
510 RecordBatch::try_new(schema, interleaved_arrays)
512 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
513}
514
515fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
517 if batch.num_rows() == indices.len() {
519 return Ok(batch.clone());
520 }
521
522 let schema = batch.schema();
523
524 let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
525 let arrays = batch
526 .columns()
527 .iter()
528 .map(|array| compute::take(array, &indices_array, None))
529 .collect::<std::result::Result<Vec<_>, _>>()
530 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
531
532 let result = RecordBatch::try_new(schema, arrays)
533 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
534 Ok(result)
535}
536
537#[cfg(test)]
538mod test {
539 use std::sync::Arc;
540
541 use datafusion::arrow::array::Int32Array;
542 use datafusion::arrow::datatypes::{DataType, Field, Schema};
543
544 use super::*;
545
546 #[test]
547 fn test_interleave_batches() {
548 let schema = Schema::new(vec![
549 Field::new("a", DataType::Int32, false),
550 Field::new("b", DataType::Int32, false),
551 ]);
552
553 let batch1 = RecordBatch::try_new(
554 Arc::new(schema.clone()),
555 vec![
556 Arc::new(Int32Array::from(vec![1, 2, 3])),
557 Arc::new(Int32Array::from(vec![4, 5, 6])),
558 ],
559 )
560 .unwrap();
561
562 let batch2 = RecordBatch::try_new(
563 Arc::new(schema.clone()),
564 vec![
565 Arc::new(Int32Array::from(vec![7, 8, 9])),
566 Arc::new(Int32Array::from(vec![10, 11, 12])),
567 ],
568 )
569 .unwrap();
570
571 let batch3 = RecordBatch::try_new(
572 Arc::new(schema.clone()),
573 vec![
574 Arc::new(Int32Array::from(vec![13, 14, 15])),
575 Arc::new(Int32Array::from(vec![16, 17, 18])),
576 ],
577 )
578 .unwrap();
579
580 let batches = vec![batch1, batch2, batch3];
581 let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
582 let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
583
584 let expected = RecordBatch::try_new(
585 Arc::new(schema),
586 vec![
587 Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
588 Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
589 ],
590 )
591 .unwrap();
592
593 assert_eq!(result, expected);
594 }
595
596 #[test]
597 fn test_take_batch() {
598 let schema = Schema::new(vec![
599 Field::new("a", DataType::Int32, false),
600 Field::new("b", DataType::Int32, false),
601 ]);
602
603 let batch = RecordBatch::try_new(
604 Arc::new(schema.clone()),
605 vec![
606 Arc::new(Int32Array::from(vec![1, 2, 3])),
607 Arc::new(Int32Array::from(vec![4, 5, 6])),
608 ],
609 )
610 .unwrap();
611
612 let indices = vec![0, 2];
613 let result = take_batch(&batch, &indices).unwrap();
614
615 let expected = RecordBatch::try_new(
616 Arc::new(schema),
617 vec![
618 Arc::new(Int32Array::from(vec![1, 3])),
619 Arc::new(Int32Array::from(vec![4, 6])),
620 ],
621 )
622 .unwrap();
623
624 assert_eq!(result, expected);
625 }
626}