Skip to main content

query/optimizer/
promql_tsid_narrow_join.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow_schema::{DataType, SchemaRef};
18use datafusion::config::ConfigOptions;
19use datafusion::physical_optimizer::PhysicalOptimizerRule;
20use datafusion::physical_plan::ExecutionPlan;
21use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
22use datafusion_common::Result as DfResult;
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_expr::JoinType;
25use datafusion_physical_expr::expressions::Column;
26use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
27
28/// Chooses a broadcast-style hash join for the PromQL vector-vector shape where
29/// the build side only carries value, `__tsid`, and timestamp columns.
30///
31/// PromQL arithmetic joins often keep one side narrow (without raw labels) and the other side wide
32/// with all output labels. Partitioning both sides shuffles the wide stream.
33/// `CollectLeft` only gathers the narrow build side and lets the wide probe side
34/// keep its existing partitioning.
35#[derive(Debug)]
36pub struct PromqlTsidNarrowJoin;
37
38impl PhysicalOptimizerRule for PromqlTsidNarrowJoin {
39    fn optimize(
40        &self,
41        plan: Arc<dyn ExecutionPlan>,
42        _config: &ConfigOptions,
43    ) -> DfResult<Arc<dyn ExecutionPlan>> {
44        plan.transform_up(Self::rewrite_join).data()
45    }
46
47    fn name(&self) -> &str {
48        "PromqlTsidNarrowJoin"
49    }
50
51    fn schema_check(&self) -> bool {
52        true
53    }
54}
55
56impl PromqlTsidNarrowJoin {
57    fn rewrite_join(plan: Arc<dyn ExecutionPlan>) -> DfResult<Transformed<Arc<dyn ExecutionPlan>>> {
58        let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() else {
59            return Ok(Transformed::no(plan));
60        };
61
62        if !Self::should_collect_left(hash_join) {
63            return Ok(Transformed::no(plan));
64        }
65
66        Ok(Transformed::yes(
67            hash_join
68                .builder()
69                .with_partition_mode(PartitionMode::CollectLeft)
70                .reset_state()
71                .build_exec()?,
72        ))
73    }
74
75    fn should_collect_left(hash_join: &HashJoinExec) -> bool {
76        hash_join.partition_mode() == &PartitionMode::Partitioned
77            && hash_join.join_type() == &JoinType::Inner
78            && hash_join.filter().is_none()
79            && hash_join.right().schema().fields().len() > hash_join.left().schema().fields().len()
80            && Self::is_promql_value_tsid_time_schema(&hash_join.left().schema())
81            && Self::joins_on_tsid_and_time(hash_join)
82    }
83
84    fn is_promql_value_tsid_time_schema(schema: &SchemaRef) -> bool {
85        let mut value_columns = 0;
86        let mut has_tsid = false;
87        let mut has_time = false;
88
89        for field in schema.fields() {
90            match field.name().as_str() {
91                DATA_SCHEMA_TSID_COLUMN_NAME => has_tsid = true,
92                _ if matches!(field.data_type(), DataType::Timestamp(_, _)) => has_time = true,
93                _ => value_columns += 1,
94            }
95        }
96
97        value_columns == 1 && has_tsid && has_time
98    }
99
100    fn joins_on_tsid_and_time(hash_join: &HashJoinExec) -> bool {
101        let mut has_tsid = false;
102        let mut has_time = false;
103
104        for (left, right) in hash_join.on() {
105            let (Some(left_col), Some(right_col)) = (
106                left.as_any().downcast_ref::<Column>(),
107                right.as_any().downcast_ref::<Column>(),
108            ) else {
109                return false;
110            };
111
112            if left_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
113                && right_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
114            {
115                has_tsid = true;
116            } else if matches!(
117                hash_join
118                    .left()
119                    .schema()
120                    .field(left_col.index())
121                    .data_type(),
122                DataType::Timestamp(_, _)
123            ) && matches!(
124                hash_join
125                    .right()
126                    .schema()
127                    .field(right_col.index())
128                    .data_type(),
129                DataType::Timestamp(_, _)
130            ) {
131                has_time = true;
132            }
133        }
134
135        has_tsid && has_time
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use arrow_schema::{DataType, Field, Schema, TimeUnit};
142    use datafusion::common::NullEquality;
143    use datafusion::physical_optimizer::PhysicalOptimizerRule;
144    use datafusion::physical_plan::displayable;
145    use datafusion::physical_plan::empty::EmptyExec;
146    use datafusion::physical_plan::joins::HashJoinExec;
147    use datafusion_common::config::ConfigOptions;
148    use datafusion_physical_expr::PhysicalExpr;
149
150    use super::*;
151
152    #[test]
153    fn chooses_collect_left_for_narrow_promql_build_side() {
154        let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
155            Field::new("greptime_value", DataType::Float64, true),
156            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
157            Field::new(
158                "greptime_timestamp",
159                DataType::Timestamp(TimeUnit::Millisecond, None),
160                false,
161            ),
162        ])))) as Arc<dyn ExecutionPlan>;
163        let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
164            Field::new("greptime_value", DataType::Float64, true),
165            Field::new("host", DataType::Utf8, true),
166            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
167            Field::new(
168                "greptime_timestamp",
169                DataType::Timestamp(TimeUnit::Millisecond, None),
170                false,
171            ),
172        ])))) as Arc<dyn ExecutionPlan>;
173        let on = vec![
174            (
175                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
176                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
177            ),
178            (
179                Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
180                Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
181            ),
182        ];
183        let join = Arc::new(
184            HashJoinExec::try_new(
185                left,
186                right,
187                on,
188                None,
189                &JoinType::Inner,
190                Some(vec![0, 3, 4, 5, 6]),
191                PartitionMode::Partitioned,
192                NullEquality::NullEqualsNull,
193                false,
194            )
195            .unwrap(),
196        ) as Arc<dyn ExecutionPlan>;
197        let original_schema = join.schema();
198
199        let optimized = PromqlTsidNarrowJoin
200            .optimize(join, &ConfigOptions::default())
201            .unwrap();
202        let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
203
204        assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
205        assert_eq!(optimized.schema(), original_schema);
206        assert!(
207            displayable(optimized.as_ref())
208                .one_line()
209                .to_string()
210                .contains(
211                    "projection=[greptime_value@0, greptime_value@3, host@4, __tsid@5, greptime_timestamp@6]"
212                )
213        );
214    }
215
216    #[test]
217    fn chooses_collect_left_for_computed_narrow_value_column() {
218        let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
219            Field::new("prom_rate(greptime_value)", DataType::Float64, true),
220            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
221            Field::new(
222                "greptime_timestamp",
223                DataType::Timestamp(TimeUnit::Millisecond, None),
224                false,
225            ),
226        ])))) as Arc<dyn ExecutionPlan>;
227        let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
228            Field::new("greptime_value", DataType::Float64, true),
229            Field::new("host", DataType::Utf8, true),
230            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
231            Field::new(
232                "greptime_timestamp",
233                DataType::Timestamp(TimeUnit::Millisecond, None),
234                false,
235            ),
236        ])))) as Arc<dyn ExecutionPlan>;
237        let on = vec![
238            (
239                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
240                Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
241            ),
242            (
243                Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
244                Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
245            ),
246        ];
247        let join = Arc::new(
248            HashJoinExec::try_new(
249                left,
250                right,
251                on,
252                None,
253                &JoinType::Inner,
254                Some(vec![0, 3, 4, 5, 6]),
255                PartitionMode::Partitioned,
256                NullEquality::NullEqualsNull,
257                false,
258            )
259            .unwrap(),
260        ) as Arc<dyn ExecutionPlan>;
261
262        let optimized = PromqlTsidNarrowJoin
263            .optimize(join, &ConfigOptions::default())
264            .unwrap();
265        let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
266
267        assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
268    }
269
270    #[test]
271    fn keeps_partitioned_join_when_left_side_carries_labels() {
272        let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
273            Field::new("greptime_value", DataType::Float64, true),
274            Field::new("host", DataType::Utf8, true),
275            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
276            Field::new(
277                "greptime_timestamp",
278                DataType::Timestamp(TimeUnit::Millisecond, None),
279                false,
280            ),
281        ])))) as Arc<dyn ExecutionPlan>;
282        let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
283            Field::new("greptime_value", DataType::Float64, true),
284            Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
285            Field::new(
286                "greptime_timestamp",
287                DataType::Timestamp(TimeUnit::Millisecond, None),
288                false,
289            ),
290        ])))) as Arc<dyn ExecutionPlan>;
291        let join = Arc::new(
292            HashJoinExec::try_new(
293                left,
294                right,
295                vec![
296                    (
297                        Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2))
298                            as Arc<dyn PhysicalExpr>,
299                        Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1))
300                            as Arc<dyn PhysicalExpr>,
301                    ),
302                    (
303                        Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
304                        Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
305                    ),
306                ],
307                None,
308                &JoinType::Inner,
309                None,
310                PartitionMode::Partitioned,
311                NullEquality::NullEqualsNull,
312                false,
313            )
314            .unwrap(),
315        ) as Arc<dyn ExecutionPlan>;
316
317        let optimized = PromqlTsidNarrowJoin
318            .optimize(join, &ConfigOptions::default())
319            .unwrap();
320        let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
321
322        assert_eq!(optimized_join.partition_mode(), &PartitionMode::Partitioned);
323    }
324}