promql/extension_plan/
normalize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use datafusion::arrow::array::{BooleanArray, Float64Array};
21use datafusion::arrow::compute;
22use datafusion::common::{DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
23use datafusion::error::DataFusionError;
24use datafusion::execution::context::TaskContext;
25use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, UserDefinedLogicalNodeCore};
26use datafusion::physical_plan::expressions::Column as ColumnExpr;
27use datafusion::physical_plan::metrics::{
28    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet,
29};
30use datafusion::physical_plan::{
31    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, RecordBatchStream,
32    SendableRecordBatchStream,
33};
34use datatypes::arrow::array::TimestampMillisecondArray;
35use datatypes::arrow::datatypes::SchemaRef;
36use datatypes::arrow::record_batch::RecordBatch;
37use futures::{ready, Stream, StreamExt};
38use greptime_proto::substrait_extension as pb;
39use prost::Message;
40use snafu::ResultExt;
41
42use crate::error::{DeserializeSnafu, Result};
43use crate::extension_plan::{Millisecond, METRIC_NUM_SERIES};
44use crate::metrics::PROMQL_SERIES_COUNT;
45
46/// Normalize the input record batch. Notice that for simplicity, this method assumes
47/// the input batch only contains sample points from one time series.
48///
49/// Roughly speaking, this method does these things:
50/// - bias sample's timestamp by offset
51/// - sort the record batch based on timestamp column
52/// - remove NaN values (optional)
53#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
54pub struct SeriesNormalize {
55    offset: Millisecond,
56    time_index_column_name: String,
57    need_filter_out_nan: bool,
58    tag_columns: Vec<String>,
59
60    input: LogicalPlan,
61}
62
63impl UserDefinedLogicalNodeCore for SeriesNormalize {
64    fn name(&self) -> &str {
65        Self::name()
66    }
67
68    fn inputs(&self) -> Vec<&LogicalPlan> {
69        vec![&self.input]
70    }
71
72    fn schema(&self) -> &DFSchemaRef {
73        self.input.schema()
74    }
75
76    fn expressions(&self) -> Vec<datafusion::logical_expr::Expr> {
77        vec![]
78    }
79
80    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
81        write!(
82            f,
83            "PromSeriesNormalize: offset=[{}], time index=[{}], filter NaN: [{}]",
84            self.offset, self.time_index_column_name, self.need_filter_out_nan
85        )
86    }
87
88    fn with_exprs_and_inputs(
89        &self,
90        _exprs: Vec<Expr>,
91        inputs: Vec<LogicalPlan>,
92    ) -> DataFusionResult<Self> {
93        if inputs.is_empty() {
94            return Err(DataFusionError::Internal(
95                "SeriesNormalize should have at least one input".to_string(),
96            ));
97        }
98
99        Ok(Self {
100            offset: self.offset,
101            time_index_column_name: self.time_index_column_name.clone(),
102            need_filter_out_nan: self.need_filter_out_nan,
103            input: inputs.into_iter().next().unwrap(),
104            tag_columns: self.tag_columns.clone(),
105        })
106    }
107}
108
109impl SeriesNormalize {
110    pub fn new<N: AsRef<str>>(
111        offset: Millisecond,
112        time_index_column_name: N,
113        need_filter_out_nan: bool,
114        tag_columns: Vec<String>,
115        input: LogicalPlan,
116    ) -> Self {
117        Self {
118            offset,
119            time_index_column_name: time_index_column_name.as_ref().to_string(),
120            need_filter_out_nan,
121            tag_columns,
122            input,
123        }
124    }
125
126    pub const fn name() -> &'static str {
127        "SeriesNormalize"
128    }
129
130    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
131        Arc::new(SeriesNormalizeExec {
132            offset: self.offset,
133            time_index_column_name: self.time_index_column_name.clone(),
134            need_filter_out_nan: self.need_filter_out_nan,
135            input: exec_input,
136            tag_columns: self.tag_columns.clone(),
137            metric: ExecutionPlanMetricsSet::new(),
138        })
139    }
140
141    pub fn serialize(&self) -> Vec<u8> {
142        pb::SeriesNormalize {
143            offset: self.offset,
144            time_index: self.time_index_column_name.clone(),
145            filter_nan: self.need_filter_out_nan,
146            tag_columns: self.tag_columns.clone(),
147        }
148        .encode_to_vec()
149    }
150
151    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
152        let pb_normalize = pb::SeriesNormalize::decode(bytes).context(DeserializeSnafu)?;
153        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
154            produce_one_row: false,
155            schema: Arc::new(DFSchema::empty()),
156        });
157        Ok(Self::new(
158            pb_normalize.offset,
159            pb_normalize.time_index,
160            pb_normalize.filter_nan,
161            pb_normalize.tag_columns,
162            placeholder_plan,
163        ))
164    }
165}
166
167#[derive(Debug)]
168pub struct SeriesNormalizeExec {
169    offset: Millisecond,
170    time_index_column_name: String,
171    need_filter_out_nan: bool,
172    tag_columns: Vec<String>,
173
174    input: Arc<dyn ExecutionPlan>,
175    metric: ExecutionPlanMetricsSet,
176}
177
178impl ExecutionPlan for SeriesNormalizeExec {
179    fn as_any(&self) -> &dyn Any {
180        self
181    }
182
183    fn schema(&self) -> SchemaRef {
184        self.input.schema()
185    }
186
187    fn required_input_distribution(&self) -> Vec<Distribution> {
188        let schema = self.input.schema();
189        vec![Distribution::HashPartitioned(
190            self.tag_columns
191                .iter()
192                // Safety: the tag column names is verified in the planning phase
193                .map(|tag| Arc::new(ColumnExpr::new_with_schema(tag, &schema).unwrap()) as _)
194                .collect(),
195        )]
196    }
197
198    fn properties(&self) -> &PlanProperties {
199        self.input.properties()
200    }
201
202    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
203        vec![&self.input]
204    }
205
206    fn with_new_children(
207        self: Arc<Self>,
208        children: Vec<Arc<dyn ExecutionPlan>>,
209    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
210        assert!(!children.is_empty());
211        Ok(Arc::new(Self {
212            offset: self.offset,
213            time_index_column_name: self.time_index_column_name.clone(),
214            need_filter_out_nan: self.need_filter_out_nan,
215            input: children[0].clone(),
216            tag_columns: self.tag_columns.clone(),
217            metric: self.metric.clone(),
218        }))
219    }
220
221    fn execute(
222        &self,
223        partition: usize,
224        context: Arc<TaskContext>,
225    ) -> DataFusionResult<SendableRecordBatchStream> {
226        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
227        let metrics_builder = MetricBuilder::new(&self.metric);
228        let num_series = Count::new();
229        metrics_builder
230            .with_partition(partition)
231            .build(MetricValue::Count {
232                name: METRIC_NUM_SERIES.into(),
233                count: num_series.clone(),
234            });
235
236        let input = self.input.execute(partition, context)?;
237        let schema = input.schema();
238        let time_index = schema
239            .column_with_name(&self.time_index_column_name)
240            .expect("time index column not found")
241            .0;
242        Ok(Box::pin(SeriesNormalizeStream {
243            offset: self.offset,
244            time_index,
245            need_filter_out_nan: self.need_filter_out_nan,
246            schema,
247            input,
248            metric: baseline_metric,
249            num_series,
250        }))
251    }
252
253    fn metrics(&self) -> Option<MetricsSet> {
254        Some(self.metric.clone_inner())
255    }
256
257    fn statistics(&self) -> DataFusionResult<Statistics> {
258        self.input.statistics()
259    }
260
261    fn name(&self) -> &str {
262        "SeriesNormalizeExec"
263    }
264}
265
266impl DisplayAs for SeriesNormalizeExec {
267    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
268        match t {
269            DisplayFormatType::Default | DisplayFormatType::Verbose => {
270                write!(
271                    f,
272                    "PromSeriesNormalizeExec: offset=[{}], time index=[{}], filter NaN: [{}]",
273                    self.offset, self.time_index_column_name, self.need_filter_out_nan
274                )
275            }
276        }
277    }
278}
279
280pub struct SeriesNormalizeStream {
281    offset: Millisecond,
282    // Column index of TIME INDEX column's position in schema
283    time_index: usize,
284    need_filter_out_nan: bool,
285
286    schema: SchemaRef,
287    input: SendableRecordBatchStream,
288    metric: BaselineMetrics,
289    /// Number of series processed.
290    num_series: Count,
291}
292
293impl SeriesNormalizeStream {
294    pub fn normalize(&self, input: RecordBatch) -> DataFusionResult<RecordBatch> {
295        let ts_column = input
296            .column(self.time_index)
297            .as_any()
298            .downcast_ref::<TimestampMillisecondArray>()
299            .ok_or_else(|| {
300                DataFusionError::Execution(
301                    "Time index Column downcast to TimestampMillisecondArray failed".into(),
302                )
303            })?;
304
305        // bias the timestamp column by offset
306        let ts_column_biased = if self.offset == 0 {
307            Arc::new(ts_column.clone()) as _
308        } else {
309            Arc::new(TimestampMillisecondArray::from_iter(
310                ts_column.iter().map(|ts| ts.map(|ts| ts + self.offset)),
311            ))
312        };
313        let mut columns = input.columns().to_vec();
314        columns[self.time_index] = ts_column_biased;
315
316        let result_batch = RecordBatch::try_new(input.schema(), columns)?;
317        if !self.need_filter_out_nan {
318            return Ok(result_batch);
319        }
320
321        // TODO(ruihang): consider the "special NaN"
322        // filter out NaN
323        let mut filter = vec![true; input.num_rows()];
324        for column in result_batch.columns() {
325            if let Some(float_column) = column.as_any().downcast_ref::<Float64Array>() {
326                for (i, flag) in filter.iter_mut().enumerate() {
327                    if float_column.value(i).is_nan() {
328                        *flag = false;
329                    }
330                }
331            }
332        }
333
334        let result = compute::filter_record_batch(&result_batch, &BooleanArray::from(filter))
335            .map_err(|e| DataFusionError::ArrowError(e, None))?;
336        Ok(result)
337    }
338}
339
340impl RecordBatchStream for SeriesNormalizeStream {
341    fn schema(&self) -> SchemaRef {
342        self.schema.clone()
343    }
344}
345
346impl Stream for SeriesNormalizeStream {
347    type Item = DataFusionResult<RecordBatch>;
348
349    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
350        let poll = match ready!(self.input.poll_next_unpin(cx)) {
351            Some(Ok(batch)) => {
352                self.num_series.add(1);
353                let timer = std::time::Instant::now();
354                let result = Ok(batch).and_then(|batch| self.normalize(batch));
355                self.metric.elapsed_compute().add_elapsed(timer);
356                Poll::Ready(Some(result))
357            }
358            None => {
359                PROMQL_SERIES_COUNT.observe(self.num_series.value() as f64);
360                Poll::Ready(None)
361            }
362            Some(Err(e)) => Poll::Ready(Some(Err(e))),
363        };
364        self.metric.record_poll(poll)
365    }
366}
367
368#[cfg(test)]
369mod test {
370    use datafusion::arrow::array::Float64Array;
371    use datafusion::arrow::datatypes::{
372        ArrowPrimitiveType, DataType, Field, Schema, TimestampMillisecondType,
373    };
374    use datafusion::physical_plan::memory::MemoryExec;
375    use datafusion::prelude::SessionContext;
376    use datatypes::arrow::array::TimestampMillisecondArray;
377    use datatypes::arrow_array::StringArray;
378
379    use super::*;
380
381    const TIME_INDEX_COLUMN: &str = "timestamp";
382
383    fn prepare_test_data() -> MemoryExec {
384        let schema = Arc::new(Schema::new(vec![
385            Field::new(TIME_INDEX_COLUMN, TimestampMillisecondType::DATA_TYPE, true),
386            Field::new("value", DataType::Float64, true),
387            Field::new("path", DataType::Utf8, true),
388        ]));
389        let timestamp_column = Arc::new(TimestampMillisecondArray::from(vec![
390            60_000, 120_000, 0, 30_000, 90_000,
391        ])) as _;
392        let field_column = Arc::new(Float64Array::from(vec![0.0, 1.0, 10.0, 100.0, 1000.0])) as _;
393        let path_column = Arc::new(StringArray::from(vec!["foo", "foo", "foo", "foo", "foo"])) as _;
394        let data = RecordBatch::try_new(
395            schema.clone(),
396            vec![timestamp_column, field_column, path_column],
397        )
398        .unwrap();
399
400        MemoryExec::try_new(&[vec![data]], schema, None).unwrap()
401    }
402
403    #[tokio::test]
404    async fn test_sort_record_batch() {
405        let memory_exec = Arc::new(prepare_test_data());
406        let normalize_exec = Arc::new(SeriesNormalizeExec {
407            offset: 0,
408            time_index_column_name: TIME_INDEX_COLUMN.to_string(),
409            need_filter_out_nan: true,
410            input: memory_exec,
411            tag_columns: vec!["path".to_string()],
412            metric: ExecutionPlanMetricsSet::new(),
413        });
414        let session_context = SessionContext::default();
415        let result = datafusion::physical_plan::collect(normalize_exec, session_context.task_ctx())
416            .await
417            .unwrap();
418        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
419            .unwrap()
420            .to_string();
421
422        let expected = String::from(
423            "+---------------------+--------+------+\
424            \n| timestamp           | value  | path |\
425            \n+---------------------+--------+------+\
426            \n| 1970-01-01T00:01:00 | 0.0    | foo  |\
427            \n| 1970-01-01T00:02:00 | 1.0    | foo  |\
428            \n| 1970-01-01T00:00:00 | 10.0   | foo  |\
429            \n| 1970-01-01T00:00:30 | 100.0  | foo  |\
430            \n| 1970-01-01T00:01:30 | 1000.0 | foo  |\
431            \n+---------------------+--------+------+",
432        );
433
434        assert_eq!(result_literal, expected);
435    }
436
437    #[tokio::test]
438    async fn test_offset_record_batch() {
439        let memory_exec = Arc::new(prepare_test_data());
440        let normalize_exec = Arc::new(SeriesNormalizeExec {
441            offset: 1_000,
442            time_index_column_name: TIME_INDEX_COLUMN.to_string(),
443            need_filter_out_nan: true,
444            input: memory_exec,
445            metric: ExecutionPlanMetricsSet::new(),
446            tag_columns: vec!["path".to_string()],
447        });
448        let session_context = SessionContext::default();
449        let result = datafusion::physical_plan::collect(normalize_exec, session_context.task_ctx())
450            .await
451            .unwrap();
452        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
453            .unwrap()
454            .to_string();
455
456        let expected = String::from(
457            "+---------------------+--------+------+\
458            \n| timestamp           | value  | path |\
459            \n+---------------------+--------+------+\
460            \n| 1970-01-01T00:01:01 | 0.0    | foo  |\
461            \n| 1970-01-01T00:02:01 | 1.0    | foo  |\
462            \n| 1970-01-01T00:00:01 | 10.0   | foo  |\
463            \n| 1970-01-01T00:00:31 | 100.0  | foo  |\
464            \n| 1970-01-01T00:01:31 | 1000.0 | foo  |\
465            \n+---------------------+--------+------+",
466        );
467
468        assert_eq!(result_literal, expected);
469    }
470}