query/datafusion/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::aggr::{
22    GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME,
23    UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
24};
25use common_function::scalars::udf::create_udf;
26use common_query::logical_plan::create_aggregate_function;
27use datafusion::common::TableReference;
28use datafusion::datasource::cte_worktable::CteWorkTable;
29use datafusion::datasource::file_format::{format_as_file_type, FileFormatFactory};
30use datafusion::datasource::provider_as_source;
31use datafusion::error::Result as DfResult;
32use datafusion::execution::context::SessionState;
33use datafusion::execution::SessionStateDefaults;
34use datafusion::sql::planner::ContextProvider;
35use datafusion::variable::VarType;
36use datafusion_common::config::ConfigOptions;
37use datafusion_common::file_options::file_type::FileType;
38use datafusion_common::DataFusionError;
39use datafusion_expr::planner::{ExprPlanner, TypePlanner};
40use datafusion_expr::var_provider::is_system_variables;
41use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
42use datafusion_sql::parser::Statement as DfStatement;
43use session::context::QueryContextRef;
44use snafu::{Location, ResultExt};
45
46use crate::error::{CatalogSnafu, Result};
47use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
48
49pub struct DfContextProviderAdapter {
50    engine_state: Arc<QueryEngineState>,
51    session_state: SessionState,
52    tables: HashMap<String, Arc<dyn TableSource>>,
53    table_provider: DfTableSourceProvider,
54    query_ctx: QueryContextRef,
55
56    // Fields from session state defaults:
57    /// Holds registered external FileFormat implementations
58    /// DataFusion doesn't pub this field, so we need to store it here.
59    file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
60    /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
61    /// DataFusion doesn't pub this field, so we need to store it here.
62    expr_planners: Vec<Arc<dyn ExprPlanner>>,
63}
64
65impl DfContextProviderAdapter {
66    pub(crate) async fn try_new(
67        engine_state: Arc<QueryEngineState>,
68        session_state: SessionState,
69        df_stmt: Option<&DfStatement>,
70        query_ctx: QueryContextRef,
71    ) -> Result<Self> {
72        let table_names = if let Some(df_stmt) = df_stmt {
73            session_state.resolve_table_references(df_stmt)?
74        } else {
75            vec![]
76        };
77
78        let mut table_provider = DfTableSourceProvider::new(
79            engine_state.catalog_manager().clone(),
80            engine_state.disallow_cross_catalog_query(),
81            query_ctx.clone(),
82            Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
83            session_state
84                .config_options()
85                .sql_parser
86                .enable_ident_normalization,
87        );
88
89        let tables = resolve_tables(table_names, &mut table_provider).await?;
90        let file_formats = SessionStateDefaults::default_file_formats()
91            .into_iter()
92            .map(|format| (format.get_ext().to_lowercase(), format))
93            .collect();
94
95        Ok(Self {
96            engine_state,
97            session_state,
98            tables,
99            table_provider,
100            query_ctx,
101            file_formats,
102            expr_planners: SessionStateDefaults::default_expr_planners(),
103        })
104    }
105}
106
107async fn resolve_tables(
108    table_names: Vec<TableReference>,
109    table_provider: &mut DfTableSourceProvider,
110) -> Result<HashMap<String, Arc<dyn TableSource>>> {
111    let mut tables = HashMap::with_capacity(table_names.len());
112
113    for table_name in table_names {
114        let resolved_name = table_provider
115            .resolve_table_ref(table_name.clone())
116            .context(CatalogSnafu)?;
117
118        if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
119            // Try our best to resolve the tables here, but we don't return an error if table is not found,
120            // because the table name may be a temporary name of CTE, they can't be found until plan
121            // execution.
122            match table_provider.resolve_table(table_name).await {
123                Ok(table) => {
124                    let _ = v.insert(table);
125                }
126                Err(e) if e.should_fail() => {
127                    return Err(e).context(CatalogSnafu);
128                }
129                _ => {
130                    // ignore
131                }
132            }
133        }
134    }
135    Ok(tables)
136}
137
138impl ContextProvider for DfContextProviderAdapter {
139    fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
140        let table_ref = self.table_provider.resolve_table_ref(name)?;
141        self.tables
142            .get(&table_ref.to_string())
143            .cloned()
144            .ok_or_else(|| {
145                crate::error::Error::TableNotFound {
146                    table: table_ref.to_string(),
147                    location: Location::default(),
148                }
149                .into()
150            })
151    }
152
153    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
154        self.engine_state.udf_function(name).map_or_else(
155            || self.session_state.scalar_functions().get(name).cloned(),
156            |func| {
157                Some(Arc::new(create_udf(
158                    func,
159                    self.query_ctx.clone(),
160                    self.engine_state.function_state(),
161                )))
162            },
163        )
164    }
165
166    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
167        if name == UDDSKETCH_STATE_NAME {
168            return Some(Arc::new(UddSketchState::state_udf_impl()));
169        } else if name == UDDSKETCH_MERGE_NAME {
170            return Some(Arc::new(UddSketchState::merge_udf_impl()));
171        } else if name == HLL_NAME {
172            return Some(Arc::new(HllState::state_udf_impl()));
173        } else if name == HLL_MERGE_NAME {
174            return Some(Arc::new(HllState::merge_udf_impl()));
175        } else if name == GEO_PATH_NAME {
176            return Some(Arc::new(GeoPathAccumulator::udf_impl()));
177        }
178
179        self.engine_state.aggregate_function(name).map_or_else(
180            || self.session_state.aggregate_functions().get(name).cloned(),
181            |func| {
182                Some(Arc::new(
183                    create_aggregate_function(func.name(), func.args_count(), func.create()).into(),
184                ))
185            },
186        )
187    }
188
189    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
190        self.session_state.window_functions().get(name).cloned()
191    }
192
193    fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
194        if variable_names.is_empty() {
195            return None;
196        }
197
198        let provider_type = if is_system_variables(variable_names) {
199            VarType::System
200        } else {
201            VarType::UserDefined
202        };
203
204        self.session_state
205            .execution_props()
206            .var_providers
207            .as_ref()
208            .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
209    }
210
211    fn options(&self) -> &ConfigOptions {
212        self.session_state.config_options()
213    }
214
215    fn udf_names(&self) -> Vec<String> {
216        let mut names = self.engine_state.udf_names();
217        names.extend(self.session_state.scalar_functions().keys().cloned());
218        names
219    }
220
221    fn udaf_names(&self) -> Vec<String> {
222        let mut names = self.engine_state.udaf_names();
223        names.extend(self.session_state.aggregate_functions().keys().cloned());
224        names
225    }
226
227    fn udwf_names(&self) -> Vec<String> {
228        self.session_state
229            .window_functions()
230            .keys()
231            .cloned()
232            .collect()
233    }
234
235    fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
236        self.file_formats
237            .get(&ext.to_lowercase())
238            .ok_or_else(|| {
239                DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
240            })
241            .map(|file_type| format_as_file_type(Arc::clone(file_type)))
242    }
243
244    fn get_table_function_source(
245        &self,
246        name: &str,
247        args: Vec<datafusion_expr::Expr>,
248    ) -> DfResult<Arc<dyn TableSource>> {
249        let tbl_func = self
250            .session_state
251            .table_functions()
252            .get(name)
253            .cloned()
254            .ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?;
255        let provider = tbl_func.create_table_provider(&args)?;
256
257        Ok(provider_as_source(provider))
258    }
259
260    fn create_cte_work_table(
261        &self,
262        name: &str,
263        schema: arrow_schema::SchemaRef,
264    ) -> DfResult<Arc<dyn TableSource>> {
265        let table = Arc::new(CteWorkTable::new(name, schema));
266        Ok(provider_as_source(table))
267    }
268
269    fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
270        &self.expr_planners
271    }
272
273    fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
274        None
275    }
276}