query/datafusion/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::function::FunctionContext;
22use datafusion::common::TableReference;
23use datafusion::datasource::cte_worktable::CteWorkTable;
24use datafusion::datasource::file_format::{format_as_file_type, FileFormatFactory};
25use datafusion::datasource::provider_as_source;
26use datafusion::error::Result as DfResult;
27use datafusion::execution::context::SessionState;
28use datafusion::execution::SessionStateDefaults;
29use datafusion::sql::planner::ContextProvider;
30use datafusion::variable::VarType;
31use datafusion_common::config::ConfigOptions;
32use datafusion_common::file_options::file_type::FileType;
33use datafusion_common::DataFusionError;
34use datafusion_expr::planner::{ExprPlanner, TypePlanner};
35use datafusion_expr::var_provider::is_system_variables;
36use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
37use datafusion_sql::parser::Statement as DfStatement;
38use session::context::QueryContextRef;
39use snafu::{Location, ResultExt};
40
41use crate::error::{CatalogSnafu, Result};
42use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
43
44pub struct DfContextProviderAdapter {
45    engine_state: Arc<QueryEngineState>,
46    session_state: SessionState,
47    tables: HashMap<String, Arc<dyn TableSource>>,
48    table_provider: DfTableSourceProvider,
49    query_ctx: QueryContextRef,
50
51    // Fields from session state defaults:
52    /// Holds registered external FileFormat implementations
53    /// DataFusion doesn't pub this field, so we need to store it here.
54    file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
55    /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
56    /// DataFusion doesn't pub this field, so we need to store it here.
57    expr_planners: Vec<Arc<dyn ExprPlanner>>,
58}
59
60impl DfContextProviderAdapter {
61    pub(crate) async fn try_new(
62        engine_state: Arc<QueryEngineState>,
63        session_state: SessionState,
64        df_stmt: Option<&DfStatement>,
65        query_ctx: QueryContextRef,
66    ) -> Result<Self> {
67        let table_names = if let Some(df_stmt) = df_stmt {
68            session_state.resolve_table_references(df_stmt)?
69        } else {
70            vec![]
71        };
72
73        let mut table_provider = DfTableSourceProvider::new(
74            engine_state.catalog_manager().clone(),
75            engine_state.disallow_cross_catalog_query(),
76            query_ctx.clone(),
77            Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
78            session_state
79                .config_options()
80                .sql_parser
81                .enable_ident_normalization,
82        );
83
84        let tables = resolve_tables(table_names, &mut table_provider).await?;
85        let file_formats = SessionStateDefaults::default_file_formats()
86            .into_iter()
87            .map(|format| (format.get_ext().to_lowercase(), format))
88            .collect();
89
90        Ok(Self {
91            engine_state,
92            session_state,
93            tables,
94            table_provider,
95            query_ctx,
96            file_formats,
97            expr_planners: SessionStateDefaults::default_expr_planners(),
98        })
99    }
100}
101
102async fn resolve_tables(
103    table_names: Vec<TableReference>,
104    table_provider: &mut DfTableSourceProvider,
105) -> Result<HashMap<String, Arc<dyn TableSource>>> {
106    let mut tables = HashMap::with_capacity(table_names.len());
107
108    for table_name in table_names {
109        let resolved_name = table_provider
110            .resolve_table_ref(table_name.clone())
111            .context(CatalogSnafu)?;
112
113        if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
114            // Try our best to resolve the tables here, but we don't return an error if table is not found,
115            // because the table name may be a temporary name of CTE, they can't be found until plan
116            // execution.
117            match table_provider.resolve_table(table_name).await {
118                Ok(table) => {
119                    let _ = v.insert(table);
120                }
121                Err(e) if e.should_fail() => {
122                    return Err(e).context(CatalogSnafu);
123                }
124                _ => {
125                    // ignore
126                }
127            }
128        }
129    }
130    Ok(tables)
131}
132
133impl ContextProvider for DfContextProviderAdapter {
134    fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
135        let table_ref = self.table_provider.resolve_table_ref(name)?;
136        self.tables
137            .get(&table_ref.to_string())
138            .cloned()
139            .ok_or_else(|| {
140                crate::error::Error::TableNotFound {
141                    table: table_ref.to_string(),
142                    location: Location::default(),
143                }
144                .into()
145            })
146    }
147
148    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
149        self.engine_state.scalar_function(name).map_or_else(
150            || self.session_state.scalar_functions().get(name).cloned(),
151            |func| {
152                Some(Arc::new(func.provide(FunctionContext {
153                    query_ctx: self.query_ctx.clone(),
154                    state: self.engine_state.function_state(),
155                })))
156            },
157        )
158    }
159
160    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
161        self.engine_state.aggr_function(name).map_or_else(
162            || self.session_state.aggregate_functions().get(name).cloned(),
163            |func| Some(Arc::new(func)),
164        )
165    }
166
167    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
168        self.session_state.window_functions().get(name).cloned()
169    }
170
171    fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
172        if variable_names.is_empty() {
173            return None;
174        }
175
176        let provider_type = if is_system_variables(variable_names) {
177            VarType::System
178        } else {
179            VarType::UserDefined
180        };
181
182        self.session_state
183            .execution_props()
184            .var_providers
185            .as_ref()
186            .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
187    }
188
189    fn options(&self) -> &ConfigOptions {
190        self.session_state.config_options()
191    }
192
193    fn udf_names(&self) -> Vec<String> {
194        let mut names = self.engine_state.scalar_names();
195        names.extend(self.session_state.scalar_functions().keys().cloned());
196        names
197    }
198
199    fn udaf_names(&self) -> Vec<String> {
200        let mut names = self.engine_state.aggr_names();
201        names.extend(self.session_state.aggregate_functions().keys().cloned());
202        names
203    }
204
205    fn udwf_names(&self) -> Vec<String> {
206        self.session_state
207            .window_functions()
208            .keys()
209            .cloned()
210            .collect()
211    }
212
213    fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
214        self.file_formats
215            .get(&ext.to_lowercase())
216            .ok_or_else(|| {
217                DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
218            })
219            .map(|file_type| format_as_file_type(Arc::clone(file_type)))
220    }
221
222    fn get_table_function_source(
223        &self,
224        name: &str,
225        args: Vec<datafusion_expr::Expr>,
226    ) -> DfResult<Arc<dyn TableSource>> {
227        let tbl_func = self
228            .session_state
229            .table_functions()
230            .get(name)
231            .cloned()
232            .ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?;
233        let provider = tbl_func.create_table_provider(&args)?;
234
235        Ok(provider_as_source(provider))
236    }
237
238    fn create_cte_work_table(
239        &self,
240        name: &str,
241        schema: arrow_schema::SchemaRef,
242    ) -> DfResult<Arc<dyn TableSource>> {
243        let table = Arc::new(CteWorkTable::new(name, schema));
244        Ok(provider_as_source(table))
245    }
246
247    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
248        &self.expr_planners
249    }
250
251    fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
252        None
253    }
254}