1use std::sync::Arc;
18
19use ahash::HashMap;
20use async_trait::async_trait;
21use catalog::CatalogManagerRef;
22use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
23use datafusion::common::Result;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::execution::context::SessionState;
26use datafusion::physical_plan::ExecutionPlan;
27use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
28use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
29use datafusion_common::{DataFusionError, TableReference};
30use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
31use partition::manager::PartitionRuleManagerRef;
32use session::context::QueryContext;
33use snafu::{OptionExt, ResultExt};
34use store_api::storage::RegionId;
35pub use table::metadata::TableType;
36use table::table::adapter::DfTableProviderAdapter;
37use table::table_name::TableName;
38
39use crate::dist_plan::merge_scan::{MergeScanExec, MergeScanLogicalPlan};
40use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
41use crate::dist_plan::region_pruner::ConstraintPruner;
42use crate::dist_plan::PredicateExtractor;
43use crate::error::{CatalogSnafu, TableNotFoundSnafu};
44use crate::region_query::RegionQueryHandlerRef;
45
46pub struct MergeSortExtensionPlanner {}
56
57#[async_trait]
58impl ExtensionPlanner for MergeSortExtensionPlanner {
59 async fn plan_extension(
60 &self,
61 planner: &dyn PhysicalPlanner,
62 node: &dyn UserDefinedLogicalNode,
63 _logical_inputs: &[&LogicalPlan],
64 physical_inputs: &[Arc<dyn ExecutionPlan>],
65 session_state: &SessionState,
66 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
67 if let Some(merge_sort) = node.as_any().downcast_ref::<MergeSortLogicalPlan>() {
68 if let LogicalPlan::Extension(ext) = &merge_sort.input.as_ref()
69 && ext
70 .node
71 .as_any()
72 .downcast_ref::<MergeScanLogicalPlan>()
73 .is_some()
74 {
75 let merge_scan_exec = physical_inputs
76 .first()
77 .and_then(|p| p.as_any().downcast_ref::<MergeScanExec>())
78 .ok_or(DataFusionError::Internal(format!(
79 "Expect MergeSort's input is a MergeScanExec, found {:?}",
80 physical_inputs
81 )))?;
82
83 let partition_cnt = merge_scan_exec.partition_count();
84 let region_cnt = merge_scan_exec.region_count();
85 let can_merge_sort = partition_cnt >= region_cnt;
88 if can_merge_sort {
89 }
91 let ret = planner
94 .create_physical_plan(&merge_sort.clone().into_sort(), session_state)
95 .await?;
96 Ok(Some(ret))
97 } else {
98 Ok(None)
99 }
100 } else {
101 Ok(None)
102 }
103 }
104}
105
106pub struct DistExtensionPlanner {
107 catalog_manager: CatalogManagerRef,
108 partition_rule_manager: PartitionRuleManagerRef,
109 region_query_handler: RegionQueryHandlerRef,
110}
111
112impl DistExtensionPlanner {
113 pub fn new(
114 catalog_manager: CatalogManagerRef,
115 partition_rule_manager: PartitionRuleManagerRef,
116 region_query_handler: RegionQueryHandlerRef,
117 ) -> Self {
118 Self {
119 catalog_manager,
120 partition_rule_manager,
121 region_query_handler,
122 }
123 }
124}
125
126#[async_trait]
127impl ExtensionPlanner for DistExtensionPlanner {
128 async fn plan_extension(
129 &self,
130 planner: &dyn PhysicalPlanner,
131 node: &dyn UserDefinedLogicalNode,
132 _logical_inputs: &[&LogicalPlan],
133 _physical_inputs: &[Arc<dyn ExecutionPlan>],
134 session_state: &SessionState,
135 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
136 let Some(merge_scan) = node.as_any().downcast_ref::<MergeScanLogicalPlan>() else {
137 return Ok(None);
138 };
139
140 let input_plan = merge_scan.input();
141 let fallback = |logical_plan| async move {
142 let optimized_plan = self.optimize_input_logical_plan(session_state, logical_plan)?;
143 planner
144 .create_physical_plan(&optimized_plan, session_state)
145 .await
146 .map(Some)
147 };
148
149 if merge_scan.is_placeholder() {
150 return fallback(input_plan).await;
152 }
153
154 let optimized_plan = input_plan;
155 let Some(table_name) = Self::extract_full_table_name(input_plan)? else {
156 return fallback(optimized_plan).await;
158 };
159
160 let Ok(regions) = self.get_regions(&table_name, input_plan).await else {
161 return fallback(optimized_plan).await;
163 };
164
165 let schema = optimized_plan.schema().as_ref().into();
167 let query_ctx = session_state
168 .config()
169 .get_extension()
170 .unwrap_or_else(QueryContext::arc);
171 let merge_scan_plan = MergeScanExec::new(
172 session_state,
173 table_name,
174 regions,
175 input_plan.clone(),
176 &schema,
177 self.region_query_handler.clone(),
178 query_ctx,
179 session_state.config().target_partitions(),
180 merge_scan.partition_cols().to_vec(),
181 )?;
182 Ok(Some(Arc::new(merge_scan_plan) as _))
183 }
184}
185
186impl DistExtensionPlanner {
187 fn extract_full_table_name(plan: &LogicalPlan) -> Result<Option<TableName>> {
189 let mut extractor = TableNameExtractor::default();
190 let _ = plan.visit(&mut extractor)?;
191 Ok(extractor.table_name)
192 }
193
194 async fn get_regions(
195 &self,
196 table_name: &TableName,
197 logical_plan: &LogicalPlan,
198 ) -> Result<Vec<RegionId>> {
199 let table = self
200 .catalog_manager
201 .table(
202 &table_name.catalog_name,
203 &table_name.schema_name,
204 &table_name.table_name,
205 None,
206 )
207 .await
208 .context(CatalogSnafu)?
209 .with_context(|| TableNotFoundSnafu {
210 table: table_name.to_string(),
211 })?;
212
213 let table_info = table.table_info();
214 let all_regions = table_info.region_ids();
215
216 let partition_columns: Vec<String> = table_info
218 .meta
219 .partition_column_names()
220 .map(|s| s.to_string())
221 .collect();
222 if partition_columns.is_empty() {
223 return Ok(all_regions);
224 }
225 let partition_column_types = partition_columns
226 .iter()
227 .map(|col_name| {
228 let data_type = table_info
229 .meta
230 .schema
231 .column_schema_by_name(col_name)
232 .unwrap()
234 .data_type
235 .clone();
236 (col_name.clone(), data_type)
237 })
238 .collect::<HashMap<_, _>>();
239
240 let partition_expressions = match PredicateExtractor::extract_partition_expressions(
242 logical_plan,
243 &partition_columns,
244 ) {
245 Ok(expressions) => expressions,
246 Err(err) => {
247 common_telemetry::debug!(
248 "Failed to extract partition expressions for table {} (id: {}), using all regions: {:?}",
249 table_name,
250 table.table_info().table_id(),
251 err
252 );
253 return Ok(all_regions);
254 }
255 };
256
257 if partition_expressions.is_empty() {
258 return Ok(all_regions);
259 }
260
261 let partitions = match self
263 .partition_rule_manager
264 .find_table_partitions(table.table_info().table_id())
265 .await
266 {
267 Ok(partitions) => partitions,
268 Err(err) => {
269 common_telemetry::debug!(
270 "Failed to get partition information for table {}, using all regions: {:?}",
271 table_name,
272 err
273 );
274 return Ok(all_regions);
275 }
276 };
277 if partitions.is_empty() {
278 return Ok(all_regions);
279 }
280
281 let pruned_regions = match ConstraintPruner::prune_regions(
283 &partition_expressions,
284 &partitions,
285 partition_column_types,
286 ) {
287 Ok(regions) => regions,
288 Err(err) => {
289 common_telemetry::debug!(
290 "Failed to prune regions for table {}, using all regions: {:?}",
291 table_name,
292 err
293 );
294 return Ok(all_regions);
295 }
296 };
297
298 common_telemetry::debug!(
299 "Region pruning for table {}: {} partition expressions applied, pruned from {} to {} regions",
300 table_name,
301 partition_expressions.len(),
302 all_regions.len(),
303 pruned_regions.len()
304 );
305
306 Ok(pruned_regions)
307 }
308
309 fn optimize_input_logical_plan(
311 &self,
312 session_state: &SessionState,
313 plan: &LogicalPlan,
314 ) -> Result<LogicalPlan> {
315 let state = session_state.clone();
316 state.optimizer().optimize(plan.clone(), &state, |_, _| {})
317 }
318}
319
320#[derive(Default)]
322struct TableNameExtractor {
323 pub table_name: Option<TableName>,
324}
325
326impl TreeNodeVisitor<'_> for TableNameExtractor {
327 type Node = LogicalPlan;
328
329 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
330 match node {
331 LogicalPlan::TableScan(scan) => {
332 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
333 if let Some(provider) = source
334 .table_provider
335 .as_any()
336 .downcast_ref::<DfTableProviderAdapter>()
337 {
338 if provider.table().table_type() == TableType::Base {
339 let info = provider.table().table_info();
340 self.table_name = Some(TableName::new(
341 info.catalog_name.clone(),
342 info.schema_name.clone(),
343 info.name.clone(),
344 ));
345 }
346 return Ok(TreeNodeRecursion::Stop);
347 }
348 }
349 match &scan.table_name {
350 TableReference::Full {
351 catalog,
352 schema,
353 table,
354 } => {
355 self.table_name = Some(TableName::new(
356 catalog.to_string(),
357 schema.to_string(),
358 table.to_string(),
359 ));
360 Ok(TreeNodeRecursion::Stop)
361 }
362 TableReference::Partial { schema, table } => {
364 self.table_name = Some(TableName::new(
365 DEFAULT_CATALOG_NAME.to_string(),
366 schema.to_string(),
367 table.to_string(),
368 ));
369 Ok(TreeNodeRecursion::Stop)
370 }
371 TableReference::Bare { table } => {
372 self.table_name = Some(TableName::new(
373 DEFAULT_CATALOG_NAME.to_string(),
374 DEFAULT_SCHEMA_NAME.to_string(),
375 table.to_string(),
376 ));
377 Ok(TreeNodeRecursion::Stop)
378 }
379 }
380 }
381 _ => Ok(TreeNodeRecursion::Continue),
382 }
383 }
384}