use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use ahash::{HashMap, RandomState};
use datafusion::arrow::array::UInt64Array;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::DFSchemaRef;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{
hash_utils, DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
};
use datatypes::arrow::compute;
use futures::future::BoxFuture;
use futures::{ready, Stream, StreamExt, TryStreamExt};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct UnionDistinctOn {
left: LogicalPlan,
right: LogicalPlan,
compare_keys: Vec<String>,
ts_col: String,
output_schema: DFSchemaRef,
}
impl UnionDistinctOn {
pub fn name() -> &'static str {
"UnionDistinctOn"
}
pub fn new(
left: LogicalPlan,
right: LogicalPlan,
compare_keys: Vec<String>,
ts_col: String,
output_schema: DFSchemaRef,
) -> Self {
Self {
left,
right,
compare_keys,
ts_col,
output_schema,
}
}
pub fn to_execution_plan(
&self,
left_exec: Arc<dyn ExecutionPlan>,
right_exec: Arc<dyn ExecutionPlan>,
) -> Arc<dyn ExecutionPlan> {
let output_schema: SchemaRef = Arc::new(self.output_schema.as_ref().into());
let properties = Arc::new(PlanProperties::new(
EquivalenceProperties::new(output_schema.clone()),
Partitioning::UnknownPartitioning(1),
ExecutionMode::Bounded,
));
Arc::new(UnionDistinctOnExec {
left: left_exec,
right: right_exec,
compare_keys: self.compare_keys.clone(),
ts_col: self.ts_col.clone(),
output_schema,
metric: ExecutionPlanMetricsSet::new(),
properties,
random_state: RandomState::new(),
})
}
}
impl PartialOrd for UnionDistinctOn {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match self.left.partial_cmp(&other.left) {
Some(core::cmp::Ordering::Equal) => {}
ord => return ord,
}
match self.right.partial_cmp(&other.right) {
Some(core::cmp::Ordering::Equal) => {}
ord => return ord,
}
match self.compare_keys.partial_cmp(&other.compare_keys) {
Some(core::cmp::Ordering::Equal) => {}
ord => return ord,
}
self.ts_col.partial_cmp(&other.ts_col)
}
}
impl UserDefinedLogicalNodeCore for UnionDistinctOn {
fn name(&self) -> &str {
Self::name()
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.left, &self.right]
}
fn schema(&self) -> &DFSchemaRef {
&self.output_schema
}
fn expressions(&self) -> Vec<Expr> {
vec![]
}
fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"UnionDistinctOn: on col=[{:?}], ts_col=[{}]",
self.compare_keys, self.ts_col
)
}
fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> DataFusionResult<Self> {
if inputs.len() != 2 {
return Err(DataFusionError::Internal(
"UnionDistinctOn must have exactly 2 inputs".to_string(),
));
}
let mut inputs = inputs.into_iter();
let left = inputs.next().unwrap();
let right = inputs.next().unwrap();
Ok(Self {
left,
right,
compare_keys: self.compare_keys.clone(),
ts_col: self.ts_col.clone(),
output_schema: self.output_schema.clone(),
})
}
}
#[derive(Debug)]
pub struct UnionDistinctOnExec {
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
compare_keys: Vec<String>,
ts_col: String,
output_schema: SchemaRef,
metric: ExecutionPlanMetricsSet,
properties: Arc<PlanProperties>,
random_state: RandomState,
}
impl ExecutionPlan for UnionDistinctOnExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![Distribution::SinglePartition, Distribution::SinglePartition]
}
fn properties(&self) -> &PlanProperties {
self.properties.as_ref()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
assert_eq!(children.len(), 2);
let left = children[0].clone();
let right = children[1].clone();
Ok(Arc::new(UnionDistinctOnExec {
left,
right,
compare_keys: self.compare_keys.clone(),
ts_col: self.ts_col.clone(),
output_schema: self.output_schema.clone(),
metric: self.metric.clone(),
properties: self.properties.clone(),
random_state: self.random_state.clone(),
}))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let left_stream = self.left.execute(partition, context.clone())?;
let right_stream = self.right.execute(partition, context.clone())?;
let mut key_indices = Vec::with_capacity(self.compare_keys.len() + 1);
for key in &self.compare_keys {
let index = self
.output_schema
.column_with_name(key)
.map(|(i, _)| i)
.ok_or_else(|| DataFusionError::Internal(format!("Column {} not found", key)))?;
key_indices.push(index);
}
let ts_index = self
.output_schema
.column_with_name(&self.ts_col)
.map(|(i, _)| i)
.ok_or_else(|| {
DataFusionError::Internal(format!("Column {} not found", self.ts_col))
})?;
key_indices.push(ts_index);
let hashed_data_future = HashedDataFut::Pending(Box::pin(HashedData::new(
right_stream,
self.random_state.clone(),
key_indices.clone(),
)));
let baseline_metric = BaselineMetrics::new(&self.metric, partition);
Ok(Box::pin(UnionDistinctOnStream {
left: left_stream,
right: hashed_data_future,
compare_keys: key_indices,
output_schema: self.output_schema.clone(),
metric: baseline_metric,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metric.clone_inner())
}
fn name(&self) -> &str {
"UnionDistinctOnExec"
}
}
impl DisplayAs for UnionDistinctOnExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"UnionDistinctOnExec: on col=[{:?}], ts_col=[{}]",
self.compare_keys, self.ts_col
)
}
}
}
}
#[allow(dead_code)]
pub struct UnionDistinctOnStream {
left: SendableRecordBatchStream,
right: HashedDataFut,
compare_keys: Vec<usize>,
output_schema: SchemaRef,
metric: BaselineMetrics,
}
impl UnionDistinctOnStream {
fn poll_impl(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
let right = match self.right {
HashedDataFut::Pending(ref mut fut) => {
let right = ready!(fut.as_mut().poll(cx))?;
self.right = HashedDataFut::Ready(right);
let HashedDataFut::Ready(right_ref) = &mut self.right else {
unreachable!()
};
right_ref
}
HashedDataFut::Ready(ref mut right) => right,
HashedDataFut::Empty => return Poll::Ready(None),
};
let next_left = ready!(self.left.poll_next_unpin(cx));
match next_left {
Some(Ok(left)) => {
right.update_map(&left)?;
Poll::Ready(Some(Ok(left)))
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => {
let right = std::mem::replace(&mut self.right, HashedDataFut::Empty);
let HashedDataFut::Ready(data) = right else {
unreachable!()
};
Poll::Ready(Some(data.finish()))
}
}
}
}
impl RecordBatchStream for UnionDistinctOnStream {
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
}
impl Stream for UnionDistinctOnStream {
type Item = DataFusionResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_impl(cx)
}
}
enum HashedDataFut {
Pending(BoxFuture<'static, DataFusionResult<HashedData>>),
Ready(HashedData),
Empty,
}
struct HashedData {
hash_map: HashMap<u64, usize>,
batch: RecordBatch,
hash_key_indices: Vec<usize>,
random_state: RandomState,
}
impl HashedData {
pub async fn new(
input: SendableRecordBatchStream,
random_state: RandomState,
hash_key_indices: Vec<usize>,
) -> DataFusionResult<Self> {
let initial = (Vec::new(), 0);
let schema = input.schema();
let (batches, _num_rows) = input
.try_fold(initial, |mut acc, batch| async {
acc.1 += batch.num_rows();
acc.0.push(batch);
Ok(acc)
})
.await?;
let mut hash_map = HashMap::default();
let mut hashes_buffer = Vec::new();
let mut interleave_indices = Vec::new();
for (batch_number, batch) in batches.iter().enumerate() {
hashes_buffer.resize(batch.num_rows(), 0);
let arrays = hash_key_indices
.iter()
.map(|i| batch.column(*i).clone())
.collect::<Vec<_>>();
let hash_values =
hash_utils::create_hashes(&arrays, &random_state, &mut hashes_buffer)?;
for (row_number, hash_value) in hash_values.iter().enumerate() {
if hash_map
.try_insert(*hash_value, interleave_indices.len())
.is_ok()
{
interleave_indices.push((batch_number, row_number));
}
}
}
let batch = interleave_batches(schema, batches, interleave_indices)?;
Ok(Self {
hash_map,
batch,
hash_key_indices,
random_state,
})
}
pub fn update_map(&mut self, input: &RecordBatch) -> DataFusionResult<()> {
let mut hashes_buffer = Vec::new();
let arrays = self
.hash_key_indices
.iter()
.map(|i| input.column(*i).clone())
.collect::<Vec<_>>();
hashes_buffer.resize(input.num_rows(), 0);
let hash_values =
hash_utils::create_hashes(&arrays, &self.random_state, &mut hashes_buffer)?;
for hash in hash_values {
self.hash_map.remove(hash);
}
Ok(())
}
pub fn finish(self) -> DataFusionResult<RecordBatch> {
let valid_indices = self.hash_map.values().copied().collect::<Vec<_>>();
let result = take_batch(&self.batch, &valid_indices)?;
Ok(result)
}
}
fn interleave_batches(
schema: SchemaRef,
batches: Vec<RecordBatch>,
indices: Vec<(usize, usize)>,
) -> DataFusionResult<RecordBatch> {
if batches.is_empty() {
if indices.is_empty() {
return Ok(RecordBatch::new_empty(schema));
} else {
return Err(DataFusionError::Internal(
"Cannot interleave empty batches with non-empty indices".to_string(),
));
}
}
let mut arrays = vec![vec![]; schema.fields().len()];
for batch in &batches {
for (i, array) in batch.columns().iter().enumerate() {
arrays[i].push(array.as_ref());
}
}
let interleaved_arrays: Vec<_> = arrays
.into_iter()
.map(|array| compute::interleave(&array, &indices))
.collect::<Result<_, _>>()?;
RecordBatch::try_new(schema, interleaved_arrays)
.map_err(|e| DataFusionError::ArrowError(e, None))
}
fn take_batch(batch: &RecordBatch, indices: &[usize]) -> DataFusionResult<RecordBatch> {
if batch.num_rows() == indices.len() {
return Ok(batch.clone());
}
let schema = batch.schema();
let indices_array = UInt64Array::from_iter(indices.iter().map(|i| *i as u64));
let arrays = batch
.columns()
.iter()
.map(|array| compute::take(array, &indices_array, None))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| DataFusionError::ArrowError(e, None))?;
let result =
RecordBatch::try_new(schema, arrays).map_err(|e| DataFusionError::ArrowError(e, None))?;
Ok(result)
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use datafusion::arrow::array::Int32Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use super::*;
#[test]
fn test_interleave_batches() {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![7, 8, 9])),
Arc::new(Int32Array::from(vec![10, 11, 12])),
],
)
.unwrap();
let batch3 = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![13, 14, 15])),
Arc::new(Int32Array::from(vec![16, 17, 18])),
],
)
.unwrap();
let batches = vec![batch1, batch2, batch3];
let indices = vec![(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)];
let result = interleave_batches(Arc::new(schema.clone()), batches, indices).unwrap();
let expected = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from(vec![1, 7, 13, 2, 8, 14])),
Arc::new(Int32Array::from(vec![4, 10, 16, 5, 11, 17])),
],
)
.unwrap();
assert_eq!(result, expected);
}
#[test]
fn test_take_batch() {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![4, 5, 6])),
],
)
.unwrap();
let indices = vec![0, 2];
let result = take_batch(&batch, &indices).unwrap();
let expected = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from(vec![1, 3])),
Arc::new(Int32Array::from(vec![4, 6])),
],
)
.unwrap();
assert_eq!(result, expected);
}
}