1use std::sync::Arc;
18
19use ahash::HashMap;
20use async_trait::async_trait;
21use catalog::CatalogManagerRef;
22use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
23use datafusion::common::Result;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::execution::context::SessionState;
26use datafusion::physical_plan::ExecutionPlan;
27use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
28use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
29use datafusion_common::{DataFusionError, TableReference};
30use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
31use partition::manager::PartitionRuleManagerRef;
32use session::context::QueryContext;
33use snafu::{OptionExt, ResultExt};
34use store_api::storage::RegionId;
35pub use table::metadata::TableType;
36use table::table::adapter::DfTableProviderAdapter;
37use table::table_name::TableName;
38
39use crate::dist_plan::PredicateExtractor;
40use crate::dist_plan::merge_scan::{MergeScanExec, MergeScanLogicalPlan};
41use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
42use crate::dist_plan::region_pruner::ConstraintPruner;
43use crate::error::{CatalogSnafu, TableNotFoundSnafu};
44use crate::region_query::RegionQueryHandlerRef;
45
46pub struct MergeSortExtensionPlanner {}
56
57#[async_trait]
58impl ExtensionPlanner for MergeSortExtensionPlanner {
59 async fn plan_extension(
60 &self,
61 planner: &dyn PhysicalPlanner,
62 node: &dyn UserDefinedLogicalNode,
63 _logical_inputs: &[&LogicalPlan],
64 physical_inputs: &[Arc<dyn ExecutionPlan>],
65 session_state: &SessionState,
66 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
67 if let Some(merge_sort) = node.as_any().downcast_ref::<MergeSortLogicalPlan>() {
68 if let LogicalPlan::Extension(ext) = &merge_sort.input.as_ref()
69 && ext
70 .node
71 .as_any()
72 .downcast_ref::<MergeScanLogicalPlan>()
73 .is_some()
74 {
75 let merge_scan_exec = physical_inputs
76 .first()
77 .and_then(|p| p.as_any().downcast_ref::<MergeScanExec>())
78 .ok_or(DataFusionError::Internal(format!(
79 "Expect MergeSort's input is a MergeScanExec, found {:?}",
80 physical_inputs
81 )))?;
82
83 let partition_cnt = merge_scan_exec.partition_count();
84 let region_cnt = merge_scan_exec.region_count();
85 let can_merge_sort = partition_cnt >= region_cnt;
88 if can_merge_sort {
89 }
91 let ret = planner
94 .create_physical_plan(&merge_sort.clone().into_sort(), session_state)
95 .await?;
96 Ok(Some(ret))
97 } else {
98 Ok(None)
99 }
100 } else {
101 Ok(None)
102 }
103 }
104}
105
106pub struct DistExtensionPlanner {
107 catalog_manager: CatalogManagerRef,
108 partition_rule_manager: PartitionRuleManagerRef,
109 region_query_handler: RegionQueryHandlerRef,
110}
111
112impl DistExtensionPlanner {
113 pub fn new(
114 catalog_manager: CatalogManagerRef,
115 partition_rule_manager: PartitionRuleManagerRef,
116 region_query_handler: RegionQueryHandlerRef,
117 ) -> Self {
118 Self {
119 catalog_manager,
120 partition_rule_manager,
121 region_query_handler,
122 }
123 }
124}
125
126#[async_trait]
127impl ExtensionPlanner for DistExtensionPlanner {
128 async fn plan_extension(
129 &self,
130 planner: &dyn PhysicalPlanner,
131 node: &dyn UserDefinedLogicalNode,
132 _logical_inputs: &[&LogicalPlan],
133 _physical_inputs: &[Arc<dyn ExecutionPlan>],
134 session_state: &SessionState,
135 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
136 let Some(merge_scan) = node.as_any().downcast_ref::<MergeScanLogicalPlan>() else {
137 return Ok(None);
138 };
139
140 let input_plan = merge_scan.input();
141 let fallback = |logical_plan| async move {
142 let optimized_plan = self.optimize_input_logical_plan(session_state, logical_plan)?;
143 planner
144 .create_physical_plan(&optimized_plan, session_state)
145 .await
146 .map(Some)
147 };
148
149 if merge_scan.is_placeholder() {
150 return fallback(input_plan).await;
152 }
153
154 let optimized_plan = input_plan;
155 let Some(table_name) = Self::extract_full_table_name(input_plan)? else {
156 return fallback(optimized_plan).await;
158 };
159
160 let Ok(regions) = self.get_regions(&table_name, input_plan).await else {
161 return fallback(optimized_plan).await;
163 };
164
165 let schema = optimized_plan.schema().as_ref().into();
167 let query_ctx = session_state
168 .config()
169 .get_extension()
170 .unwrap_or_else(QueryContext::arc);
171 let merge_scan_plan = MergeScanExec::new(
172 session_state,
173 table_name,
174 regions,
175 input_plan.clone(),
176 &schema,
177 self.region_query_handler.clone(),
178 query_ctx,
179 session_state.config().target_partitions(),
180 merge_scan.partition_cols().clone(),
181 )?;
182 Ok(Some(Arc::new(merge_scan_plan) as _))
183 }
184}
185
186impl DistExtensionPlanner {
187 fn extract_full_table_name(plan: &LogicalPlan) -> Result<Option<TableName>> {
189 let mut extractor = TableNameExtractor::default();
190 let _ = plan.visit(&mut extractor)?;
191 Ok(extractor.table_name)
192 }
193
194 async fn get_regions(
195 &self,
196 table_name: &TableName,
197 logical_plan: &LogicalPlan,
198 ) -> Result<Vec<RegionId>> {
199 let table = self
200 .catalog_manager
201 .table(
202 &table_name.catalog_name,
203 &table_name.schema_name,
204 &table_name.table_name,
205 None,
206 )
207 .await
208 .context(CatalogSnafu)?
209 .with_context(|| TableNotFoundSnafu {
210 table: table_name.to_string(),
211 })?;
212
213 let table_info = table.table_info();
214 let all_regions = table_info.region_ids();
215
216 let partition_columns: Vec<String> =
218 table_info.meta.partition_column_names().cloned().collect();
219 if partition_columns.is_empty() {
220 return Ok(all_regions);
221 }
222 let partition_column_types = partition_columns
223 .iter()
224 .map(|col_name| {
225 let data_type = table_info
226 .meta
227 .schema
228 .column_schema_by_name(col_name)
229 .unwrap()
231 .data_type
232 .clone();
233 (col_name.clone(), data_type)
234 })
235 .collect::<HashMap<_, _>>();
236
237 let partition_expressions = match PredicateExtractor::extract_partition_expressions(
239 logical_plan,
240 &partition_columns,
241 ) {
242 Ok(expressions) => expressions,
243 Err(err) => {
244 common_telemetry::debug!(
245 "Failed to extract partition expressions for table {} (id: {}), using all regions: {:?}",
246 table_name,
247 table.table_info().table_id(),
248 err
249 );
250 return Ok(all_regions);
251 }
252 };
253
254 if partition_expressions.is_empty() {
255 return Ok(all_regions);
256 }
257
258 let partitions = match self
260 .partition_rule_manager
261 .find_table_partitions(table.table_info().table_id())
262 .await
263 {
264 Ok(partitions) => partitions,
265 Err(err) => {
266 common_telemetry::debug!(
267 "Failed to get partition information for table {}, using all regions: {:?}",
268 table_name,
269 err
270 );
271 return Ok(all_regions);
272 }
273 };
274 if partitions.is_empty() {
275 return Ok(all_regions);
276 }
277
278 let pruned_regions = match ConstraintPruner::prune_regions(
280 &partition_expressions,
281 &partitions,
282 partition_column_types,
283 ) {
284 Ok(regions) => regions,
285 Err(err) => {
286 common_telemetry::debug!(
287 "Failed to prune regions for table {}, using all regions: {:?}",
288 table_name,
289 err
290 );
291 return Ok(all_regions);
292 }
293 };
294
295 common_telemetry::debug!(
296 "Region pruning for table {}: {} partition expressions applied, pruned from {} to {} regions",
297 table_name,
298 partition_expressions.len(),
299 all_regions.len(),
300 pruned_regions.len()
301 );
302
303 Ok(pruned_regions)
304 }
305
306 fn optimize_input_logical_plan(
308 &self,
309 session_state: &SessionState,
310 plan: &LogicalPlan,
311 ) -> Result<LogicalPlan> {
312 let state = session_state.clone();
313 state.optimizer().optimize(plan.clone(), &state, |_, _| {})
314 }
315}
316
317#[derive(Default)]
319struct TableNameExtractor {
320 pub table_name: Option<TableName>,
321}
322
323impl TreeNodeVisitor<'_> for TableNameExtractor {
324 type Node = LogicalPlan;
325
326 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
327 match node {
328 LogicalPlan::TableScan(scan) => {
329 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>()
330 && let Some(provider) = source
331 .table_provider
332 .as_any()
333 .downcast_ref::<DfTableProviderAdapter>()
334 {
335 if provider.table().table_type() == TableType::Base {
336 let info = provider.table().table_info();
337 self.table_name = Some(TableName::new(
338 info.catalog_name.clone(),
339 info.schema_name.clone(),
340 info.name.clone(),
341 ));
342 }
343 return Ok(TreeNodeRecursion::Stop);
344 }
345 match &scan.table_name {
346 TableReference::Full {
347 catalog,
348 schema,
349 table,
350 } => {
351 self.table_name = Some(TableName::new(
352 catalog.to_string(),
353 schema.to_string(),
354 table.to_string(),
355 ));
356 Ok(TreeNodeRecursion::Stop)
357 }
358 TableReference::Partial { schema, table } => {
360 self.table_name = Some(TableName::new(
361 DEFAULT_CATALOG_NAME.to_string(),
362 schema.to_string(),
363 table.to_string(),
364 ));
365 Ok(TreeNodeRecursion::Stop)
366 }
367 TableReference::Bare { table } => {
368 self.table_name = Some(TableName::new(
369 DEFAULT_CATALOG_NAME.to_string(),
370 DEFAULT_SCHEMA_NAME.to_string(),
371 table.to_string(),
372 ));
373 Ok(TreeNodeRecursion::Stop)
374 }
375 }
376 }
377 _ => Ok(TreeNodeRecursion::Continue),
378 }
379 }
380}