query/dist_plan/
planner.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [ExtensionPlanner] implementation for distributed planner
16
17use std::sync::Arc;
18
19use ahash::HashMap;
20use async_trait::async_trait;
21use catalog::CatalogManagerRef;
22use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
23use datafusion::common::Result;
24use datafusion::datasource::DefaultTableSource;
25use datafusion::execution::context::SessionState;
26use datafusion::physical_plan::ExecutionPlan;
27use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
28use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
29use datafusion_common::{DataFusionError, TableReference};
30use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
31use partition::manager::PartitionRuleManagerRef;
32use session::context::QueryContext;
33use snafu::{OptionExt, ResultExt};
34use store_api::storage::RegionId;
35pub use table::metadata::TableType;
36use table::table::adapter::DfTableProviderAdapter;
37use table::table_name::TableName;
38
39use crate::dist_plan::merge_scan::{MergeScanExec, MergeScanLogicalPlan};
40use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
41use crate::dist_plan::region_pruner::ConstraintPruner;
42use crate::dist_plan::PredicateExtractor;
43use crate::error::{CatalogSnafu, TableNotFoundSnafu};
44use crate::region_query::RegionQueryHandlerRef;
45
46/// Planner for convert merge sort logical plan to physical plan
47///
48/// it is currently a fallback to sort, and doesn't change the execution plan:
49/// `MergeSort(MergeScan) -> Sort(MergeScan) - to physical plan -> ...`
50/// It should be applied after `DistExtensionPlanner`
51///
52/// (Later when actually impl this merge sort)
53///
54/// We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
55pub struct MergeSortExtensionPlanner {}
56
57#[async_trait]
58impl ExtensionPlanner for MergeSortExtensionPlanner {
59    async fn plan_extension(
60        &self,
61        planner: &dyn PhysicalPlanner,
62        node: &dyn UserDefinedLogicalNode,
63        _logical_inputs: &[&LogicalPlan],
64        physical_inputs: &[Arc<dyn ExecutionPlan>],
65        session_state: &SessionState,
66    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
67        if let Some(merge_sort) = node.as_any().downcast_ref::<MergeSortLogicalPlan>() {
68            if let LogicalPlan::Extension(ext) = &merge_sort.input.as_ref()
69                && ext
70                    .node
71                    .as_any()
72                    .downcast_ref::<MergeScanLogicalPlan>()
73                    .is_some()
74            {
75                let merge_scan_exec = physical_inputs
76                    .first()
77                    .and_then(|p| p.as_any().downcast_ref::<MergeScanExec>())
78                    .ok_or(DataFusionError::Internal(format!(
79                        "Expect MergeSort's input is a MergeScanExec, found {:?}",
80                        physical_inputs
81                    )))?;
82
83                let partition_cnt = merge_scan_exec.partition_count();
84                let region_cnt = merge_scan_exec.region_count();
85                // if partition >= region, we know that every partition stream of merge scan is ordered
86                // and we only need to do a merge sort, otherwise fallback to quick sort
87                let can_merge_sort = partition_cnt >= region_cnt;
88                if can_merge_sort {
89                    // TODO(discord9): use `SortPreversingMergeExec here`
90                }
91                // for now merge sort only exist in logical plan, and have the same effect as `Sort`
92                // doesn't change the execution plan, this will change in the future
93                let ret = planner
94                    .create_physical_plan(&merge_sort.clone().into_sort(), session_state)
95                    .await?;
96                Ok(Some(ret))
97            } else {
98                Ok(None)
99            }
100        } else {
101            Ok(None)
102        }
103    }
104}
105
106pub struct DistExtensionPlanner {
107    catalog_manager: CatalogManagerRef,
108    partition_rule_manager: PartitionRuleManagerRef,
109    region_query_handler: RegionQueryHandlerRef,
110}
111
112impl DistExtensionPlanner {
113    pub fn new(
114        catalog_manager: CatalogManagerRef,
115        partition_rule_manager: PartitionRuleManagerRef,
116        region_query_handler: RegionQueryHandlerRef,
117    ) -> Self {
118        Self {
119            catalog_manager,
120            partition_rule_manager,
121            region_query_handler,
122        }
123    }
124}
125
126#[async_trait]
127impl ExtensionPlanner for DistExtensionPlanner {
128    async fn plan_extension(
129        &self,
130        planner: &dyn PhysicalPlanner,
131        node: &dyn UserDefinedLogicalNode,
132        _logical_inputs: &[&LogicalPlan],
133        _physical_inputs: &[Arc<dyn ExecutionPlan>],
134        session_state: &SessionState,
135    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
136        let Some(merge_scan) = node.as_any().downcast_ref::<MergeScanLogicalPlan>() else {
137            return Ok(None);
138        };
139
140        let input_plan = merge_scan.input();
141        let fallback = |logical_plan| async move {
142            let optimized_plan = self.optimize_input_logical_plan(session_state, logical_plan)?;
143            planner
144                .create_physical_plan(&optimized_plan, session_state)
145                .await
146                .map(Some)
147        };
148
149        if merge_scan.is_placeholder() {
150            // ignore placeholder
151            return fallback(input_plan).await;
152        }
153
154        let optimized_plan = input_plan;
155        let Some(table_name) = Self::extract_full_table_name(input_plan)? else {
156            // no relation found in input plan, going to execute them locally
157            return fallback(optimized_plan).await;
158        };
159
160        let Ok(regions) = self.get_regions(&table_name, input_plan).await else {
161            // no peers found, going to execute them locally
162            return fallback(optimized_plan).await;
163        };
164
165        // TODO(ruihang): generate different execution plans for different variant merge operation
166        let schema = optimized_plan.schema().as_ref().into();
167        let query_ctx = session_state
168            .config()
169            .get_extension()
170            .unwrap_or_else(QueryContext::arc);
171        let merge_scan_plan = MergeScanExec::new(
172            session_state,
173            table_name,
174            regions,
175            input_plan.clone(),
176            &schema,
177            self.region_query_handler.clone(),
178            query_ctx,
179            session_state.config().target_partitions(),
180            merge_scan.partition_cols().to_vec(),
181        )?;
182        Ok(Some(Arc::new(merge_scan_plan) as _))
183    }
184}
185
186impl DistExtensionPlanner {
187    /// Extract fully resolved table name from logical plan
188    fn extract_full_table_name(plan: &LogicalPlan) -> Result<Option<TableName>> {
189        let mut extractor = TableNameExtractor::default();
190        let _ = plan.visit(&mut extractor)?;
191        Ok(extractor.table_name)
192    }
193
194    async fn get_regions(
195        &self,
196        table_name: &TableName,
197        logical_plan: &LogicalPlan,
198    ) -> Result<Vec<RegionId>> {
199        let table = self
200            .catalog_manager
201            .table(
202                &table_name.catalog_name,
203                &table_name.schema_name,
204                &table_name.table_name,
205                None,
206            )
207            .await
208            .context(CatalogSnafu)?
209            .with_context(|| TableNotFoundSnafu {
210                table: table_name.to_string(),
211            })?;
212
213        let table_info = table.table_info();
214        let all_regions = table_info.region_ids();
215
216        // Extract partition columns
217        let partition_columns: Vec<String> = table_info
218            .meta
219            .partition_column_names()
220            .map(|s| s.to_string())
221            .collect();
222        if partition_columns.is_empty() {
223            return Ok(all_regions);
224        }
225        let partition_column_types = partition_columns
226            .iter()
227            .map(|col_name| {
228                let data_type = table_info
229                    .meta
230                    .schema
231                    .column_schema_by_name(col_name)
232                    // Safety: names are retrieved above from the same table
233                    .unwrap()
234                    .data_type
235                    .clone();
236                (col_name.clone(), data_type)
237            })
238            .collect::<HashMap<_, _>>();
239
240        // Extract predicates from logical plan
241        let partition_expressions = match PredicateExtractor::extract_partition_expressions(
242            logical_plan,
243            &partition_columns,
244        ) {
245            Ok(expressions) => expressions,
246            Err(err) => {
247                common_telemetry::debug!(
248                    "Failed to extract partition expressions for table {} (id: {}), using all regions: {:?}",
249                    table_name,
250                    table.table_info().table_id(),
251                    err
252                );
253                return Ok(all_regions);
254            }
255        };
256
257        if partition_expressions.is_empty() {
258            return Ok(all_regions);
259        }
260
261        // Get partition information for the table if partition rule manager is available
262        let partitions = match self
263            .partition_rule_manager
264            .find_table_partitions(table.table_info().table_id())
265            .await
266        {
267            Ok(partitions) => partitions,
268            Err(err) => {
269                common_telemetry::debug!(
270                    "Failed to get partition information for table {}, using all regions: {:?}",
271                    table_name,
272                    err
273                );
274                return Ok(all_regions);
275            }
276        };
277        if partitions.is_empty() {
278            return Ok(all_regions);
279        }
280
281        // Apply region pruning based on partition rules
282        let pruned_regions = match ConstraintPruner::prune_regions(
283            &partition_expressions,
284            &partitions,
285            partition_column_types,
286        ) {
287            Ok(regions) => regions,
288            Err(err) => {
289                common_telemetry::debug!(
290                    "Failed to prune regions for table {}, using all regions: {:?}",
291                    table_name,
292                    err
293                );
294                return Ok(all_regions);
295            }
296        };
297
298        common_telemetry::debug!(
299            "Region pruning for table {}: {} partition expressions applied, pruned from {} to {} regions",
300            table_name,
301            partition_expressions.len(),
302            all_regions.len(),
303            pruned_regions.len()
304        );
305
306        Ok(pruned_regions)
307    }
308
309    /// Input logical plan is analyzed. Thus only call logical optimizer to optimize it.
310    fn optimize_input_logical_plan(
311        &self,
312        session_state: &SessionState,
313        plan: &LogicalPlan,
314    ) -> Result<LogicalPlan> {
315        let state = session_state.clone();
316        state.optimizer().optimize(plan.clone(), &state, |_, _| {})
317    }
318}
319
320/// Visitor to extract table name from logical plan (TableScan node)
321#[derive(Default)]
322struct TableNameExtractor {
323    pub table_name: Option<TableName>,
324}
325
326impl TreeNodeVisitor<'_> for TableNameExtractor {
327    type Node = LogicalPlan;
328
329    fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
330        match node {
331            LogicalPlan::TableScan(scan) => {
332                if let Some(source) = scan.source.as_any().downcast_ref::<DefaultTableSource>() {
333                    if let Some(provider) = source
334                        .table_provider
335                        .as_any()
336                        .downcast_ref::<DfTableProviderAdapter>()
337                    {
338                        if provider.table().table_type() == TableType::Base {
339                            let info = provider.table().table_info();
340                            self.table_name = Some(TableName::new(
341                                info.catalog_name.clone(),
342                                info.schema_name.clone(),
343                                info.name.clone(),
344                            ));
345                        }
346                        return Ok(TreeNodeRecursion::Stop);
347                    }
348                }
349                match &scan.table_name {
350                    TableReference::Full {
351                        catalog,
352                        schema,
353                        table,
354                    } => {
355                        self.table_name = Some(TableName::new(
356                            catalog.to_string(),
357                            schema.to_string(),
358                            table.to_string(),
359                        ));
360                        Ok(TreeNodeRecursion::Stop)
361                    }
362                    // TODO(ruihang): Maybe the following two cases should not be valid
363                    TableReference::Partial { schema, table } => {
364                        self.table_name = Some(TableName::new(
365                            DEFAULT_CATALOG_NAME.to_string(),
366                            schema.to_string(),
367                            table.to_string(),
368                        ));
369                        Ok(TreeNodeRecursion::Stop)
370                    }
371                    TableReference::Bare { table } => {
372                        self.table_name = Some(TableName::new(
373                            DEFAULT_CATALOG_NAME.to_string(),
374                            DEFAULT_SCHEMA_NAME.to_string(),
375                            table.to_string(),
376                        ));
377                        Ok(TreeNodeRecursion::Stop)
378                    }
379                }
380            }
381            _ => Ok(TreeNodeRecursion::Continue),
382        }
383    }
384}