1use std::sync::Arc;
16
17use arrow_schema::{DataType, SchemaRef};
18use datafusion::config::ConfigOptions;
19use datafusion::physical_optimizer::PhysicalOptimizerRule;
20use datafusion::physical_plan::ExecutionPlan;
21use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
22use datafusion_common::Result as DfResult;
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_expr::JoinType;
25use datafusion_physical_expr::expressions::Column;
26use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
27
28#[derive(Debug)]
36pub struct PromqlTsidNarrowJoin;
37
38impl PhysicalOptimizerRule for PromqlTsidNarrowJoin {
39 fn optimize(
40 &self,
41 plan: Arc<dyn ExecutionPlan>,
42 _config: &ConfigOptions,
43 ) -> DfResult<Arc<dyn ExecutionPlan>> {
44 plan.transform_up(Self::rewrite_join).data()
45 }
46
47 fn name(&self) -> &str {
48 "PromqlTsidNarrowJoin"
49 }
50
51 fn schema_check(&self) -> bool {
52 true
53 }
54}
55
56impl PromqlTsidNarrowJoin {
57 fn rewrite_join(plan: Arc<dyn ExecutionPlan>) -> DfResult<Transformed<Arc<dyn ExecutionPlan>>> {
58 let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() else {
59 return Ok(Transformed::no(plan));
60 };
61
62 if !Self::should_collect_left(hash_join) {
63 return Ok(Transformed::no(plan));
64 }
65
66 Ok(Transformed::yes(
67 hash_join
68 .builder()
69 .with_partition_mode(PartitionMode::CollectLeft)
70 .reset_state()
71 .build_exec()?,
72 ))
73 }
74
75 fn should_collect_left(hash_join: &HashJoinExec) -> bool {
76 hash_join.partition_mode() == &PartitionMode::Partitioned
77 && hash_join.join_type() == &JoinType::Inner
78 && hash_join.filter().is_none()
79 && hash_join.right().schema().fields().len() > hash_join.left().schema().fields().len()
80 && Self::is_promql_value_tsid_time_schema(&hash_join.left().schema())
81 && Self::joins_on_tsid_and_time(hash_join)
82 }
83
84 fn is_promql_value_tsid_time_schema(schema: &SchemaRef) -> bool {
85 let mut value_columns = 0;
86 let mut has_tsid = false;
87 let mut has_time = false;
88
89 for field in schema.fields() {
90 match field.name().as_str() {
91 DATA_SCHEMA_TSID_COLUMN_NAME => has_tsid = true,
92 _ if matches!(field.data_type(), DataType::Timestamp(_, _)) => has_time = true,
93 _ => value_columns += 1,
94 }
95 }
96
97 value_columns == 1 && has_tsid && has_time
98 }
99
100 fn joins_on_tsid_and_time(hash_join: &HashJoinExec) -> bool {
101 let mut has_tsid = false;
102 let mut has_time = false;
103
104 for (left, right) in hash_join.on() {
105 let (Some(left_col), Some(right_col)) = (
106 left.as_any().downcast_ref::<Column>(),
107 right.as_any().downcast_ref::<Column>(),
108 ) else {
109 return false;
110 };
111
112 if left_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
113 && right_col.name() == DATA_SCHEMA_TSID_COLUMN_NAME
114 {
115 has_tsid = true;
116 } else if matches!(
117 hash_join
118 .left()
119 .schema()
120 .field(left_col.index())
121 .data_type(),
122 DataType::Timestamp(_, _)
123 ) && matches!(
124 hash_join
125 .right()
126 .schema()
127 .field(right_col.index())
128 .data_type(),
129 DataType::Timestamp(_, _)
130 ) {
131 has_time = true;
132 }
133 }
134
135 has_tsid && has_time
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use arrow_schema::{DataType, Field, Schema, TimeUnit};
142 use datafusion::common::NullEquality;
143 use datafusion::physical_optimizer::PhysicalOptimizerRule;
144 use datafusion::physical_plan::displayable;
145 use datafusion::physical_plan::empty::EmptyExec;
146 use datafusion::physical_plan::joins::HashJoinExec;
147 use datafusion_common::config::ConfigOptions;
148 use datafusion_physical_expr::PhysicalExpr;
149
150 use super::*;
151
152 #[test]
153 fn chooses_collect_left_for_narrow_promql_build_side() {
154 let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
155 Field::new("greptime_value", DataType::Float64, true),
156 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
157 Field::new(
158 "greptime_timestamp",
159 DataType::Timestamp(TimeUnit::Millisecond, None),
160 false,
161 ),
162 ])))) as Arc<dyn ExecutionPlan>;
163 let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
164 Field::new("greptime_value", DataType::Float64, true),
165 Field::new("host", DataType::Utf8, true),
166 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
167 Field::new(
168 "greptime_timestamp",
169 DataType::Timestamp(TimeUnit::Millisecond, None),
170 false,
171 ),
172 ])))) as Arc<dyn ExecutionPlan>;
173 let on = vec![
174 (
175 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
176 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
177 ),
178 (
179 Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
180 Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
181 ),
182 ];
183 let join = Arc::new(
184 HashJoinExec::try_new(
185 left,
186 right,
187 on,
188 None,
189 &JoinType::Inner,
190 Some(vec![0, 3, 4, 5, 6]),
191 PartitionMode::Partitioned,
192 NullEquality::NullEqualsNull,
193 false,
194 )
195 .unwrap(),
196 ) as Arc<dyn ExecutionPlan>;
197 let original_schema = join.schema();
198
199 let optimized = PromqlTsidNarrowJoin
200 .optimize(join, &ConfigOptions::default())
201 .unwrap();
202 let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
203
204 assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
205 assert_eq!(optimized.schema(), original_schema);
206 assert!(
207 displayable(optimized.as_ref())
208 .one_line()
209 .to_string()
210 .contains(
211 "projection=[greptime_value@0, greptime_value@3, host@4, __tsid@5, greptime_timestamp@6]"
212 )
213 );
214 }
215
216 #[test]
217 fn chooses_collect_left_for_computed_narrow_value_column() {
218 let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
219 Field::new("prom_rate(greptime_value)", DataType::Float64, true),
220 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
221 Field::new(
222 "greptime_timestamp",
223 DataType::Timestamp(TimeUnit::Millisecond, None),
224 false,
225 ),
226 ])))) as Arc<dyn ExecutionPlan>;
227 let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
228 Field::new("greptime_value", DataType::Float64, true),
229 Field::new("host", DataType::Utf8, true),
230 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
231 Field::new(
232 "greptime_timestamp",
233 DataType::Timestamp(TimeUnit::Millisecond, None),
234 false,
235 ),
236 ])))) as Arc<dyn ExecutionPlan>;
237 let on = vec![
238 (
239 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1)) as Arc<dyn PhysicalExpr>,
240 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2)) as Arc<dyn PhysicalExpr>,
241 ),
242 (
243 Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
244 Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
245 ),
246 ];
247 let join = Arc::new(
248 HashJoinExec::try_new(
249 left,
250 right,
251 on,
252 None,
253 &JoinType::Inner,
254 Some(vec![0, 3, 4, 5, 6]),
255 PartitionMode::Partitioned,
256 NullEquality::NullEqualsNull,
257 false,
258 )
259 .unwrap(),
260 ) as Arc<dyn ExecutionPlan>;
261
262 let optimized = PromqlTsidNarrowJoin
263 .optimize(join, &ConfigOptions::default())
264 .unwrap();
265 let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
266
267 assert_eq!(optimized_join.partition_mode(), &PartitionMode::CollectLeft);
268 }
269
270 #[test]
271 fn keeps_partitioned_join_when_left_side_carries_labels() {
272 let left = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
273 Field::new("greptime_value", DataType::Float64, true),
274 Field::new("host", DataType::Utf8, true),
275 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
276 Field::new(
277 "greptime_timestamp",
278 DataType::Timestamp(TimeUnit::Millisecond, None),
279 false,
280 ),
281 ])))) as Arc<dyn ExecutionPlan>;
282 let right = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![
283 Field::new("greptime_value", DataType::Float64, true),
284 Field::new(DATA_SCHEMA_TSID_COLUMN_NAME, DataType::UInt64, false),
285 Field::new(
286 "greptime_timestamp",
287 DataType::Timestamp(TimeUnit::Millisecond, None),
288 false,
289 ),
290 ])))) as Arc<dyn ExecutionPlan>;
291 let join = Arc::new(
292 HashJoinExec::try_new(
293 left,
294 right,
295 vec![
296 (
297 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 2))
298 as Arc<dyn PhysicalExpr>,
299 Arc::new(Column::new(DATA_SCHEMA_TSID_COLUMN_NAME, 1))
300 as Arc<dyn PhysicalExpr>,
301 ),
302 (
303 Arc::new(Column::new("greptime_timestamp", 3)) as Arc<dyn PhysicalExpr>,
304 Arc::new(Column::new("greptime_timestamp", 2)) as Arc<dyn PhysicalExpr>,
305 ),
306 ],
307 None,
308 &JoinType::Inner,
309 None,
310 PartitionMode::Partitioned,
311 NullEquality::NullEqualsNull,
312 false,
313 )
314 .unwrap(),
315 ) as Arc<dyn ExecutionPlan>;
316
317 let optimized = PromqlTsidNarrowJoin
318 .optimize(join, &ConfigOptions::default())
319 .unwrap();
320 let optimized_join = optimized.as_any().downcast_ref::<HashJoinExec>().unwrap();
321
322 assert_eq!(optimized_join.partition_mode(), &PartitionMode::Partitioned);
323 }
324}