query/datafusion/
planner.rs1use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::function::FunctionContext;
22use datafusion::common::TableReference;
23use datafusion::datasource::cte_worktable::CteWorkTable;
24use datafusion::datasource::file_format::{FileFormatFactory, format_as_file_type};
25use datafusion::datasource::provider_as_source;
26use datafusion::error::Result as DfResult;
27use datafusion::execution::SessionStateDefaults;
28use datafusion::execution::context::SessionState;
29use datafusion::sql::planner::ContextProvider;
30use datafusion::variable::VarType;
31use datafusion_common::DataFusionError;
32use datafusion_common::config::ConfigOptions;
33use datafusion_common::file_options::file_type::FileType;
34use datafusion_expr::planner::{ExprPlanner, TypePlanner};
35use datafusion_expr::var_provider::is_system_variables;
36use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
37use datafusion_sql::parser::Statement as DfStatement;
38use session::context::QueryContextRef;
39use snafu::{Location, ResultExt};
40
41use crate::error::{CatalogSnafu, Result};
42use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
43
44mod function_alias;
45
46pub struct DfContextProviderAdapter {
47 engine_state: Arc<QueryEngineState>,
48 session_state: SessionState,
49 tables: HashMap<String, Arc<dyn TableSource>>,
50 table_provider: DfTableSourceProvider,
51 query_ctx: QueryContextRef,
52
53 file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
57 expr_planners: Vec<Arc<dyn ExprPlanner>>,
60}
61
62impl DfContextProviderAdapter {
63 pub(crate) async fn try_new(
64 engine_state: Arc<QueryEngineState>,
65 session_state: SessionState,
66 df_stmt: Option<&DfStatement>,
67 query_ctx: QueryContextRef,
68 ) -> Result<Self> {
69 let table_names = if let Some(df_stmt) = df_stmt {
70 session_state.resolve_table_references(df_stmt)?
71 } else {
72 vec![]
73 };
74
75 let mut table_provider = DfTableSourceProvider::new(
76 engine_state.catalog_manager().clone(),
77 engine_state.disallow_cross_catalog_query(),
78 query_ctx.clone(),
79 Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
80 session_state
81 .config_options()
82 .sql_parser
83 .enable_ident_normalization,
84 );
85
86 let tables = resolve_tables(table_names, &mut table_provider).await?;
87 let file_formats = SessionStateDefaults::default_file_formats()
88 .into_iter()
89 .map(|format| (format.get_ext().to_lowercase(), format))
90 .collect();
91
92 Ok(Self {
93 engine_state,
94 session_state,
95 tables,
96 table_provider,
97 query_ctx,
98 file_formats,
99 expr_planners: SessionStateDefaults::default_expr_planners(),
100 })
101 }
102}
103
104async fn resolve_tables(
105 table_names: Vec<TableReference>,
106 table_provider: &mut DfTableSourceProvider,
107) -> Result<HashMap<String, Arc<dyn TableSource>>> {
108 let mut tables = HashMap::with_capacity(table_names.len());
109
110 for table_name in table_names {
111 let resolved_name = table_provider
112 .resolve_table_ref(table_name.clone())
113 .context(CatalogSnafu)?;
114
115 if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
116 match table_provider.resolve_table(table_name).await {
120 Ok(table) => {
121 let _ = v.insert(table);
122 }
123 Err(e) if e.should_fail() => {
124 return Err(e).context(CatalogSnafu);
125 }
126 _ => {
127 }
129 }
130 }
131 }
132 Ok(tables)
133}
134
135impl ContextProvider for DfContextProviderAdapter {
136 fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
137 let table_ref = self.table_provider.resolve_table_ref(name)?;
138 self.tables
139 .get(&table_ref.to_string())
140 .cloned()
141 .ok_or_else(|| {
142 crate::error::Error::TableNotFound {
143 table: table_ref.to_string(),
144 location: Location::default(),
145 }
146 .into()
147 })
148 }
149
150 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
151 self.engine_state.scalar_function(name).map_or_else(
152 || {
153 self.session_state
154 .scalar_functions()
155 .get(name)
156 .cloned()
157 .or_else(|| {
158 function_alias::resolve_scalar(name).and_then(|name| {
159 self.session_state.scalar_functions().get(name).cloned()
160 })
161 })
162 },
163 |func| {
164 Some(Arc::new(func.provide(FunctionContext {
165 query_ctx: self.query_ctx.clone(),
166 state: self.engine_state.function_state(),
167 })))
168 },
169 )
170 }
171
172 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
173 self.engine_state.aggr_function(name).map_or_else(
174 || {
175 self.session_state
176 .aggregate_functions()
177 .get(name)
178 .cloned()
179 .or_else(|| {
180 function_alias::resolve_aggregate(name).and_then(|name| {
181 self.session_state.aggregate_functions().get(name).cloned()
182 })
183 })
184 },
185 |func| Some(Arc::new(func)),
186 )
187 }
188
189 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
190 self.session_state.window_functions().get(name).cloned()
191 }
192
193 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
194 if variable_names.is_empty() {
195 return None;
196 }
197
198 let provider_type = if is_system_variables(variable_names) {
199 VarType::System
200 } else {
201 VarType::UserDefined
202 };
203
204 self.session_state
205 .execution_props()
206 .var_providers
207 .as_ref()
208 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
209 }
210
211 fn options(&self) -> &ConfigOptions {
212 self.session_state.config_options()
213 }
214
215 fn udf_names(&self) -> Vec<String> {
216 let mut names = self.engine_state.scalar_names();
217 names.extend(self.session_state.scalar_functions().keys().cloned());
218 names.extend(function_alias::scalar_alias_names().map(|name| name.to_string()));
219 names
220 }
221
222 fn udaf_names(&self) -> Vec<String> {
223 let mut names = self.engine_state.aggr_names();
224 names.extend(self.session_state.aggregate_functions().keys().cloned());
225 names.extend(function_alias::aggregate_alias_names().map(|name| name.to_string()));
226 names
227 }
228
229 fn udwf_names(&self) -> Vec<String> {
230 self.session_state
231 .window_functions()
232 .keys()
233 .cloned()
234 .collect()
235 }
236
237 fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
238 self.file_formats
239 .get(&ext.to_lowercase())
240 .ok_or_else(|| {
241 DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
242 })
243 .map(|file_type| format_as_file_type(Arc::clone(file_type)))
244 }
245
246 fn get_table_function_source(
247 &self,
248 name: &str,
249 args: Vec<datafusion_expr::Expr>,
250 ) -> DfResult<Arc<dyn TableSource>> {
251 if let Some(tbl_func) = self.engine_state.table_function(name) {
252 let provider = tbl_func.create_table_provider(&args)?;
253 Ok(provider_as_source(provider))
254 } else {
255 let tbl_func = self
256 .session_state
257 .table_functions()
258 .get(name)
259 .cloned()
260 .or_else(|| {
261 function_alias::resolve_scalar(name)
262 .and_then(|alias| self.session_state.table_functions().get(alias).cloned())
263 });
264
265 let tbl_func = tbl_func.ok_or_else(|| {
266 DataFusionError::Plan(format!("table function '{name}' not found"))
267 })?;
268 let provider = tbl_func.create_table_provider(&args)?;
269
270 Ok(provider_as_source(provider))
271 }
272 }
273
274 fn create_cte_work_table(
275 &self,
276 name: &str,
277 schema: arrow_schema::SchemaRef,
278 ) -> DfResult<Arc<dyn TableSource>> {
279 let table = Arc::new(CteWorkTable::new(name, schema));
280 Ok(provider_as_source(table))
281 }
282
283 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
284 &self.expr_planners
285 }
286
287 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
288 None
289 }
290}