1use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::TableReference;
19use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
20use datafusion_expr::{Expr, LogicalPlan};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29 pub(crate) table_names: HashSet<TableName>,
30 query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34 type Node = LogicalPlan;
35
36 fn f_down<'a>(
38 &mut self,
39 node: Self::Node,
40 ) -> datafusion::error::Result<Transformed<Self::Node>> {
41 match node {
42 LogicalPlan::TableScan(mut scan) => {
43 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>()
44 && let Some(provider) = source
45 .table_provider
46 .as_any()
47 .downcast_ref::<DfTableProviderAdapter>()
48 && provider.table().table_type() == TableType::Base
49 {
50 let info = provider.table().table_info();
51 self.table_names.insert(TableName::new(
52 info.catalog_name.clone(),
53 info.schema_name.clone(),
54 info.name.clone(),
55 ));
56 }
57 match &scan.table_name {
58 TableReference::Full {
59 catalog,
60 schema,
61 table,
62 } => {
63 self.table_names.insert(TableName::new(
64 catalog.to_string(),
65 schema.to_string(),
66 table.to_string(),
67 ));
68 }
69 TableReference::Partial { schema, table } => {
70 self.table_names.insert(TableName::new(
71 self.query_ctx.current_catalog(),
72 schema.to_string(),
73 table.to_string(),
74 ));
75
76 scan.table_name = TableReference::Full {
77 catalog: self.query_ctx.current_catalog().into(),
78 schema: schema.clone(),
79 table: table.clone(),
80 };
81 }
82 TableReference::Bare { table } => {
83 self.table_names.insert(TableName::new(
84 self.query_ctx.current_catalog(),
85 self.query_ctx.current_schema(),
86 table.to_string(),
87 ));
88
89 scan.table_name = TableReference::Full {
90 catalog: self.query_ctx.current_catalog().into(),
91 schema: self.query_ctx.current_schema().into(),
92 table: table.clone(),
93 };
94 }
95 }
96 Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
97 }
98 node => Ok(Transformed::no(node)),
99 }
100 }
101}
102
103impl TableNamesExtractAndRewriter {
104 fn new(query_ctx: QueryContextRef) -> Self {
105 Self {
106 query_ctx,
107 table_names: HashSet::new(),
108 }
109 }
110}
111
112pub fn extract_and_rewrite_full_table_names(
115 plan: LogicalPlan,
116 query_ctx: QueryContextRef,
117) -> Result<(HashSet<TableName>, LogicalPlan)> {
118 let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
119 let plan = plan.rewrite(&mut extractor)?;
120 Ok((extractor.table_names, plan.data))
121}
122
123pub trait ExtractExpr {
125 fn expressions_consider_join(&self) -> Vec<Expr>;
129}
130
131impl ExtractExpr for LogicalPlan {
132 fn expressions_consider_join(&self) -> Vec<Expr> {
133 self.expressions()
134 }
135}
136
137#[cfg(test)]
138pub(crate) mod tests {
139
140 use std::sync::Arc;
141
142 use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
143 use common_catalog::consts::DEFAULT_CATALOG_NAME;
144 use datafusion::logical_expr::builder::LogicalTableSource;
145 use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
146 use session::context::QueryContextBuilder;
147
148 use super::*;
149
150 fn mock_plan() -> LogicalPlan {
151 let schema = Schema::new(vec![
152 Field::new("id", DataType::Int32, true),
153 Field::new("name", DataType::Utf8, true),
154 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
155 ]);
156 let table_source = LogicalTableSource::new(SchemaRef::new(schema));
157
158 let projection = None;
159
160 let builder =
161 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
162
163 builder
164 .filter(col("id").gt(lit(500)))
165 .unwrap()
166 .build()
167 .unwrap()
168 }
169
170 #[test]
171 fn test_extract_full_table_names() {
172 let ctx = QueryContextBuilder::default()
173 .current_schema("test".to_string())
174 .build();
175
176 let (table_names, plan) =
177 extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
178
179 assert_eq!(1, table_names.len());
180 assert!(table_names.contains(&TableName::new(
181 DEFAULT_CATALOG_NAME.to_string(),
182 "test".to_string(),
183 "devices".to_string()
184 )));
185
186 assert_eq!(
187 "Filter: devices.id > Int32(500)\n TableScan: greptime.test.devices",
188 plan.to_string()
189 );
190 }
191}