use std::collections::HashSet;
use std::sync::Arc;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::Result as DataFusionResult;
use datafusion_physical_expr::expressions::Column as PhysicalColumn;
use datafusion_physical_expr::LexOrdering;
use store_api::region_engine::PartitionRange;
use table::table::scan::RegionScanExec;
use crate::part_sort::PartSortExec;
use crate::window_sort::WindowedSortExec;
#[derive(Debug)]
pub struct WindowedSortPhysicalRule;
impl PhysicalOptimizerRule for WindowedSortPhysicalRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &datafusion::config::ConfigOptions,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
Self::do_optimize(plan, config)
}
fn name(&self) -> &str {
"WindowedSortRule"
}
fn schema_check(&self) -> bool {
false
}
}
impl WindowedSortPhysicalRule {
fn do_optimize(
plan: Arc<dyn ExecutionPlan>,
_config: &datafusion::config::ConfigOptions,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let result = plan
.transform_down(|plan| {
if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
if sort_exec.expr().len() != 1 {
return Ok(Transformed::no(plan));
}
let preserve_partitioning = sort_exec.preserve_partitioning();
let sort_input = remove_repartition(sort_exec.input().clone())?.data;
let sort_input =
remove_coalesce_batches_exec(sort_input, sort_exec.fetch())?.data;
let Some(scanner_info) = fetch_partition_range(sort_input.clone())? else {
return Ok(Transformed::no(plan));
};
let input_schema = sort_input.schema();
if let Some(first_sort_expr) = sort_exec.expr().first()
&& let Some(column_expr) = first_sort_expr
.expr
.as_any()
.downcast_ref::<PhysicalColumn>()
&& scanner_info
.time_index
.contains(input_schema.field(column_expr.index()).name())
{
} else {
return Ok(Transformed::no(plan));
}
let first_sort_expr = sort_exec.expr().first().unwrap().clone();
let new_input = if scanner_info.tag_columns.is_empty()
&& !first_sort_expr.options.descending
{
sort_input
} else {
Arc::new(PartSortExec::new(
first_sort_expr.clone(),
sort_exec.fetch(),
scanner_info.partition_ranges.clone(),
sort_input,
))
};
let windowed_sort_exec = WindowedSortExec::try_new(
first_sort_expr,
sort_exec.fetch(),
scanner_info.partition_ranges,
new_input,
)?;
if !preserve_partitioning {
let order_preserving_merge = SortPreservingMergeExec::new(
LexOrdering::new(sort_exec.expr().to_vec()),
Arc::new(windowed_sort_exec),
);
return Ok(Transformed {
data: Arc::new(order_preserving_merge),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
} else {
return Ok(Transformed {
data: Arc::new(windowed_sort_exec),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
}
}
Ok(Transformed::no(plan))
})?
.data;
Ok(result)
}
}
#[derive(Debug)]
struct ScannerInfo {
partition_ranges: Vec<Vec<PartitionRange>>,
time_index: HashSet<String>,
tag_columns: Vec<String>,
}
fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Option<ScannerInfo>> {
let mut partition_ranges = None;
let mut time_index = HashSet::new();
let mut alias_map = Vec::new();
let mut tag_columns = None;
let mut is_batch_coalesced = false;
input.transform_up(|plan| {
if plan.as_any().is::<RepartitionExec>()
|| plan.as_any().is::<CoalescePartitionsExec>()
|| plan.as_any().is::<SortExec>()
|| plan.as_any().is::<WindowedSortExec>()
{
partition_ranges = None;
}
if plan.as_any().is::<CoalesceBatchesExec>() {
is_batch_coalesced = true;
}
if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
for (expr, output_name) in projection.expr() {
if let Some(column_expr) = expr.as_any().downcast_ref::<PhysicalColumn>() {
alias_map.push((column_expr.name().to_string(), output_name.clone()));
}
}
time_index = resolve_alias(&alias_map, &time_index);
}
if let Some(region_scan_exec) = plan.as_any().downcast_ref::<RegionScanExec>() {
partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges());
time_index = HashSet::from([region_scan_exec.time_index()]);
tag_columns = Some(region_scan_exec.tag_columns());
if !is_batch_coalesced {
region_scan_exec.with_distinguish_partition_range(true);
}
}
Ok(Transformed::no(plan))
})?;
let result = try {
ScannerInfo {
partition_ranges: partition_ranges?,
time_index,
tag_columns: tag_columns?,
}
};
Ok(result)
}
fn remove_repartition(
plan: Arc<dyn ExecutionPlan>,
) -> DataFusionResult<Transformed<Arc<dyn ExecutionPlan>>> {
plan.transform_down(|plan| {
if plan.as_any().is::<FilterExec>() {
let maybe_repartition = plan.children()[0];
if maybe_repartition.as_any().is::<RepartitionExec>() {
let maybe_scan = maybe_repartition.children()[0];
if maybe_scan.as_any().is::<RegionScanExec>() {
let new_filter = plan.clone().with_new_children(vec![maybe_scan.clone()])?;
return Ok(Transformed::yes(new_filter));
}
}
}
Ok(Transformed::no(plan))
})
}
fn remove_coalesce_batches_exec(
plan: Arc<dyn ExecutionPlan>,
fetch: Option<usize>,
) -> DataFusionResult<Transformed<Arc<dyn ExecutionPlan>>> {
let Some(fetch) = fetch else {
return Ok(Transformed::no(plan));
};
let mut is_done = false;
plan.transform_down(|plan| {
if let Some(coalesce_batches_exec) = plan.as_any().downcast_ref::<CoalesceBatchesExec>() {
let target_batch_size = coalesce_batches_exec.target_batch_size();
if fetch < target_batch_size && !is_done {
is_done = true;
return Ok(Transformed::yes(coalesce_batches_exec.input().clone()));
}
}
Ok(Transformed::no(plan))
})
}
fn resolve_alias(alias_map: &[(String, String)], time_index: &HashSet<String>) -> HashSet<String> {
let mut avail_old_name = time_index.clone();
let mut new_time_index = HashSet::new();
for (old, new) in alias_map {
if time_index.contains(old) {
new_time_index.insert(new.clone());
} else if time_index.contains(new) && old != new {
avail_old_name.remove(new);
continue;
}
}
new_time_index.extend(avail_old_name);
new_time_index
}
#[cfg(test)]
mod test {
use itertools::Itertools;
use super::*;
#[test]
fn test_alias() {
let testcases = [
(
vec![("a", "b"), ("b", "c")],
HashSet::from(["a"]),
HashSet::from(["a", "b"]),
),
(
vec![("b", "a"), ("a", "b")],
HashSet::from(["a"]),
HashSet::from(["b"]),
),
(
vec![("b", "a"), ("b", "c")],
HashSet::from(["a"]),
HashSet::from([]),
),
(
vec![("c", "d"), ("d", "c")],
HashSet::from(["a"]),
HashSet::from(["a"]),
),
(vec![], HashSet::from(["a"]), HashSet::from(["a"])),
(vec![], HashSet::from([]), HashSet::from([])),
];
for (alias_map, time_index, expected) in testcases {
let alias_map = alias_map
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect_vec();
let time_index = time_index.into_iter().map(|i| i.to_string()).collect();
let expected: HashSet<String> = expected.into_iter().map(|i| i.to_string()).collect();
assert_eq!(
expected,
resolve_alias(&alias_map, &time_index),
"alias_map={:?}, time_index={:?}",
alias_map,
time_index
);
}
}
}