query/
plan.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::TableReference;
20use datafusion_expr::{Expr, LogicalPlan};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29    pub(crate) table_names: HashSet<TableName>,
30    query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34    type Node = LogicalPlan;
35
36    /// descend
37    fn f_down<'a>(
38        &mut self,
39        node: Self::Node,
40    ) -> datafusion::error::Result<Transformed<Self::Node>> {
41        match node {
42            LogicalPlan::TableScan(mut scan) => {
43                if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
44                    if let Some(provider) = source
45                        .table_provider
46                        .as_any()
47                        .downcast_ref::<DfTableProviderAdapter>()
48                    {
49                        if provider.table().table_type() == TableType::Base {
50                            let info = provider.table().table_info();
51                            self.table_names.insert(TableName::new(
52                                info.catalog_name.clone(),
53                                info.schema_name.clone(),
54                                info.name.clone(),
55                            ));
56                        }
57                    }
58                }
59                match &scan.table_name {
60                    TableReference::Full {
61                        catalog,
62                        schema,
63                        table,
64                    } => {
65                        self.table_names.insert(TableName::new(
66                            catalog.to_string(),
67                            schema.to_string(),
68                            table.to_string(),
69                        ));
70                    }
71                    TableReference::Partial { schema, table } => {
72                        self.table_names.insert(TableName::new(
73                            self.query_ctx.current_catalog(),
74                            schema.to_string(),
75                            table.to_string(),
76                        ));
77
78                        scan.table_name = TableReference::Full {
79                            catalog: self.query_ctx.current_catalog().into(),
80                            schema: schema.clone(),
81                            table: table.clone(),
82                        };
83                    }
84                    TableReference::Bare { table } => {
85                        self.table_names.insert(TableName::new(
86                            self.query_ctx.current_catalog(),
87                            self.query_ctx.current_schema(),
88                            table.to_string(),
89                        ));
90
91                        scan.table_name = TableReference::Full {
92                            catalog: self.query_ctx.current_catalog().into(),
93                            schema: self.query_ctx.current_schema().into(),
94                            table: table.clone(),
95                        };
96                    }
97                }
98                Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
99            }
100            node => Ok(Transformed::no(node)),
101        }
102    }
103}
104
105impl TableNamesExtractAndRewriter {
106    fn new(query_ctx: QueryContextRef) -> Self {
107        Self {
108            query_ctx,
109            table_names: HashSet::new(),
110        }
111    }
112}
113
114/// Extracts and rewrites the table names in the plan in the fully qualified style,
115/// return the table names and new plan.
116pub fn extract_and_rewrite_full_table_names(
117    plan: LogicalPlan,
118    query_ctx: QueryContextRef,
119) -> Result<(HashSet<TableName>, LogicalPlan)> {
120    let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
121    let plan = plan.rewrite(&mut extractor)?;
122    Ok((extractor.table_names, plan.data))
123}
124
125/// A trait to extract expressions from a logical plan.
126pub trait ExtractExpr {
127    /// Gets expressions from a logical plan.
128    /// It handles [Join] specially so [LogicalPlan::with_new_exprs()] can use the expressions
129    /// this method returns.
130    fn expressions_consider_join(&self) -> Vec<Expr>;
131}
132
133impl ExtractExpr for LogicalPlan {
134    fn expressions_consider_join(&self) -> Vec<Expr> {
135        self.expressions()
136    }
137}
138
139#[cfg(test)]
140pub(crate) mod tests {
141
142    use std::sync::Arc;
143
144    use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
145    use common_catalog::consts::DEFAULT_CATALOG_NAME;
146    use datafusion::logical_expr::builder::LogicalTableSource;
147    use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
148    use session::context::QueryContextBuilder;
149
150    use super::*;
151
152    fn mock_plan() -> LogicalPlan {
153        let schema = Schema::new(vec![
154            Field::new("id", DataType::Int32, true),
155            Field::new("name", DataType::Utf8, true),
156            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
157        ]);
158        let table_source = LogicalTableSource::new(SchemaRef::new(schema));
159
160        let projection = None;
161
162        let builder =
163            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
164
165        builder
166            .filter(col("id").gt(lit(500)))
167            .unwrap()
168            .build()
169            .unwrap()
170    }
171
172    #[test]
173    fn test_extract_full_table_names() {
174        let ctx = QueryContextBuilder::default()
175            .current_schema("test".to_string())
176            .build();
177
178        let (table_names, plan) =
179            extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
180
181        assert_eq!(1, table_names.len());
182        assert!(table_names.contains(&TableName::new(
183            DEFAULT_CATALOG_NAME.to_string(),
184            "test".to_string(),
185            "devices".to_string()
186        )));
187
188        assert_eq!(
189            "Filter: devices.id > Int32(500)\n  TableScan: greptime.test.devices",
190            plan.to_string()
191        );
192    }
193}