query/datafusion/
planner.rs1use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use arrow_schema::DataType;
20use catalog::table_source::DfTableSourceProvider;
21use common_function::aggr::{
22 GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME,
23 UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
24};
25use common_function::scalars::udf::create_udf;
26use common_query::logical_plan::create_aggregate_function;
27use datafusion::common::TableReference;
28use datafusion::datasource::cte_worktable::CteWorkTable;
29use datafusion::datasource::file_format::{format_as_file_type, FileFormatFactory};
30use datafusion::datasource::provider_as_source;
31use datafusion::error::Result as DfResult;
32use datafusion::execution::context::SessionState;
33use datafusion::execution::SessionStateDefaults;
34use datafusion::sql::planner::ContextProvider;
35use datafusion::variable::VarType;
36use datafusion_common::config::ConfigOptions;
37use datafusion_common::file_options::file_type::FileType;
38use datafusion_common::DataFusionError;
39use datafusion_expr::planner::{ExprPlanner, TypePlanner};
40use datafusion_expr::var_provider::is_system_variables;
41use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
42use datafusion_sql::parser::Statement as DfStatement;
43use session::context::QueryContextRef;
44use snafu::{Location, ResultExt};
45
46use crate::error::{CatalogSnafu, Result};
47use crate::query_engine::{DefaultPlanDecoder, QueryEngineState};
48
49pub struct DfContextProviderAdapter {
50 engine_state: Arc<QueryEngineState>,
51 session_state: SessionState,
52 tables: HashMap<String, Arc<dyn TableSource>>,
53 table_provider: DfTableSourceProvider,
54 query_ctx: QueryContextRef,
55
56 file_formats: HashMap<String, Arc<dyn FileFormatFactory>>,
60 expr_planners: Vec<Arc<dyn ExprPlanner>>,
63}
64
65impl DfContextProviderAdapter {
66 pub(crate) async fn try_new(
67 engine_state: Arc<QueryEngineState>,
68 session_state: SessionState,
69 df_stmt: Option<&DfStatement>,
70 query_ctx: QueryContextRef,
71 ) -> Result<Self> {
72 let table_names = if let Some(df_stmt) = df_stmt {
73 session_state.resolve_table_references(df_stmt)?
74 } else {
75 vec![]
76 };
77
78 let mut table_provider = DfTableSourceProvider::new(
79 engine_state.catalog_manager().clone(),
80 engine_state.disallow_cross_catalog_query(),
81 query_ctx.clone(),
82 Arc::new(DefaultPlanDecoder::new(session_state.clone(), &query_ctx)?),
83 session_state
84 .config_options()
85 .sql_parser
86 .enable_ident_normalization,
87 );
88
89 let tables = resolve_tables(table_names, &mut table_provider).await?;
90 let file_formats = SessionStateDefaults::default_file_formats()
91 .into_iter()
92 .map(|format| (format.get_ext().to_lowercase(), format))
93 .collect();
94
95 Ok(Self {
96 engine_state,
97 session_state,
98 tables,
99 table_provider,
100 query_ctx,
101 file_formats,
102 expr_planners: SessionStateDefaults::default_expr_planners(),
103 })
104 }
105}
106
107async fn resolve_tables(
108 table_names: Vec<TableReference>,
109 table_provider: &mut DfTableSourceProvider,
110) -> Result<HashMap<String, Arc<dyn TableSource>>> {
111 let mut tables = HashMap::with_capacity(table_names.len());
112
113 for table_name in table_names {
114 let resolved_name = table_provider
115 .resolve_table_ref(table_name.clone())
116 .context(CatalogSnafu)?;
117
118 if let Entry::Vacant(v) = tables.entry(resolved_name.to_string()) {
119 match table_provider.resolve_table(table_name).await {
123 Ok(table) => {
124 let _ = v.insert(table);
125 }
126 Err(e) if e.should_fail() => {
127 return Err(e).context(CatalogSnafu);
128 }
129 _ => {
130 }
132 }
133 }
134 }
135 Ok(tables)
136}
137
138impl ContextProvider for DfContextProviderAdapter {
139 fn get_table_source(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
140 let table_ref = self.table_provider.resolve_table_ref(name)?;
141 self.tables
142 .get(&table_ref.to_string())
143 .cloned()
144 .ok_or_else(|| {
145 crate::error::Error::TableNotFound {
146 table: table_ref.to_string(),
147 location: Location::default(),
148 }
149 .into()
150 })
151 }
152
153 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
154 self.engine_state.udf_function(name).map_or_else(
155 || self.session_state.scalar_functions().get(name).cloned(),
156 |func| {
157 Some(Arc::new(create_udf(
158 func,
159 self.query_ctx.clone(),
160 self.engine_state.function_state(),
161 )))
162 },
163 )
164 }
165
166 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
167 if name == UDDSKETCH_STATE_NAME {
168 return Some(Arc::new(UddSketchState::state_udf_impl()));
169 } else if name == UDDSKETCH_MERGE_NAME {
170 return Some(Arc::new(UddSketchState::merge_udf_impl()));
171 } else if name == HLL_NAME {
172 return Some(Arc::new(HllState::state_udf_impl()));
173 } else if name == HLL_MERGE_NAME {
174 return Some(Arc::new(HllState::merge_udf_impl()));
175 } else if name == GEO_PATH_NAME {
176 return Some(Arc::new(GeoPathAccumulator::udf_impl()));
177 }
178
179 self.engine_state.aggregate_function(name).map_or_else(
180 || self.session_state.aggregate_functions().get(name).cloned(),
181 |func| {
182 Some(Arc::new(
183 create_aggregate_function(func.name(), func.args_count(), func.create()).into(),
184 ))
185 },
186 )
187 }
188
189 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
190 self.session_state.window_functions().get(name).cloned()
191 }
192
193 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
194 if variable_names.is_empty() {
195 return None;
196 }
197
198 let provider_type = if is_system_variables(variable_names) {
199 VarType::System
200 } else {
201 VarType::UserDefined
202 };
203
204 self.session_state
205 .execution_props()
206 .var_providers
207 .as_ref()
208 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
209 }
210
211 fn options(&self) -> &ConfigOptions {
212 self.session_state.config_options()
213 }
214
215 fn udf_names(&self) -> Vec<String> {
216 let mut names = self.engine_state.udf_names();
217 names.extend(self.session_state.scalar_functions().keys().cloned());
218 names
219 }
220
221 fn udaf_names(&self) -> Vec<String> {
222 let mut names = self.engine_state.udaf_names();
223 names.extend(self.session_state.aggregate_functions().keys().cloned());
224 names
225 }
226
227 fn udwf_names(&self) -> Vec<String> {
228 self.session_state
229 .window_functions()
230 .keys()
231 .cloned()
232 .collect()
233 }
234
235 fn get_file_type(&self, ext: &str) -> DfResult<Arc<dyn FileType>> {
236 self.file_formats
237 .get(&ext.to_lowercase())
238 .ok_or_else(|| {
239 DataFusionError::Plan(format!("There is no registered file format with ext {ext}"))
240 })
241 .map(|file_type| format_as_file_type(Arc::clone(file_type)))
242 }
243
244 fn get_table_function_source(
245 &self,
246 name: &str,
247 args: Vec<datafusion_expr::Expr>,
248 ) -> DfResult<Arc<dyn TableSource>> {
249 let tbl_func = self
250 .session_state
251 .table_functions()
252 .get(name)
253 .cloned()
254 .ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?;
255 let provider = tbl_func.create_table_provider(&args)?;
256
257 Ok(provider_as_source(provider))
258 }
259
260 fn create_cte_work_table(
261 &self,
262 name: &str,
263 schema: arrow_schema::SchemaRef,
264 ) -> DfResult<Arc<dyn TableSource>> {
265 let table = Arc::new(CteWorkTable::new(name, schema));
266 Ok(provider_as_source(table))
267 }
268
269 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
270 &self.expr_planners
271 }
272
273 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
274 None
275 }
276}