query/dist_plan/
planner.rs1use std::sync::Arc;
18
19use async_trait::async_trait;
20use catalog::CatalogManagerRef;
21use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
22use datafusion::common::Result;
23use datafusion::datasource::DefaultTableSource;
24use datafusion::execution::context::SessionState;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
27use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
28use datafusion_common::{DataFusionError, TableReference};
29use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
30use session::context::QueryContext;
31use snafu::{OptionExt, ResultExt};
32use store_api::storage::RegionId;
33pub use table::metadata::TableType;
34use table::table::adapter::DfTableProviderAdapter;
35use table::table_name::TableName;
36
37use crate::dist_plan::merge_scan::{MergeScanExec, MergeScanLogicalPlan};
38use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
39use crate::error::{CatalogSnafu, TableNotFoundSnafu};
40use crate::region_query::RegionQueryHandlerRef;
41
42pub struct MergeSortExtensionPlanner {}
52
53#[async_trait]
54impl ExtensionPlanner for MergeSortExtensionPlanner {
55 async fn plan_extension(
56 &self,
57 planner: &dyn PhysicalPlanner,
58 node: &dyn UserDefinedLogicalNode,
59 _logical_inputs: &[&LogicalPlan],
60 physical_inputs: &[Arc<dyn ExecutionPlan>],
61 session_state: &SessionState,
62 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
63 if let Some(merge_sort) = node.as_any().downcast_ref::<MergeSortLogicalPlan>() {
64 if let LogicalPlan::Extension(ext) = &merge_sort.input.as_ref()
65 && ext
66 .node
67 .as_any()
68 .downcast_ref::<MergeScanLogicalPlan>()
69 .is_some()
70 {
71 let merge_scan_exec = physical_inputs
72 .first()
73 .and_then(|p| p.as_any().downcast_ref::<MergeScanExec>())
74 .ok_or(DataFusionError::Internal(format!(
75 "Expect MergeSort's input is a MergeScanExec, found {:?}",
76 physical_inputs
77 )))?;
78
79 let partition_cnt = merge_scan_exec.partition_count();
80 let region_cnt = merge_scan_exec.region_count();
81 let can_merge_sort = partition_cnt >= region_cnt;
84 if can_merge_sort {
85 }
87 let ret = planner
90 .create_physical_plan(&merge_sort.clone().into_sort(), session_state)
91 .await?;
92 Ok(Some(ret))
93 } else {
94 Ok(None)
95 }
96 } else {
97 Ok(None)
98 }
99 }
100}
101
102pub struct DistExtensionPlanner {
103 catalog_manager: CatalogManagerRef,
104 region_query_handler: RegionQueryHandlerRef,
105}
106
107impl DistExtensionPlanner {
108 pub fn new(
109 catalog_manager: CatalogManagerRef,
110 region_query_handler: RegionQueryHandlerRef,
111 ) -> Self {
112 Self {
113 catalog_manager,
114 region_query_handler,
115 }
116 }
117}
118
119#[async_trait]
120impl ExtensionPlanner for DistExtensionPlanner {
121 async fn plan_extension(
122 &self,
123 planner: &dyn PhysicalPlanner,
124 node: &dyn UserDefinedLogicalNode,
125 _logical_inputs: &[&LogicalPlan],
126 _physical_inputs: &[Arc<dyn ExecutionPlan>],
127 session_state: &SessionState,
128 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
129 let Some(merge_scan) = node.as_any().downcast_ref::<MergeScanLogicalPlan>() else {
130 return Ok(None);
131 };
132
133 let input_plan = merge_scan.input();
134 let fallback = |logical_plan| async move {
135 let optimized_plan = self.optimize_input_logical_plan(session_state, logical_plan)?;
136 planner
137 .create_physical_plan(&optimized_plan, session_state)
138 .await
139 .map(Some)
140 };
141
142 if merge_scan.is_placeholder() {
143 return fallback(input_plan).await;
145 }
146
147 let optimized_plan = input_plan;
148 let Some(table_name) = Self::extract_full_table_name(input_plan)? else {
149 return fallback(optimized_plan).await;
151 };
152
153 let Ok(regions) = self.get_regions(&table_name).await else {
154 return fallback(optimized_plan).await;
156 };
157
158 let schema = optimized_plan.schema().as_ref().into();
160 let query_ctx = session_state
161 .config()
162 .get_extension()
163 .unwrap_or_else(QueryContext::arc);
164 let merge_scan_plan = MergeScanExec::new(
165 session_state,
166 table_name,
167 regions,
168 input_plan.clone(),
169 &schema,
170 self.region_query_handler.clone(),
171 query_ctx,
172 session_state.config().target_partitions(),
173 merge_scan.partition_cols().to_vec(),
174 )?;
175 Ok(Some(Arc::new(merge_scan_plan) as _))
176 }
177}
178
179impl DistExtensionPlanner {
180 fn extract_full_table_name(plan: &LogicalPlan) -> Result<Option<TableName>> {
182 let mut extractor = TableNameExtractor::default();
183 let _ = plan.visit(&mut extractor)?;
184 Ok(extractor.table_name)
185 }
186
187 async fn get_regions(&self, table_name: &TableName) -> Result<Vec<RegionId>> {
188 let table = self
189 .catalog_manager
190 .table(
191 &table_name.catalog_name,
192 &table_name.schema_name,
193 &table_name.table_name,
194 None,
195 )
196 .await
197 .context(CatalogSnafu)?
198 .with_context(|| TableNotFoundSnafu {
199 table: table_name.to_string(),
200 })?;
201 Ok(table.table_info().region_ids())
202 }
203
204 fn optimize_input_logical_plan(
206 &self,
207 session_state: &SessionState,
208 plan: &LogicalPlan,
209 ) -> Result<LogicalPlan> {
210 let state = session_state.clone();
211 state.optimizer().optimize(plan.clone(), &state, |_, _| {})
212 }
213}
214
215#[derive(Default)]
217struct TableNameExtractor {
218 pub table_name: Option<TableName>,
219}
220
221impl TreeNodeVisitor<'_> for TableNameExtractor {
222 type Node = LogicalPlan;
223
224 fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
225 match node {
226 LogicalPlan::TableScan(scan) => {
227 if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
228 if let Some(provider) = source
229 .table_provider
230 .as_any()
231 .downcast_ref::<DfTableProviderAdapter>()
232 {
233 if provider.table().table_type() == TableType::Base {
234 let info = provider.table().table_info();
235 self.table_name = Some(TableName::new(
236 info.catalog_name.clone(),
237 info.schema_name.clone(),
238 info.name.clone(),
239 ));
240 }
241 return Ok(TreeNodeRecursion::Stop);
242 }
243 }
244 match &scan.table_name {
245 TableReference::Full {
246 catalog,
247 schema,
248 table,
249 } => {
250 self.table_name = Some(TableName::new(
251 catalog.to_string(),
252 schema.to_string(),
253 table.to_string(),
254 ));
255 Ok(TreeNodeRecursion::Stop)
256 }
257 TableReference::Partial { schema, table } => {
259 self.table_name = Some(TableName::new(
260 DEFAULT_CATALOG_NAME.to_string(),
261 schema.to_string(),
262 table.to_string(),
263 ));
264 Ok(TreeNodeRecursion::Stop)
265 }
266 TableReference::Bare { table } => {
267 self.table_name = Some(TableName::new(
268 DEFAULT_CATALOG_NAME.to_string(),
269 DEFAULT_SCHEMA_NAME.to_string(),
270 table.to_string(),
271 ));
272 Ok(TreeNodeRecursion::Stop)
273 }
274 }
275 }
276 _ => Ok(TreeNodeRecursion::Continue),
277 }
278 }
279}