promql/extension_plan/
normalize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use datafusion::arrow::array::{BooleanArray, Float64Array};
21use datafusion::arrow::compute;
22use datafusion::common::{DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
23use datafusion::error::DataFusionError;
24use datafusion::execution::context::TaskContext;
25use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, UserDefinedLogicalNodeCore};
26use datafusion::physical_plan::expressions::Column as ColumnExpr;
27use datafusion::physical_plan::metrics::{
28    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet,
29};
30use datafusion::physical_plan::{
31    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, RecordBatchStream,
32    SendableRecordBatchStream,
33};
34use datafusion_expr::col;
35use datatypes::arrow::array::TimestampMillisecondArray;
36use datatypes::arrow::datatypes::SchemaRef;
37use datatypes::arrow::record_batch::RecordBatch;
38use futures::{Stream, StreamExt, ready};
39use greptime_proto::substrait_extension as pb;
40use prost::Message;
41use snafu::ResultExt;
42
43use crate::error::{DeserializeSnafu, Result};
44use crate::extension_plan::{
45    METRIC_NUM_SERIES, Millisecond, resolve_column_name, serialize_column_index,
46};
47use crate::metrics::PROMQL_SERIES_COUNT;
48
49/// Normalize the input record batch. Notice that for simplicity, this method assumes
50/// the input batch only contains sample points from one time series.
51///
52/// Roughly speaking, this method does these things:
53/// - bias sample's timestamp by offset
54/// - sort the record batch based on timestamp column
55/// - remove NaN values (optional)
56#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
57pub struct SeriesNormalize {
58    offset: Millisecond,
59    time_index_column_name: String,
60    need_filter_out_nan: bool,
61    tag_columns: Vec<String>,
62
63    input: LogicalPlan,
64    unfix: Option<UnfixIndices>,
65}
66
67#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
68struct UnfixIndices {
69    pub time_index_idx: u64,
70    pub tag_column_indices: Vec<u64>,
71}
72
73impl UserDefinedLogicalNodeCore for SeriesNormalize {
74    fn name(&self) -> &str {
75        Self::name()
76    }
77
78    fn inputs(&self) -> Vec<&LogicalPlan> {
79        vec![&self.input]
80    }
81
82    fn schema(&self) -> &DFSchemaRef {
83        self.input.schema()
84    }
85
86    fn expressions(&self) -> Vec<datafusion::logical_expr::Expr> {
87        if self.unfix.is_some() {
88            return vec![];
89        }
90
91        self.tag_columns
92            .iter()
93            .map(col)
94            .chain(std::iter::once(col(&self.time_index_column_name)))
95            .collect()
96    }
97
98    fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
99        if self.unfix.is_some() {
100            return None;
101        }
102
103        let input_schema = self.input.schema();
104        if output_columns.is_empty() {
105            let indices = (0..input_schema.fields().len()).collect::<Vec<_>>();
106            return Some(vec![indices]);
107        }
108
109        let mut required = Vec::with_capacity(output_columns.len() + 1 + self.tag_columns.len());
110        required.extend_from_slice(output_columns);
111        required.push(input_schema.index_of_column_by_name(None, &self.time_index_column_name)?);
112        for tag in &self.tag_columns {
113            required.push(input_schema.index_of_column_by_name(None, tag)?);
114        }
115
116        required.sort_unstable();
117        required.dedup();
118        Some(vec![required])
119    }
120
121    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122        write!(
123            f,
124            "PromSeriesNormalize: offset=[{}], time index=[{}], filter NaN: [{}]",
125            self.offset, self.time_index_column_name, self.need_filter_out_nan
126        )
127    }
128
129    fn with_exprs_and_inputs(
130        &self,
131        _exprs: Vec<Expr>,
132        inputs: Vec<LogicalPlan>,
133    ) -> DataFusionResult<Self> {
134        if inputs.is_empty() {
135            return Err(DataFusionError::Internal(
136                "SeriesNormalize should have at least one input".to_string(),
137            ));
138        }
139
140        let input: LogicalPlan = inputs.into_iter().next().unwrap();
141        let input_schema = input.schema();
142
143        if let Some(unfix) = &self.unfix {
144            // transform indices to names
145            let time_index_column_name = resolve_column_name(
146                unfix.time_index_idx,
147                input_schema,
148                "SeriesNormalize",
149                "time index",
150            )?;
151
152            let tag_columns = unfix
153                .tag_column_indices
154                .iter()
155                .map(|idx| resolve_column_name(*idx, input_schema, "SeriesNormalize", "tag"))
156                .collect::<DataFusionResult<Vec<String>>>()?;
157
158            Ok(Self {
159                offset: self.offset,
160                time_index_column_name,
161                need_filter_out_nan: self.need_filter_out_nan,
162                tag_columns,
163                input,
164                unfix: None,
165            })
166        } else {
167            Ok(Self {
168                offset: self.offset,
169                time_index_column_name: self.time_index_column_name.clone(),
170                need_filter_out_nan: self.need_filter_out_nan,
171                tag_columns: self.tag_columns.clone(),
172                input,
173                unfix: None,
174            })
175        }
176    }
177}
178
179impl SeriesNormalize {
180    pub fn new<N: AsRef<str>>(
181        offset: Millisecond,
182        time_index_column_name: N,
183        need_filter_out_nan: bool,
184        tag_columns: Vec<String>,
185        input: LogicalPlan,
186    ) -> Self {
187        Self {
188            offset,
189            time_index_column_name: time_index_column_name.as_ref().to_string(),
190            need_filter_out_nan,
191            tag_columns,
192            input,
193            unfix: None,
194        }
195    }
196
197    pub const fn name() -> &'static str {
198        "SeriesNormalize"
199    }
200
201    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
202        Arc::new(SeriesNormalizeExec {
203            offset: self.offset,
204            time_index_column_name: self.time_index_column_name.clone(),
205            need_filter_out_nan: self.need_filter_out_nan,
206            input: exec_input,
207            tag_columns: self.tag_columns.clone(),
208            metric: ExecutionPlanMetricsSet::new(),
209        })
210    }
211
212    pub fn serialize(&self) -> Vec<u8> {
213        let time_index_idx =
214            serialize_column_index(self.input.schema(), &self.time_index_column_name);
215
216        let tag_column_indices = self
217            .tag_columns
218            .iter()
219            .map(|name| serialize_column_index(self.input.schema(), name))
220            .collect::<Vec<u64>>();
221
222        pb::SeriesNormalize {
223            offset: self.offset,
224            time_index_idx,
225            filter_nan: self.need_filter_out_nan,
226            tag_column_indices,
227            ..Default::default()
228        }
229        .encode_to_vec()
230    }
231
232    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
233        let pb_normalize = pb::SeriesNormalize::decode(bytes).context(DeserializeSnafu)?;
234        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
235            produce_one_row: false,
236            schema: Arc::new(DFSchema::empty()),
237        });
238
239        let unfix = UnfixIndices {
240            time_index_idx: pb_normalize.time_index_idx,
241            tag_column_indices: pb_normalize.tag_column_indices.clone(),
242        };
243
244        Ok(Self {
245            offset: pb_normalize.offset,
246            time_index_column_name: String::new(),
247            need_filter_out_nan: pb_normalize.filter_nan,
248            tag_columns: Vec::new(),
249            input: placeholder_plan,
250            unfix: Some(unfix),
251        })
252    }
253}
254
255#[derive(Debug)]
256pub struct SeriesNormalizeExec {
257    offset: Millisecond,
258    time_index_column_name: String,
259    need_filter_out_nan: bool,
260    tag_columns: Vec<String>,
261
262    input: Arc<dyn ExecutionPlan>,
263    metric: ExecutionPlanMetricsSet,
264}
265
266impl ExecutionPlan for SeriesNormalizeExec {
267    fn as_any(&self) -> &dyn Any {
268        self
269    }
270
271    fn schema(&self) -> SchemaRef {
272        self.input.schema()
273    }
274
275    fn required_input_distribution(&self) -> Vec<Distribution> {
276        let schema = self.input.schema();
277        vec![Distribution::HashPartitioned(
278            self.tag_columns
279                .iter()
280                // Safety: the tag column names is verified in the planning phase
281                .map(|tag| Arc::new(ColumnExpr::new_with_schema(tag, &schema).unwrap()) as _)
282                .collect(),
283        )]
284    }
285
286    fn properties(&self) -> &PlanProperties {
287        self.input.properties()
288    }
289
290    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
291        vec![&self.input]
292    }
293
294    fn with_new_children(
295        self: Arc<Self>,
296        children: Vec<Arc<dyn ExecutionPlan>>,
297    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
298        assert!(!children.is_empty());
299        Ok(Arc::new(Self {
300            offset: self.offset,
301            time_index_column_name: self.time_index_column_name.clone(),
302            need_filter_out_nan: self.need_filter_out_nan,
303            input: children[0].clone(),
304            tag_columns: self.tag_columns.clone(),
305            metric: self.metric.clone(),
306        }))
307    }
308
309    fn execute(
310        &self,
311        partition: usize,
312        context: Arc<TaskContext>,
313    ) -> DataFusionResult<SendableRecordBatchStream> {
314        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
315        let metrics_builder = MetricBuilder::new(&self.metric);
316        let num_series = Count::new();
317        metrics_builder
318            .with_partition(partition)
319            .build(MetricValue::Count {
320                name: METRIC_NUM_SERIES.into(),
321                count: num_series.clone(),
322            });
323
324        let input = self.input.execute(partition, context)?;
325        let schema = input.schema();
326        let time_index = schema
327            .column_with_name(&self.time_index_column_name)
328            .expect("time index column not found")
329            .0;
330        Ok(Box::pin(SeriesNormalizeStream {
331            offset: self.offset,
332            time_index,
333            need_filter_out_nan: self.need_filter_out_nan,
334            schema,
335            input,
336            metric: baseline_metric,
337            num_series,
338        }))
339    }
340
341    fn metrics(&self) -> Option<MetricsSet> {
342        Some(self.metric.clone_inner())
343    }
344
345    fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
346        self.input.partition_statistics(partition)
347    }
348
349    fn name(&self) -> &str {
350        "SeriesNormalizeExec"
351    }
352}
353
354impl DisplayAs for SeriesNormalizeExec {
355    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
356        match t {
357            DisplayFormatType::Default
358            | DisplayFormatType::Verbose
359            | DisplayFormatType::TreeRender => {
360                write!(
361                    f,
362                    "PromSeriesNormalizeExec: offset=[{}], time index=[{}], filter NaN: [{}]",
363                    self.offset, self.time_index_column_name, self.need_filter_out_nan
364                )
365            }
366        }
367    }
368}
369
370pub struct SeriesNormalizeStream {
371    offset: Millisecond,
372    // Column index of TIME INDEX column's position in schema
373    time_index: usize,
374    need_filter_out_nan: bool,
375
376    schema: SchemaRef,
377    input: SendableRecordBatchStream,
378    metric: BaselineMetrics,
379    /// Number of series processed.
380    num_series: Count,
381}
382
383impl SeriesNormalizeStream {
384    pub fn normalize(&self, input: RecordBatch) -> DataFusionResult<RecordBatch> {
385        let ts_column = input
386            .column(self.time_index)
387            .as_any()
388            .downcast_ref::<TimestampMillisecondArray>()
389            .ok_or_else(|| {
390                DataFusionError::Execution(
391                    "Time index Column downcast to TimestampMillisecondArray failed".into(),
392                )
393            })?;
394
395        // bias the timestamp column by offset
396        let ts_column_biased = if self.offset == 0 {
397            Arc::new(ts_column.clone()) as _
398        } else {
399            Arc::new(TimestampMillisecondArray::from_iter(
400                ts_column.iter().map(|ts| ts.map(|ts| ts + self.offset)),
401            ))
402        };
403        let mut columns = input.columns().to_vec();
404        columns[self.time_index] = ts_column_biased;
405
406        let result_batch = RecordBatch::try_new(input.schema(), columns)?;
407        if !self.need_filter_out_nan {
408            return Ok(result_batch);
409        }
410
411        // TODO(ruihang): consider the "special NaN"
412        // filter out NaN
413        let mut filter = vec![true; input.num_rows()];
414        for column in result_batch.columns() {
415            if let Some(float_column) = column.as_any().downcast_ref::<Float64Array>() {
416                for (i, flag) in filter.iter_mut().enumerate() {
417                    if float_column.value(i).is_nan() {
418                        *flag = false;
419                    }
420                }
421            }
422        }
423
424        let result = compute::filter_record_batch(&result_batch, &BooleanArray::from(filter))
425            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
426        Ok(result)
427    }
428}
429
430impl RecordBatchStream for SeriesNormalizeStream {
431    fn schema(&self) -> SchemaRef {
432        self.schema.clone()
433    }
434}
435
436impl Stream for SeriesNormalizeStream {
437    type Item = DataFusionResult<RecordBatch>;
438
439    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
440        let poll = match ready!(self.input.poll_next_unpin(cx)) {
441            Some(Ok(batch)) => {
442                self.num_series.add(1);
443                let timer = std::time::Instant::now();
444                let result = Ok(batch).and_then(|batch| self.normalize(batch));
445                self.metric.elapsed_compute().add_elapsed(timer);
446                Poll::Ready(Some(result))
447            }
448            None => {
449                PROMQL_SERIES_COUNT.observe(self.num_series.value() as f64);
450                Poll::Ready(None)
451            }
452            Some(Err(e)) => Poll::Ready(Some(Err(e))),
453        };
454        self.metric.record_poll(poll)
455    }
456}
457
458#[cfg(test)]
459mod test {
460    use datafusion::arrow::array::Float64Array;
461    use datafusion::arrow::datatypes::{
462        ArrowPrimitiveType, DataType, Field, Schema, TimestampMillisecondType,
463    };
464    use datafusion::common::ToDFSchema;
465    use datafusion::datasource::memory::MemorySourceConfig;
466    use datafusion::datasource::source::DataSourceExec;
467    use datafusion::logical_expr::{EmptyRelation, LogicalPlan};
468    use datafusion::prelude::SessionContext;
469    use datatypes::arrow::array::TimestampMillisecondArray;
470    use datatypes::arrow_array::StringArray;
471
472    use super::*;
473
474    const TIME_INDEX_COLUMN: &str = "timestamp";
475
476    fn prepare_test_data() -> DataSourceExec {
477        let schema = Arc::new(Schema::new(vec![
478            Field::new(TIME_INDEX_COLUMN, TimestampMillisecondType::DATA_TYPE, true),
479            Field::new("value", DataType::Float64, true),
480            Field::new("path", DataType::Utf8, true),
481        ]));
482        let timestamp_column = Arc::new(TimestampMillisecondArray::from(vec![
483            60_000, 120_000, 0, 30_000, 90_000,
484        ])) as _;
485        let field_column = Arc::new(Float64Array::from(vec![0.0, 1.0, 10.0, 100.0, 1000.0])) as _;
486        let path_column = Arc::new(StringArray::from(vec!["foo", "foo", "foo", "foo", "foo"])) as _;
487        let data = RecordBatch::try_new(
488            schema.clone(),
489            vec![timestamp_column, field_column, path_column],
490        )
491        .unwrap();
492
493        DataSourceExec::new(Arc::new(
494            MemorySourceConfig::try_new(&[vec![data]], schema, None).unwrap(),
495        ))
496    }
497
498    #[test]
499    fn pruning_should_keep_time_and_tag_columns_for_exec() {
500        let df_schema = prepare_test_data().schema().to_dfschema_ref().unwrap();
501        let input = LogicalPlan::EmptyRelation(EmptyRelation {
502            produce_one_row: false,
503            schema: df_schema,
504        });
505        let plan =
506            SeriesNormalize::new(0, TIME_INDEX_COLUMN, true, vec!["path".to_string()], input);
507
508        // Simulate a parent projection requesting only the `value` column.
509        let output_columns = [1usize];
510        let required = plan.necessary_children_exprs(&output_columns).unwrap();
511        let required = &required[0];
512        assert_eq!(required.as_slice(), &[0, 1, 2]);
513    }
514
515    #[tokio::test]
516    async fn test_sort_record_batch() {
517        let memory_exec = Arc::new(prepare_test_data());
518        let normalize_exec = Arc::new(SeriesNormalizeExec {
519            offset: 0,
520            time_index_column_name: TIME_INDEX_COLUMN.to_string(),
521            need_filter_out_nan: true,
522            input: memory_exec,
523            tag_columns: vec!["path".to_string()],
524            metric: ExecutionPlanMetricsSet::new(),
525        });
526        let session_context = SessionContext::default();
527        let result = datafusion::physical_plan::collect(normalize_exec, session_context.task_ctx())
528            .await
529            .unwrap();
530        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
531            .unwrap()
532            .to_string();
533
534        let expected = String::from(
535            "+---------------------+--------+------+\
536            \n| timestamp           | value  | path |\
537            \n+---------------------+--------+------+\
538            \n| 1970-01-01T00:01:00 | 0.0    | foo  |\
539            \n| 1970-01-01T00:02:00 | 1.0    | foo  |\
540            \n| 1970-01-01T00:00:00 | 10.0   | foo  |\
541            \n| 1970-01-01T00:00:30 | 100.0  | foo  |\
542            \n| 1970-01-01T00:01:30 | 1000.0 | foo  |\
543            \n+---------------------+--------+------+",
544        );
545
546        assert_eq!(result_literal, expected);
547    }
548
549    #[tokio::test]
550    async fn test_offset_record_batch() {
551        let memory_exec = Arc::new(prepare_test_data());
552        let normalize_exec = Arc::new(SeriesNormalizeExec {
553            offset: 1_000,
554            time_index_column_name: TIME_INDEX_COLUMN.to_string(),
555            need_filter_out_nan: true,
556            input: memory_exec,
557            metric: ExecutionPlanMetricsSet::new(),
558            tag_columns: vec!["path".to_string()],
559        });
560        let session_context = SessionContext::default();
561        let result = datafusion::physical_plan::collect(normalize_exec, session_context.task_ctx())
562            .await
563            .unwrap();
564        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
565            .unwrap()
566            .to_string();
567
568        let expected = String::from(
569            "+---------------------+--------+------+\
570            \n| timestamp           | value  | path |\
571            \n+---------------------+--------+------+\
572            \n| 1970-01-01T00:01:01 | 0.0    | foo  |\
573            \n| 1970-01-01T00:02:01 | 1.0    | foo  |\
574            \n| 1970-01-01T00:00:01 | 10.0   | foo  |\
575            \n| 1970-01-01T00:00:31 | 100.0  | foo  |\
576            \n| 1970-01-01T00:01:31 | 1000.0 | foo  |\
577            \n+---------------------+--------+------+",
578        );
579
580        assert_eq!(result_literal, expected);
581    }
582}