query/
plan.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::TableReference;
19use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
20use datafusion_expr::{Expr, LogicalPlan};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29    pub(crate) table_names: HashSet<TableName>,
30    query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34    type Node = LogicalPlan;
35
36    /// descend
37    fn f_down<'a>(
38        &mut self,
39        node: Self::Node,
40    ) -> datafusion::error::Result<Transformed<Self::Node>> {
41        match node {
42            LogicalPlan::TableScan(mut scan) => {
43                if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>()
44                    && let Some(provider) = source
45                        .table_provider
46                        .as_any()
47                        .downcast_ref::<DfTableProviderAdapter>()
48                    && provider.table().table_type() == TableType::Base
49                {
50                    let info = provider.table().table_info();
51                    self.table_names.insert(TableName::new(
52                        info.catalog_name.clone(),
53                        info.schema_name.clone(),
54                        info.name.clone(),
55                    ));
56                }
57                match &scan.table_name {
58                    TableReference::Full {
59                        catalog,
60                        schema,
61                        table,
62                    } => {
63                        self.table_names.insert(TableName::new(
64                            catalog.to_string(),
65                            schema.to_string(),
66                            table.to_string(),
67                        ));
68                    }
69                    TableReference::Partial { schema, table } => {
70                        self.table_names.insert(TableName::new(
71                            self.query_ctx.current_catalog(),
72                            schema.to_string(),
73                            table.to_string(),
74                        ));
75
76                        scan.table_name = TableReference::Full {
77                            catalog: self.query_ctx.current_catalog().into(),
78                            schema: schema.clone(),
79                            table: table.clone(),
80                        };
81                    }
82                    TableReference::Bare { table } => {
83                        self.table_names.insert(TableName::new(
84                            self.query_ctx.current_catalog(),
85                            self.query_ctx.current_schema(),
86                            table.to_string(),
87                        ));
88
89                        scan.table_name = TableReference::Full {
90                            catalog: self.query_ctx.current_catalog().into(),
91                            schema: self.query_ctx.current_schema().into(),
92                            table: table.clone(),
93                        };
94                    }
95                }
96                Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
97            }
98            node => Ok(Transformed::no(node)),
99        }
100    }
101}
102
103impl TableNamesExtractAndRewriter {
104    fn new(query_ctx: QueryContextRef) -> Self {
105        Self {
106            query_ctx,
107            table_names: HashSet::new(),
108        }
109    }
110}
111
112/// Extracts and rewrites the table names in the plan in the fully qualified style,
113/// return the table names and new plan.
114pub fn extract_and_rewrite_full_table_names(
115    plan: LogicalPlan,
116    query_ctx: QueryContextRef,
117) -> Result<(HashSet<TableName>, LogicalPlan)> {
118    let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
119    let plan = plan.rewrite(&mut extractor)?;
120    Ok((extractor.table_names, plan.data))
121}
122
123/// A trait to extract expressions from a logical plan.
124pub trait ExtractExpr {
125    /// Gets expressions from a logical plan.
126    /// It handles [Join] specially so [LogicalPlan::with_new_exprs()] can use the expressions
127    /// this method returns.
128    fn expressions_consider_join(&self) -> Vec<Expr>;
129}
130
131impl ExtractExpr for LogicalPlan {
132    fn expressions_consider_join(&self) -> Vec<Expr> {
133        self.expressions()
134    }
135}
136
137#[cfg(test)]
138pub(crate) mod tests {
139
140    use std::sync::Arc;
141
142    use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
143    use common_catalog::consts::DEFAULT_CATALOG_NAME;
144    use datafusion::logical_expr::builder::LogicalTableSource;
145    use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
146    use session::context::QueryContextBuilder;
147
148    use super::*;
149
150    fn mock_plan() -> LogicalPlan {
151        let schema = Schema::new(vec![
152            Field::new("id", DataType::Int32, true),
153            Field::new("name", DataType::Utf8, true),
154            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
155        ]);
156        let table_source = LogicalTableSource::new(SchemaRef::new(schema));
157
158        let projection = None;
159
160        let builder =
161            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
162
163        builder
164            .filter(col("id").gt(lit(500)))
165            .unwrap()
166            .build()
167            .unwrap()
168    }
169
170    #[test]
171    fn test_extract_full_table_names() {
172        let ctx = QueryContextBuilder::default()
173            .current_schema("test".to_string())
174            .build();
175
176        let (table_names, plan) =
177            extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
178
179        assert_eq!(1, table_names.len());
180        assert!(table_names.contains(&TableName::new(
181            DEFAULT_CATALOG_NAME.to_string(),
182            "test".to_string(),
183            "devices".to_string()
184        )));
185
186        assert_eq!(
187            "Filter: devices.id > Int32(500)\n  TableScan: greptime.test.devices",
188            plan.to_string()
189        );
190    }
191}