1use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::TableReference;
20use datafusion_expr::{Expr, LogicalPlan};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29 pub(crate) table_names: HashSet<TableName>,
30 query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34 type Node = LogicalPlan;
35
36 fn f_down<'a>(
38 &mut self,
39 node: Self::Node,
40 ) -> datafusion::error::Result<Transformed<Self::Node>> {
41 match node {
42 LogicalPlan::TableScan(mut scan) => {
43 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
44 if let Some(provider) = source
45 .table_provider
46 .as_any()
47 .downcast_ref::<DfTableProviderAdapter>()
48 {
49 if provider.table().table_type() == TableType::Base {
50 let info = provider.table().table_info();
51 self.table_names.insert(TableName::new(
52 info.catalog_name.clone(),
53 info.schema_name.clone(),
54 info.name.clone(),
55 ));
56 }
57 }
58 }
59 match &scan.table_name {
60 TableReference::Full {
61 catalog,
62 schema,
63 table,
64 } => {
65 self.table_names.insert(TableName::new(
66 catalog.to_string(),
67 schema.to_string(),
68 table.to_string(),
69 ));
70 }
71 TableReference::Partial { schema, table } => {
72 self.table_names.insert(TableName::new(
73 self.query_ctx.current_catalog(),
74 schema.to_string(),
75 table.to_string(),
76 ));
77
78 scan.table_name = TableReference::Full {
79 catalog: self.query_ctx.current_catalog().into(),
80 schema: schema.clone(),
81 table: table.clone(),
82 };
83 }
84 TableReference::Bare { table } => {
85 self.table_names.insert(TableName::new(
86 self.query_ctx.current_catalog(),
87 self.query_ctx.current_schema(),
88 table.to_string(),
89 ));
90
91 scan.table_name = TableReference::Full {
92 catalog: self.query_ctx.current_catalog().into(),
93 schema: self.query_ctx.current_schema().into(),
94 table: table.clone(),
95 };
96 }
97 }
98 Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
99 }
100 node => Ok(Transformed::no(node)),
101 }
102 }
103}
104
105impl TableNamesExtractAndRewriter {
106 fn new(query_ctx: QueryContextRef) -> Self {
107 Self {
108 query_ctx,
109 table_names: HashSet::new(),
110 }
111 }
112}
113
114pub fn extract_and_rewrite_full_table_names(
117 plan: LogicalPlan,
118 query_ctx: QueryContextRef,
119) -> Result<(HashSet<TableName>, LogicalPlan)> {
120 let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
121 let plan = plan.rewrite(&mut extractor)?;
122 Ok((extractor.table_names, plan.data))
123}
124
125pub trait ExtractExpr {
127 fn expressions_consider_join(&self) -> Vec<Expr>;
131}
132
133impl ExtractExpr for LogicalPlan {
134 fn expressions_consider_join(&self) -> Vec<Expr> {
135 self.expressions()
136 }
137}
138
139#[cfg(test)]
140pub(crate) mod tests {
141
142 use std::sync::Arc;
143
144 use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
145 use common_catalog::consts::DEFAULT_CATALOG_NAME;
146 use datafusion::logical_expr::builder::LogicalTableSource;
147 use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
148 use session::context::QueryContextBuilder;
149
150 use super::*;
151
152 fn mock_plan() -> LogicalPlan {
153 let schema = Schema::new(vec![
154 Field::new("id", DataType::Int32, true),
155 Field::new("name", DataType::Utf8, true),
156 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
157 ]);
158 let table_source = LogicalTableSource::new(SchemaRef::new(schema));
159
160 let projection = None;
161
162 let builder =
163 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
164
165 builder
166 .filter(col("id").gt(lit(500)))
167 .unwrap()
168 .build()
169 .unwrap()
170 }
171
172 #[test]
173 fn test_extract_full_table_names() {
174 let ctx = QueryContextBuilder::default()
175 .current_schema("test".to_string())
176 .build();
177
178 let (table_names, plan) =
179 extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
180
181 assert_eq!(1, table_names.len());
182 assert!(table_names.contains(&TableName::new(
183 DEFAULT_CATALOG_NAME.to_string(),
184 "test".to_string(),
185 "devices".to_string()
186 )));
187
188 assert_eq!(
189 "Filter: devices.id > Int32(500)\n TableScan: greptime.test.devices",
190 plan.to_string()
191 );
192 }
193}