query/datafusion/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::function::FunctionContext;
22use datafusion::common::TableReference;
23use datafusion::datasource::cte_worktable::CteWorkTable;
24use datafusion::datasource::file_format::{FileFormatFactory, format_as_file_type};
25use datafusion::datasource::provider_as_source;
26use datafusion::error::Result as DfResult;
27use datafusion::execution::SessionStateDefaults;
28use datafusion::execution::context::SessionState;
29use datafusion::sql::planner::ContextProvider;
30use datafusion::variable::VarType;
31use datafusion_common::DataFusionError;
32use datafusion_common::config::ConfigOptions;
33use datafusion_common::file_options::file_type::FileType;
34use datafusion_expr::planner::{ExprPlanner, TypePlanner};
35use datafusion_expr::var_provider::is_system_variables;
36use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
37use datafusion_sql::parser::Statement as DfStatement;
38use session::context::QueryContextRef;
39use snafu::{Location, ResultExt};
40
41use crate::error::{CatalogSnafu, Result};
42use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
43
44mod function_alias;
45
46pub struct DfContextProviderAdapter {
47    engine_state: Arc<QueryEngineState>,
48    session_state: SessionState,
49    tables: HashMap<String, Arc<dyn TableSource>>,
50    table_provider: DfTableSourceProvider,
51    query_ctx: QueryContextRef,
52
53    // Fields from session state defaults:
54    /// Holds registered external FileFormat implementations
55    /// DataFusion doesn't pub this field, so we need to store it here.
56    file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
57    /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
58    /// DataFusion doesn't pub this field, so we need to store it here.
59    expr_planners: Vec<Arc<dyn ExprPlanner>>,
60}
61
62impl DfContextProviderAdapter {
63    pub(crate) async fn try_new(
64        engine_state: Arc<QueryEngineState>,
65        session_state: SessionState,
66        df_stmt: Option<&DfStatement>,
67        query_ctx: QueryContextRef,
68    ) -> Result<Self> {
69        let table_names = if let Some(df_stmt) = df_stmt {
70            session_state.resolve_table_references(df_stmt)?
71        } else {
72            vec![]
73        };
74
75        let mut table_provider = DfTableSourceProvider::new(
76            engine_state.catalog_manager().clone(),
77            engine_state.disallow_cross_catalog_query(),
78            query_ctx.clone(),
79            Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
80            session_state
81                .config_options()
82                .sql_parser
83                .enable_ident_normalization,
84        );
85
86        let tables = resolve_tables(table_names, &mut table_provider).await?;
87        let file_formats = SessionStateDefaults::default_file_formats()
88            .into_iter()
89            .map(|format| (format.get_ext().to_lowercase(), format))
90            .collect();
91
92        Ok(Self {
93            engine_state,
94            session_state,
95            tables,
96            table_provider,
97            query_ctx,
98            file_formats,
99            expr_planners: SessionStateDefaults::default_expr_planners(),
100        })
101    }
102}
103
104async fn resolve_tables(
105    table_names: Vec<TableReference>,
106    table_provider: &mut DfTableSourceProvider,
107) -> Result<HashMap<String, Arc<dyn TableSource>>> {
108    let mut tables = HashMap::with_capacity(table_names.len());
109
110    for table_name in table_names {
111        let resolved_name = table_provider
112            .resolve_table_ref(table_name.clone())
113            .context(CatalogSnafu)?;
114
115        if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
116            // Try our best to resolve the tables here, but we don't return an error if table is not found,
117            // because the table name may be a temporary name of CTE, they can't be found until plan
118            // execution.
119            match table_provider.resolve_table(table_name).await {
120                Ok(table) => {
121                    let _ = v.insert(table);
122                }
123                Err(e) if e.should_fail() => {
124                    return Err(e).context(CatalogSnafu);
125                }
126                _ => {
127                    // ignore
128                }
129            }
130        }
131    }
132    Ok(tables)
133}
134
135impl ContextProvider for DfContextProviderAdapter {
136    fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
137        let table_ref = self.table_provider.resolve_table_ref(name)?;
138        self.tables
139            .get(&table_ref.to_string())
140            .cloned()
141            .ok_or_else(|| {
142                crate::error::Error::TableNotFound {
143                    table: table_ref.to_string(),
144                    location: Location::default(),
145                }
146                .into()
147            })
148    }
149
150    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
151        self.engine_state.scalar_function(name).map_or_else(
152            || {
153                self.session_state
154                    .scalar_functions()
155                    .get(name)
156                    .cloned()
157                    .or_else(|| {
158                        function_alias::resolve_scalar(name).and_then(|name| {
159                            self.session_state.scalar_functions().get(name).cloned()
160                        })
161                    })
162            },
163            |func| {
164                Some(Arc::new(func.provide(FunctionContext {
165                    query_ctx: self.query_ctx.clone(),
166                    state: self.engine_state.function_state(),
167                })))
168            },
169        )
170    }
171
172    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
173        self.engine_state.aggr_function(name).map_or_else(
174            || {
175                self.session_state
176                    .aggregate_functions()
177                    .get(name)
178                    .cloned()
179                    .or_else(|| {
180                        function_alias::resolve_aggregate(name).and_then(|name| {
181                            self.session_state.aggregate_functions().get(name).cloned()
182                        })
183                    })
184            },
185            |func| Some(Arc::new(func)),
186        )
187    }
188
189    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
190        self.session_state.window_functions().get(name).cloned()
191    }
192
193    fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
194        if variable_names.is_empty() {
195            return None;
196        }
197
198        let provider_type = if is_system_variables(variable_names) {
199            VarType::System
200        } else {
201            VarType::UserDefined
202        };
203
204        self.session_state
205            .execution_props()
206            .var_providers
207            .as_ref()
208            .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
209    }
210
211    fn options(&self) -> &ConfigOptions {
212        self.session_state.config_options()
213    }
214
215    fn udf_names(&self) -> Vec<String> {
216        let mut names = self.engine_state.scalar_names();
217        names.extend(self.session_state.scalar_functions().keys().cloned());
218        names.extend(function_alias::scalar_alias_names().map(|name| name.to_string()));
219        names
220    }
221
222    fn udaf_names(&self) -> Vec<String> {
223        let mut names = self.engine_state.aggr_names();
224        names.extend(self.session_state.aggregate_functions().keys().cloned());
225        names.extend(function_alias::aggregate_alias_names().map(|name| name.to_string()));
226        names
227    }
228
229    fn udwf_names(&self) -> Vec<String> {
230        self.session_state
231            .window_functions()
232            .keys()
233            .cloned()
234            .collect()
235    }
236
237    fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
238        self.file_formats
239            .get(&ext.to_lowercase())
240            .ok_or_else(|| {
241                DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
242            })
243            .map(|file_type| format_as_file_type(Arc::clone(file_type)))
244    }
245
246    fn get_table_function_source(
247        &self,
248        name: &str,
249        args: Vec<datafusion_expr::Expr>,
250    ) -> DfResult<Arc<dyn TableSource>> {
251        if let Some(tbl_func) = self.engine_state.table_function(name) {
252            let provider = tbl_func.create_table_provider(&args)?;
253            Ok(provider_as_source(provider))
254        } else {
255            let tbl_func = self
256                .session_state
257                .table_functions()
258                .get(name)
259                .cloned()
260                .or_else(|| {
261                    function_alias::resolve_scalar(name)
262                        .and_then(|alias| self.session_state.table_functions().get(alias).cloned())
263                });
264
265            let tbl_func = tbl_func.ok_or_else(|| {
266                DataFusionError::Plan(format!("table function '{name}' not found"))
267            })?;
268            let provider = tbl_func.create_table_provider(&args)?;
269
270            Ok(provider_as_source(provider))
271        }
272    }
273
274    fn create_cte_work_table(
275        &self,
276        name: &str,
277        schema: arrow_schema::SchemaRef,
278    ) -> DfResult<Arc<dyn TableSource>> {
279        let table = Arc::new(CteWorkTable::new(name, schema));
280        Ok(provider_as_source(table))
281    }
282
283    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
284        &self.expr_planners
285    }
286
287    fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
288        None
289    }
290}