query/
plan.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use datafusion::datasource::DefaultTableSource;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::TableReference;
20use datafusion_expr::{BinaryExpr, Expr, Join, LogicalPlan, Operator};
21use session::context::QueryContextRef;
22pub use table::metadata::TableType;
23use table::table::adapter::DfTableProviderAdapter;
24use table::table_name::TableName;
25
26use crate::error::Result;
27
28struct TableNamesExtractAndRewriter {
29    pub(crate) table_names: HashSet<TableName>,
30    query_ctx: QueryContextRef,
31}
32
33impl TreeNodeRewriter for TableNamesExtractAndRewriter {
34    type Node = LogicalPlan;
35
36    /// descend
37    fn f_down<'a>(
38        &mut self,
39        node: Self::Node,
40    ) -> datafusion::error::Result<Transformed<Self::Node>> {
41        match node {
42            LogicalPlan::TableScan(mut scan) => {
43                if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
44                    if let Some(provider) = source
45                        .table_provider
46                        .as_any()
47                        .downcast_ref::<DfTableProviderAdapter>()
48                    {
49                        if provider.table().table_type() == TableType::Base {
50                            let info = provider.table().table_info();
51                            self.table_names.insert(TableName::new(
52                                info.catalog_name.clone(),
53                                info.schema_name.clone(),
54                                info.name.clone(),
55                            ));
56                        }
57                    }
58                }
59                match &scan.table_name {
60                    TableReference::Full {
61                        catalog,
62                        schema,
63                        table,
64                    } => {
65                        self.table_names.insert(TableName::new(
66                            catalog.to_string(),
67                            schema.to_string(),
68                            table.to_string(),
69                        ));
70                    }
71                    TableReference::Partial { schema, table } => {
72                        self.table_names.insert(TableName::new(
73                            self.query_ctx.current_catalog(),
74                            schema.to_string(),
75                            table.to_string(),
76                        ));
77
78                        scan.table_name = TableReference::Full {
79                            catalog: self.query_ctx.current_catalog().into(),
80                            schema: schema.clone(),
81                            table: table.clone(),
82                        };
83                    }
84                    TableReference::Bare { table } => {
85                        self.table_names.insert(TableName::new(
86                            self.query_ctx.current_catalog(),
87                            self.query_ctx.current_schema(),
88                            table.to_string(),
89                        ));
90
91                        scan.table_name = TableReference::Full {
92                            catalog: self.query_ctx.current_catalog().into(),
93                            schema: self.query_ctx.current_schema().into(),
94                            table: table.clone(),
95                        };
96                    }
97                }
98                Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
99            }
100            node => Ok(Transformed::no(node)),
101        }
102    }
103}
104
105impl TableNamesExtractAndRewriter {
106    fn new(query_ctx: QueryContextRef) -> Self {
107        Self {
108            query_ctx,
109            table_names: HashSet::new(),
110        }
111    }
112}
113
114/// Extracts and rewrites the table names in the plan in the fully qualified style,
115/// return the table names and new plan.
116pub fn extract_and_rewrite_full_table_names(
117    plan: LogicalPlan,
118    query_ctx: QueryContextRef,
119) -> Result<(HashSet<TableName>, LogicalPlan)> {
120    let mut extractor = TableNamesExtractAndRewriter::new(query_ctx);
121    let plan = plan.rewrite(&mut extractor)?;
122    Ok((extractor.table_names, plan.data))
123}
124
125/// A trait to extract expressions from a logical plan.
126pub trait ExtractExpr {
127    /// Gets expressions from a logical plan.
128    /// It handles [Join] specially so [LogicalPlan::with_new_exprs()] can use the expressions
129    /// this method returns.
130    fn expressions_consider_join(&self) -> Vec<Expr>;
131}
132
133impl ExtractExpr for LogicalPlan {
134    fn expressions_consider_join(&self) -> Vec<Expr> {
135        match self {
136            LogicalPlan::Join(Join { on, filter, .. }) => {
137                // The first part of expr is equi-exprs,
138                // and the struct of each equi-expr is like `left-expr = right-expr`.
139                // We only normalize the filter_expr (non equality predicate from ON clause).
140                on.iter()
141                    .map(|(left, right)| {
142                        Expr::BinaryExpr(BinaryExpr {
143                            left: Box::new(left.clone()),
144                            op: Operator::Eq,
145                            right: Box::new(right.clone()),
146                        })
147                    })
148                    .chain(filter.clone())
149                    .collect()
150            }
151            _ => self.expressions(),
152        }
153    }
154}
155
156#[cfg(test)]
157pub(crate) mod tests {
158
159    use std::sync::Arc;
160
161    use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
162    use common_catalog::consts::DEFAULT_CATALOG_NAME;
163    use datafusion::logical_expr::builder::LogicalTableSource;
164    use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
165    use session::context::QueryContextBuilder;
166
167    use super::*;
168
169    fn mock_plan() -> LogicalPlan {
170        let schema = Schema::new(vec![
171            Field::new("id", DataType::Int32, true),
172            Field::new("name", DataType::Utf8, true),
173            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
174        ]);
175        let table_source = LogicalTableSource::new(SchemaRef::new(schema));
176
177        let projection = None;
178
179        let builder =
180            LogicalPlanBuilder::scan("devices", Arc::new(table_source), projection).unwrap();
181
182        builder
183            .filter(col("id").gt(lit(500)))
184            .unwrap()
185            .build()
186            .unwrap()
187    }
188
189    #[test]
190    fn test_extract_full_table_names() {
191        let ctx = QueryContextBuilder::default()
192            .current_schema("test".to_string())
193            .build();
194
195        let (table_names, plan) =
196            extract_and_rewrite_full_table_names(mock_plan(), Arc::new(ctx)).unwrap();
197
198        assert_eq!(1, table_names.len());
199        assert!(table_names.contains(&TableName::new(
200            DEFAULT_CATALOG_NAME.to_string(),
201            "test".to_string(),
202            "devices".to_string()
203        )));
204
205        assert_eq!(
206            "Filter: devices.id > Int32(500)\n  TableScan: greptime.test.devices",
207            plan.to_string()
208        );
209    }
210}