1use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::TableReference;
20use datafusion_expr::{BinaryExpr, Expr, Join, LogicalPlan, Operator};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29 pub(crate) table_names: HashSet<TableName>,
30 query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34 type Node = LogicalPlan;
35
36 fn f_down<'a>(
38 &mut self,
39 node: Self::Node,
40 ) -> datafusion::error::Result<Transformed<Self::Node>> {
41 match node {
42 LogicalPlan::TableScan(mut scan) => {
43 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
44 if let Some(provider) = source
45 .table_provider
46 .as_any()
47 .downcast_ref::<DfTableProviderAdapter>()
48 {
49 if provider.table().table_type() == TableType::Base {
50 let info = provider.table().table_info();
51 self.table_names.insert(TableName::new(
52 info.catalog_name.clone(),
53 info.schema_name.clone(),
54 info.name.clone(),
55 ));
56 }
57 }
58 }
59 match &scan.table_name {
60 TableReference::Full {
61 catalog,
62 schema,
63 table,
64 } => {
65 self.table_names.insert(TableName::new(
66 catalog.to_string(),
67 schema.to_string(),
68 table.to_string(),
69 ));
70 }
71 TableReference::Partial { schema, table } => {
72 self.table_names.insert(TableName::new(
73 self.query_ctx.current_catalog(),
74 schema.to_string(),
75 table.to_string(),
76 ));
77
78 scan.table_name = TableReference::Full {
79 catalog: self.query_ctx.current_catalog().into(),
80 schema: schema.clone(),
81 table: table.clone(),
82 };
83 }
84 TableReference::Bare { table } => {
85 self.table_names.insert(TableName::new(
86 self.query_ctx.current_catalog(),
87 self.query_ctx.current_schema(),
88 table.to_string(),
89 ));
90
91 scan.table_name = TableReference::Full {
92 catalog: self.query_ctx.current_catalog().into(),
93 schema: self.query_ctx.current_schema().into(),
94 table: table.clone(),
95 };
96 }
97 }
98 Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
99 }
100 node => Ok(Transformed::no(node)),
101 }
102 }
103}
104
105impl TableNamesExtractAndRewriter {
106 fn new(query_ctx: QueryContextRef) -> Self {
107 Self {
108 query_ctx,
109 table_names: HashSet::new(),
110 }
111 }
112}
113
114pub fn extract_and_rewrite_full_table_names(
117 plan: LogicalPlan,
118 query_ctx: QueryContextRef,
119) -> Result<(HashSet<TableName>, LogicalPlan)> {
120 let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
121 let plan = plan.rewrite(&mut extractor)?;
122 Ok((extractor.table_names, plan.data))
123}
124
125pub trait ExtractExpr {
127 fn expressions_consider_join(&self) -> Vec<Expr>;
131}
132
133impl ExtractExpr for LogicalPlan {
134 fn expressions_consider_join(&self) -> Vec<Expr> {
135 match self {
136 LogicalPlan::Join(Join { on, filter, .. }) => {
137 on.iter()
141 .map(|(left, right)| {
142 Expr::BinaryExpr(BinaryExpr {
143 left: Box::new(left.clone()),
144 op: Operator::Eq,
145 right: Box::new(right.clone()),
146 })
147 })
148 .chain(filter.clone())
149 .collect()
150 }
151 _ => self.expressions(),
152 }
153 }
154}
155
156#[cfg(test)]
157pub(crate) mod tests {
158
159 use std::sync::Arc;
160
161 use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
162 use common_catalog::consts::DEFAULT_CATALOG_NAME;
163 use datafusion::logical_expr::builder::LogicalTableSource;
164 use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
165 use session::context::QueryContextBuilder;
166
167 use super::*;
168
169 fn mock_plan() -> LogicalPlan {
170 let schema = Schema::new(vec![
171 Field::new("id", DataType::Int32, true),
172 Field::new("name", DataType::Utf8, true),
173 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
174 ]);
175 let table_source = LogicalTableSource::new(SchemaRef::new(schema));
176
177 let projection = None;
178
179 let builder =
180 LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
181
182 builder
183 .filter(col("id").gt(lit(500)))
184 .unwrap()
185 .build()
186 .unwrap()
187 }
188
189 #[test]
190 fn test_extract_full_table_names() {
191 let ctx = QueryContextBuilder::default()
192 .current_schema("test".to_string())
193 .build();
194
195 let (table_names, plan) =
196 extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
197
198 assert_eq!(1, table_names.len());
199 assert!(table_names.contains(&TableName::new(
200 DEFAULT_CATALOG_NAME.to_string(),
201 "test".to_string(),
202 "devices".to_string()
203 )));
204
205 assert_eq!(
206 "Filter: devices.id > Int32(500)\n TableScan: greptime.test.devices",
207 plan.to_string()
208 );
209 }
210}