query/optimizer/
windowed_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::sync::Arc;
17
18use datafusion::physical_optimizer::PhysicalOptimizerRule;
19use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
20use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
21use datafusion::physical_plan::coop::CooperativeExec;
22use datafusion::physical_plan::filter::FilterExec;
23use datafusion::physical_plan::projection::ProjectionExec;
24use datafusion::physical_plan::repartition::RepartitionExec;
25use datafusion::physical_plan::sorts::sort::SortExec;
26use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
27use datafusion::physical_plan::ExecutionPlan;
28use datafusion_common::tree_node::{Transformed, TreeNode};
29use datafusion_common::Result as DataFusionResult;
30use datafusion_physical_expr::expressions::Column as PhysicalColumn;
31use store_api::region_engine::PartitionRange;
32use table::table::scan::RegionScanExec;
33
34use crate::part_sort::PartSortExec;
35use crate::window_sort::WindowedSortExec;
36
37/// Optimize rule for windowed sort.
38///
39/// This is expected to run after [`ScanHint`] and [`ParallelizeScan`].
40/// It would change the original sort to a custom plan. To make sure
41/// other rules are applied correctly, this rule can be run as later as
42/// possible.
43///
44/// [`ScanHint`]: crate::optimizer::scan_hint::ScanHintRule
45/// [`ParallelizeScan`]: crate::optimizer::parallelize_scan::ParallelizeScan
46#[derive(Debug)]
47pub struct WindowedSortPhysicalRule;
48
49impl PhysicalOptimizerRule for WindowedSortPhysicalRule {
50    fn optimize(
51        &self,
52        plan: Arc<dyn ExecutionPlan>,
53        config: &datafusion::config::ConfigOptions,
54    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
55        Self::do_optimize(plan, config)
56    }
57
58    fn name(&self) -> &str {
59        "WindowedSortRule"
60    }
61
62    fn schema_check(&self) -> bool {
63        false
64    }
65}
66
67impl WindowedSortPhysicalRule {
68    fn do_optimize(
69        plan: Arc<dyn ExecutionPlan>,
70        _config: &datafusion::config::ConfigOptions,
71    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
72        let result = plan
73            .transform_down(|plan| {
74                if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
75                    // TODO: support multiple expr in windowed sort
76                    if sort_exec.expr().len() != 1 {
77                        return Ok(Transformed::no(plan));
78                    }
79
80                    let preserve_partitioning = sort_exec.preserve_partitioning();
81
82                    let sort_input = remove_repartition(sort_exec.input().clone())?.data;
83                    let sort_input =
84                        remove_coalesce_batches_exec(sort_input, sort_exec.fetch())?.data;
85
86                    // Gets scanner info from the input without repartition before filter.
87                    let Some(scanner_info) = fetch_partition_range(sort_input.clone())? else {
88                        return Ok(Transformed::no(plan));
89                    };
90                    let input_schema = sort_input.schema();
91
92                    let first_sort_expr = sort_exec.expr().first();
93                    if let Some(column_expr) = first_sort_expr
94                        .expr
95                        .as_any()
96                        .downcast_ref::<PhysicalColumn>()
97                        && scanner_info
98                            .time_index
99                            .contains(input_schema.field(column_expr.index()).name())
100                    {
101                    } else {
102                        return Ok(Transformed::no(plan));
103                    }
104
105                    // PartSortExec is unnecessary if:
106                    // - there is no tag column, and
107                    // - the sort is ascending on the time index column
108                    let new_input = if scanner_info.tag_columns.is_empty()
109                        && !first_sort_expr.options.descending
110                    {
111                        sort_input
112                    } else {
113                        Arc::new(PartSortExec::new(
114                            first_sort_expr.clone(),
115                            sort_exec.fetch(),
116                            scanner_info.partition_ranges.clone(),
117                            sort_input,
118                        ))
119                    };
120
121                    let windowed_sort_exec = WindowedSortExec::try_new(
122                        first_sort_expr.clone(),
123                        sort_exec.fetch(),
124                        scanner_info.partition_ranges,
125                        new_input,
126                    )?;
127
128                    if !preserve_partitioning {
129                        let order_preserving_merge = SortPreservingMergeExec::new(
130                            sort_exec.expr().clone(),
131                            Arc::new(windowed_sort_exec),
132                        );
133                        return Ok(Transformed {
134                            data: Arc::new(order_preserving_merge),
135                            transformed: true,
136                            tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
137                        });
138                    } else {
139                        return Ok(Transformed {
140                            data: Arc::new(windowed_sort_exec),
141                            transformed: true,
142                            tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
143                        });
144                    }
145                }
146
147                Ok(Transformed::no(plan))
148            })?
149            .data;
150
151        Ok(result)
152    }
153}
154
155#[derive(Debug)]
156struct ScannerInfo {
157    partition_ranges: Vec<Vec<PartitionRange>>,
158    time_index: HashSet<String>,
159    tag_columns: Vec<String>,
160}
161
162fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Option<ScannerInfo>> {
163    let mut partition_ranges = None;
164    let mut time_index = HashSet::new();
165    let mut alias_map = Vec::new();
166    let mut tag_columns = None;
167    let mut is_batch_coalesced = false;
168
169    input.transform_up(|plan| {
170        if plan.as_any().is::<CooperativeExec>() {
171            return Ok(Transformed::no(plan));
172        }
173
174        // Unappliable case, reset the state.
175        if plan.as_any().is::<RepartitionExec>()
176            || plan.as_any().is::<CoalescePartitionsExec>()
177            || plan.as_any().is::<SortExec>()
178            || plan.as_any().is::<WindowedSortExec>()
179        {
180            partition_ranges = None;
181        }
182
183        if plan.as_any().is::<CoalesceBatchesExec>() {
184            is_batch_coalesced = true;
185        }
186
187        // only a very limited set of plans can exist between region scan and sort exec
188        // other plans might make this optimize wrong, so be safe here by limiting it
189        if !(plan.as_any().is::<ProjectionExec>()
190            || plan.as_any().is::<FilterExec>()
191            || plan.as_any().is::<CoalesceBatchesExec>())
192        {
193            partition_ranges = None;
194        }
195
196        // TODO(discord9): do this in logical plan instead as it's lessy bugy there
197        // Collects alias of the time index column.
198        if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
199            for (expr, output_name) in projection.expr() {
200                if let Some(column_expr) = expr.as_any().downcast_ref::<PhysicalColumn>() {
201                    alias_map.push((column_expr.name().to_string(), output_name.clone()));
202                }
203            }
204            // resolve alias properly
205            time_index = resolve_alias(&alias_map, &time_index);
206        }
207
208        if let Some(region_scan_exec) = plan.as_any().downcast_ref::<RegionScanExec>() {
209            // `PerSeries` distribution is not supported in windowed sort.
210            if region_scan_exec.distribution()
211                == Some(store_api::storage::TimeSeriesDistribution::PerSeries)
212            {
213                partition_ranges = None;
214                return Ok(Transformed::no(plan));
215            }
216
217            partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges());
218            // Reset time index column.
219            time_index = HashSet::from([region_scan_exec.time_index()]);
220            tag_columns = Some(region_scan_exec.tag_columns());
221
222            // set distinguish_partition_ranges to true, this is an incorrect workaround
223            if !is_batch_coalesced {
224                region_scan_exec.with_distinguish_partition_range(true);
225            }
226        }
227
228        Ok(Transformed::no(plan))
229    })?;
230
231    let result = try {
232        ScannerInfo {
233            partition_ranges: partition_ranges?,
234            time_index,
235            tag_columns: tag_columns?,
236        }
237    };
238
239    Ok(result)
240}
241
242/// Removes the repartition plan between the filter and region scan.
243fn remove_repartition(
244    plan: Arc<dyn ExecutionPlan>,
245) -> DataFusionResult<Transformed<Arc<dyn ExecutionPlan>>> {
246    plan.transform_down(|plan| {
247        if plan.as_any().is::<FilterExec>() {
248            // Checks child.
249            let maybe_repartition = plan.children()[0];
250            if maybe_repartition.as_any().is::<RepartitionExec>() {
251                let maybe_scan = maybe_repartition.children()[0];
252                if maybe_scan.as_any().is::<RegionScanExec>() {
253                    let new_filter = plan.clone().with_new_children(vec![maybe_scan.clone()])?;
254                    return Ok(Transformed::yes(new_filter));
255                }
256            }
257        }
258
259        Ok(Transformed::no(plan))
260    })
261}
262
263/// Remove `CoalesceBatchesExec` if the limit is less than the batch size.
264///
265/// so that if limit is too small we can avoid need to scan for more rows than necessary
266fn remove_coalesce_batches_exec(
267    plan: Arc<dyn ExecutionPlan>,
268    fetch: Option<usize>,
269) -> DataFusionResult<Transformed<Arc<dyn ExecutionPlan>>> {
270    let Some(fetch) = fetch else {
271        return Ok(Transformed::no(plan));
272    };
273
274    // Avoid removing multiple coalesce batches
275    let mut is_done = false;
276
277    plan.transform_down(|plan| {
278        if let Some(coalesce_batches_exec) = plan.as_any().downcast_ref::<CoalesceBatchesExec>() {
279            let target_batch_size = coalesce_batches_exec.target_batch_size();
280            if fetch < target_batch_size && !is_done {
281                is_done = true;
282                return Ok(Transformed::yes(coalesce_batches_exec.input().clone()));
283            }
284        }
285
286        Ok(Transformed::no(plan))
287    })
288}
289
290/// Resolves alias of the time index column.
291///
292/// i.e if a is time index, alias= {a:b, b:c}, then result should be {a, b}(not {a, c}) because projection is not transitive
293/// if alias={b:a} and a is time index, then return empty
294fn resolve_alias(alias_map: &[(String, String)], time_index: &HashSet<String>) -> HashSet<String> {
295    // available old name for time index
296    let mut avail_old_name = time_index.clone();
297    let mut new_time_index = HashSet::new();
298    for (old, new) in alias_map {
299        if time_index.contains(old) {
300            new_time_index.insert(new.clone());
301        } else if time_index.contains(new) && old != new {
302            // other alias to time index, remove the old name
303            avail_old_name.remove(new);
304            continue;
305        }
306    }
307    // add the remaining time index that is not in alias map
308    new_time_index.extend(avail_old_name);
309    new_time_index
310}
311
312#[cfg(test)]
313mod test {
314    use itertools::Itertools;
315
316    use super::*;
317
318    #[test]
319    fn test_alias() {
320        let testcases = [
321            // notice the old name is still in the result
322            (
323                vec![("a", "b"), ("b", "c")],
324                HashSet::from(["a"]),
325                HashSet::from(["a", "b"]),
326            ),
327            // alias swap
328            (
329                vec![("b", "a"), ("a", "b")],
330                HashSet::from(["a"]),
331                HashSet::from(["b"]),
332            ),
333            (
334                vec![("b", "a"), ("b", "c")],
335                HashSet::from(["a"]),
336                HashSet::from([]),
337            ),
338            // not in alias map
339            (
340                vec![("c", "d"), ("d", "c")],
341                HashSet::from(["a"]),
342                HashSet::from(["a"]),
343            ),
344            // no alias
345            (vec![], HashSet::from(["a"]), HashSet::from(["a"])),
346            // empty time index
347            (vec![], HashSet::from([]), HashSet::from([])),
348        ];
349        for (alias_map, time_index, expected) in testcases {
350            let alias_map = alias_map
351                .into_iter()
352                .map(|(k, v)| (k.to_string(), v.to_string()))
353                .collect_vec();
354            let time_index = time_index.into_iter().map(|i| i.to_string()).collect();
355            let expected: HashSet<String> = expected.into_iter().map(|i| i.to_string()).collect();
356
357            assert_eq!(
358                expected,
359                resolve_alias(&alias_map, &time_index),
360                "alias_map={:?}, time_index={:?}",
361                alias_map,
362                time_index
363            );
364        }
365    }
366}