query/optimizer/
count_wildcard.rsuse datafusion::datasource::DefaultTableSource;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::{Column, Result as DataFusionResult};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_optimizer::utils::NamePreserver;
use datafusion_optimizer::AnalyzerRule;
use datafusion_sql::TableReference;
use table::table::adapter::DfTableProviderAdapter;
pub struct CountWildcardToTimeIndexRule;
impl AnalyzerRule for CountWildcardToTimeIndexRule {
fn name(&self) -> &str {
"count_wildcard_to_time_index_rule"
}
fn analyze(
&self,
plan: LogicalPlan,
_config: &datafusion::config::ConfigOptions,
) -> DataFusionResult<LogicalPlan> {
plan.transform_down_with_subqueries(&Self::analyze_internal)
.data()
}
}
impl CountWildcardToTimeIndexRule {
fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
vec![col(time_index)]
} else {
vec![lit(COUNT_STAR_EXPANSION)]
};
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(|expr| match expr {
Expr::WindowFunction(mut window_function)
if Self::is_count_star_window_aggregate(&window_function) =>
{
window_function.args.clone_from(&new_arg);
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if Self::is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args.clone_from(&new_arg);
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
let mut finder = TimeIndexFinder::default();
plan.visit(&mut finder).unwrap();
let col = finder.into_column();
if let Some(col) = &col {
let mut is_valid = false;
for input in plan.inputs() {
if input.schema().has_column(col) {
is_valid = true;
break;
}
}
if !is_valid {
return None;
}
}
col
}
}
impl CountWildcardToTimeIndexRule {
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(
&aggregate_function.func_def,
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && aggregate_function.args.len() == 1
&& Self::is_wildcard(&aggregate_function.args[0])
}
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
matches!(
&window_function.fun,
WindowFunctionDefinition::AggregateFunction(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && window_function.args.len() == 1
&& Self::is_wildcard(&window_function.args[0])
}
}
#[derive(Default)]
struct TimeIndexFinder {
time_index_col: Option<String>,
table_alias: Option<TableReference>,
}
impl TreeNodeVisitor<'_> for TimeIndexFinder {
type Node = LogicalPlan;
fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
self.table_alias = Some(subquery_alias.alias.clone());
}
if let LogicalPlan::TableScan(table_scan) = &node {
if let Some(source) = table_scan
.source
.as_any()
.downcast_ref::<DefaultTableSource>()
{
if let Some(adapter) = source
.table_provider
.as_any()
.downcast_ref::<DfTableProviderAdapter>()
{
let table_info = adapter.table().table_info();
self.table_alias
.get_or_insert(TableReference::bare(table_info.name.clone()));
self.time_index_col = table_info
.meta
.schema
.timestamp_column()
.map(|c| c.name.clone());
return Ok(TreeNodeRecursion::Stop);
}
}
}
Ok(TreeNodeRecursion::Continue)
}
fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Stop)
}
}
impl TimeIndexFinder {
fn into_column(self) -> Option<Column> {
self.time_index_col
.map(|c| Column::new(self.table_alias, c))
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use datafusion_expr::{count, wildcard, LogicalPlanBuilder};
use table::table::numbers::NumbersTable;
use super::*;
#[test]
fn uppercase_table_name() {
let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(numbers_table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])
.unwrap()
.alias(r#""FgHiJ""#)
.unwrap()
.build()
.unwrap();
let mut finder = TimeIndexFinder::default();
plan.visit(&mut finder).unwrap();
assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
assert!(finder.time_index_col.is_none());
}
}